├── .github
└── workflows
│ ├── build_and_publish_template.yml
│ ├── preview_docs.yml
│ ├── publish_nightly.yml
│ ├── publish_release.yml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .readthedocs.yaml
├── BUILD
├── CONTRIBUTING.md
├── LICENSE
├── MODULE.bazel
├── README.md
├── docs
├── CONTRIBUTING.md
├── README.md
├── api_choice.md
├── behind_the_scenes.md
├── conf.py
├── data_loader
│ ├── samplers.md
│ └── transformations.md
├── data_sources.md
├── grain.checkpoint.rst
├── grain.constants.rst
├── grain.dataset.rst
├── grain.experimental.rst
├── grain.multiprocessing.rst
├── grain.rst
├── grain.samplers.rst
├── grain.sharding.rst
├── grain.sources.rst
├── grain.transforms.rst
├── images
│ ├── data_flow_multiple_workers.png
│ ├── data_flow_zero_workers.png
│ └── grain_pipeline.svg
├── index.md
├── installation.md
├── requirements.txt
└── tutorials
│ ├── data_loader_tutorial.ipynb
│ ├── data_loader_tutorial.md
│ ├── data_sources
│ ├── arrayrecord_data_source_tutorial.ipynb
│ ├── arrayrecord_data_source_tutorial.md
│ ├── huggingface_dataset_tutorial.ipynb
│ ├── huggingface_dataset_tutorial.md
│ ├── index.rst
│ ├── load_from_gcs_tutorial.ipynb
│ ├── load_from_gcs_tutorial.md
│ ├── load_from_s3_tutorial.ipynb
│ ├── load_from_s3_tutorial.md
│ ├── parquet_dataset_tutorial.ipynb
│ ├── parquet_dataset_tutorial.md
│ ├── pytorch_dataset_tutorial.ipynb
│ └── pytorch_dataset_tutorial.md
│ ├── dataset_advanced_tutorial.ipynb
│ ├── dataset_advanced_tutorial.md
│ ├── dataset_basic_tutorial.ipynb
│ ├── dataset_basic_tutorial.md
│ ├── dataset_debugging_tutorial.ipynb
│ └── dataset_debugging_tutorial.md
├── grain
├── BUILD
├── __init__.py
├── _src
│ ├── __init__.py
│ ├── core
│ │ ├── BUILD
│ │ ├── config.py
│ │ ├── constants.py
│ │ ├── exceptions.py
│ │ ├── grain_random.py
│ │ ├── monitoring.py
│ │ ├── parallel.py
│ │ ├── parallel_test.py
│ │ ├── sharding.py
│ │ ├── sharding_test.py
│ │ ├── smoke_test_with_jax.py
│ │ ├── smoke_test_with_tf.py
│ │ ├── transforms.py
│ │ ├── transforms_test.py
│ │ ├── tree_lib.py
│ │ ├── tree_lib_jax_test.py
│ │ ├── tree_lib_test.py
│ │ ├── usage_logging.py
│ │ └── version_test.py
│ └── python
│ │ ├── BUILD
│ │ ├── checkpoint_handlers.py
│ │ ├── checkpointing.py
│ │ ├── data_loader.py
│ │ ├── data_loader_test.py
│ │ ├── data_sources.py
│ │ ├── data_sources_test.py
│ │ ├── dataset
│ │ ├── BUILD
│ │ ├── base.py
│ │ ├── base_test.py
│ │ ├── dataset.py
│ │ ├── dataset_test.py
│ │ ├── elastic_iterator.py
│ │ ├── elastic_iterator_test.py
│ │ ├── sources
│ │ │ ├── BUILD
│ │ │ ├── parquet_dataset.py
│ │ │ ├── parquet_dataset_test.py
│ │ │ ├── tfrecord_dataset.py
│ │ │ └── tfrecord_dataset_test.py
│ │ ├── stats.py
│ │ ├── stats_test.py
│ │ ├── stats_utils.py
│ │ ├── stats_utils_test.py
│ │ ├── transformations
│ │ │ ├── BUILD
│ │ │ ├── batch.py
│ │ │ ├── batch_test.py
│ │ │ ├── filter.py
│ │ │ ├── filter_test.py
│ │ │ ├── flatmap.py
│ │ │ ├── flatmap_test.py
│ │ │ ├── interleave.py
│ │ │ ├── interleave_test.py
│ │ │ ├── limit.py
│ │ │ ├── limit_test.py
│ │ │ ├── map.py
│ │ │ ├── map_test.py
│ │ │ ├── mix.py
│ │ │ ├── mix_test.py
│ │ │ ├── packing.py
│ │ │ ├── packing_concat_then_split.py
│ │ │ ├── packing_concat_then_split_test.py
│ │ │ ├── packing_packed_batch.py
│ │ │ ├── packing_test.py
│ │ │ ├── prefetch.py
│ │ │ ├── prefetch_test.py
│ │ │ ├── repeat.py
│ │ │ ├── repeat_test.py
│ │ │ ├── shuffle.py
│ │ │ ├── shuffle_test.py
│ │ │ ├── slice.py
│ │ │ ├── slice_test.py
│ │ │ ├── source.py
│ │ │ ├── source_test.py
│ │ │ ├── testing_util.py
│ │ │ ├── zip.py
│ │ │ └── zip_test.py
│ │ ├── visualize.py
│ │ └── visualize_test.py
│ │ ├── experimental
│ │ ├── example_packing
│ │ │ ├── BUILD
│ │ │ ├── packing.py
│ │ │ └── packing_test.py
│ │ └── index_shuffle
│ │ │ ├── BUILD
│ │ │ ├── index_shuffle.cc
│ │ │ ├── index_shuffle.h
│ │ │ └── python
│ │ │ ├── BUILD
│ │ │ ├── index_shuffle_module.cc
│ │ │ ├── index_shuffle_python.py
│ │ │ ├── index_shuffle_python_test.py
│ │ │ └── index_shuffle_test.py
│ │ ├── grain_logging.py
│ │ ├── grain_logging_test.py
│ │ ├── grain_pool.py
│ │ ├── grain_pool_test.py
│ │ ├── load.py
│ │ ├── load_test.py
│ │ ├── multiprocessing_common.py
│ │ ├── multiprocessing_common_test.py
│ │ ├── operations.py
│ │ ├── operations_test.py
│ │ ├── options.py
│ │ ├── record.py
│ │ ├── record_test.py
│ │ ├── samplers.py
│ │ ├── samplers_test.py
│ │ ├── shared_memory_array.py
│ │ ├── shared_memory_array_test.py
│ │ ├── testdata
│ │ ├── BUILD
│ │ ├── digits.array_record-00000-of-00002
│ │ ├── digits.array_record-00001-of-00002
│ │ └── morris_sequence_first_5.tfrecord
│ │ └── testing
│ │ ├── BUILD
│ │ ├── __init__.py
│ │ └── experimental.py
├── checkpoint.py
├── constants.py
├── experimental.py
├── multiprocessing.py
├── oss
│ ├── Dockerfile
│ ├── array_record
│ │ ├── Dockerfile.patch
│ │ ├── WORKSPACE.patch
│ │ ├── array_record_reader.patch
│ │ ├── build_whl.patch
│ │ ├── runner_common.patch
│ │ └── setup.patch
│ ├── build_whl.sh
│ └── runner_common.sh
├── proto
│ ├── BUILD
│ └── execution_summary.proto
├── python
│ ├── __init__.py
│ └── experimental.py
├── samplers.py
├── sharding.py
├── sources.py
└── transforms.py
├── pyproject.toml
├── setup.py
├── test_requirements.in
├── test_requirements_lock_3_10.txt
├── test_requirements_lock_3_11.txt
├── test_requirements_lock_3_12.txt
└── test_requirements_lock_3_13.txt
/.github/workflows/build_and_publish_template.yml:
--------------------------------------------------------------------------------
1 | # This workflow builds Grain wheels and uploads them as artifacts.
2 |
3 | name: Build & Publish Template
4 |
5 | on:
6 | workflow_call:
7 | inputs:
8 | pypi_project_url:
9 | required: true
10 | type: string
11 | run_tests:
12 | required: true
13 | type: boolean
14 | is_nightly:
15 | required: true
16 | type: boolean
17 |
18 | permissions:
19 | contents: read
20 |
21 | jobs:
22 | build-and-test:
23 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
24 | runs-on: "${{ matrix.os }}"
25 |
26 | strategy:
27 | fail-fast: false
28 | matrix:
29 | python-version: ["3.10", "3.11", "3.12", "3.13"]
30 | os: [ubuntu-22.04, ubuntu-22.04-arm, macos-14]
31 |
32 | steps:
33 | - uses: "actions/checkout@v3"
34 | - name: Create directory
35 | run: |
36 | mkdir -p /tmp/grain
37 | cp -r . /tmp/grain
38 | - name: Build package
39 | run: |
40 | set -xe
41 | export PYTHON_VERSION=${{ matrix.python-version }}
42 | export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1)
43 | export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2)
44 | export BAZEL_VERSION="7.1.1"
45 | export OUTPUT_DIR="/tmp/grain"
46 | export SOURCE_DIR="/tmp/grain"
47 | export RUN_TESTS=${{ inputs.run_tests }}
48 | export IS_NIGHTLY=${{ inputs.is_nightly }}
49 | . "${SOURCE_DIR}"'/grain/oss/runner_common.sh'
50 | build_and_test_grain
51 | - name: Upload Grain artifacts
52 | uses: actions/upload-artifact@v4
53 | with:
54 | name: built-grain-wheels-${{ matrix.os }}-${{ matrix.python-version }}
55 | path: /tmp/grain/all_dist/*.whl
56 |
57 | publish-wheel:
58 | runs-on: ubuntu-22.04
59 | needs: build-and-test
60 | permissions:
61 | id-token: write
62 | environment:
63 | name: pypi
64 | url: ${{ inputs.pypi_project_url }}
65 | steps:
66 | - name: Download Grain artifacts
67 | uses: actions/download-artifact@v4
68 | with:
69 | pattern: built-grain-wheels-*
70 | path: dist/
71 | merge-multiple: true
72 | - name: Publish package distributions to PyPI
73 | uses: pypa/gh-action-pypi-publish@release/v1
--------------------------------------------------------------------------------
/.github/workflows/preview_docs.yml:
--------------------------------------------------------------------------------
1 | # Add a link to preview the documentation on Read the Docs for every pull request.
2 | name: "RTD preview"
3 |
4 | on:
5 | pull_request_target:
6 | types:
7 | - opened
8 |
9 | permissions:
10 | pull-requests: write
11 |
12 | jobs:
13 | documentation-links:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: readthedocs/actions/preview@v1
17 | with:
18 | project-slug: "google-grain"
19 | single-version: true
--------------------------------------------------------------------------------
/.github/workflows/publish_nightly.yml:
--------------------------------------------------------------------------------
1 | name: Build and Publish Nightly
2 |
3 | on:
4 | schedule:
5 | # At 04:00 on Monday.
6 | - cron: '0 4 * * 1'
7 |
8 | jobs:
9 | call-workflow:
10 | uses: ./.github/workflows/build_and_publish_template.yml
11 | with:
12 | pypi_project_url: https://pypi.org/project/grain-nightly
13 | run_tests: true
14 | is_nightly: true
--------------------------------------------------------------------------------
/.github/workflows/publish_release.yml:
--------------------------------------------------------------------------------
1 | name: Build and Publish Release
2 |
3 | on:
4 | workflow_dispatch:
5 | inputs:
6 | run_tests:
7 | description: 'Run unit tests'
8 | required: false
9 | default: true
10 | type: boolean
11 |
12 | jobs:
13 | call-workflow:
14 | uses: ./.github/workflows/build_and_publish_template.yml
15 | with:
16 | pypi_project_url: https://pypi.org/project/grain
17 | run_tests: ${{ inputs.run_tests }}
18 | is_nightly: false
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Build & Test
2 |
3 | on:
4 | push:
5 | branches: ["main"]
6 | pull_request:
7 | branches: ["main"]
8 |
9 | jobs:
10 | build-and-test:
11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
12 | runs-on: "${{ matrix.os }}"
13 |
14 | strategy:
15 | matrix:
16 | python-version: ["3.10", "3.11", "3.12", "3.13"]
17 | os: [ubuntu-latest]
18 |
19 | steps:
20 | - uses: "actions/checkout@v2"
21 | - name: Create directory
22 | run: |
23 | mkdir -p /tmp/grain
24 | cp -r . /tmp/grain
25 | - name: Build package
26 | run: |
27 | set -xe
28 | export PYTHON_VERSION=${{ matrix.python-version }}
29 | export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1)
30 | export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2)
31 | export BAZEL_VERSION="7.1.1"
32 | export AUDITWHEEL_PLATFORM="manylinux2014_x86_64"
33 | export RUN_TESTS="true"
34 | cd /tmp/grain
35 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \
36 | --build-arg AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \
37 | --build-arg PYTHON_VERSION=${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION} \
38 | --build-arg BAZEL_VERSION=${BAZEL_VERSION} \
39 | -t grain:${PYTHON_VERSION} grain/oss
40 | docker run --rm -a stdin -a stdout -a stderr \
41 | --env PYTHON_VERSION=${PYTHON_VERSION} \
42 | --env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \
43 | --env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \
44 | --env BAZEL_VERSION=${BAZEL_VERSION} \
45 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \
46 | --env RUN_TESTS=${RUN_TESTS} \
47 | -v /tmp/grain:/tmp/grain \
48 | --name grain grain:${PYTHON_VERSION} \
49 | bash grain/oss/build_whl.sh
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
26 | # Unit test / coverage reports
27 | htmlcov/
28 | .tox/
29 | .nox/
30 | .coverage
31 | .coverage.*
32 | .cache
33 | nosetests.xml
34 | coverage.xml
35 | *.cover
36 | *.py,cover
37 | .hypothesis/
38 | .pytest_cache/
39 | cover/
40 |
41 | # Translations
42 | *.mo
43 | *.pot
44 |
45 | # Sphinx documentation
46 | docs/_build/
47 |
48 | # Jupyter Notebook
49 | .ipynb_checkpoints
50 |
51 | # IPython
52 | profile_default/
53 | ipython_config.py
54 |
55 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
56 | __pypackages__/
57 |
58 | # Environments
59 | .venv
60 |
61 | # Bazel outputs
62 | bazel-bin/
63 | bazel-grain/
64 | bazel-out/
65 | bazel-testlogs/
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v5.0.0
4 | hooks:
5 | - id: check-ast
6 | - id: check-merge-conflict
7 | - id: end-of-file-fixer
8 | - id: trailing-whitespace
9 |
10 | - repo: https://github.com/mwouts/jupytext
11 | rev: v1.15.2
12 | hooks:
13 | - id: jupytext
14 | files: docs/tutorials/
15 | args: [--sync]
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-lts-latest
5 | tools:
6 | python: "3.12"
7 |
8 | sphinx:
9 | configuration: docs/conf.py
10 | # Note this is set to false for now while the warnings are resolved
11 | fail_on_warning: false
12 |
13 | python:
14 | install:
15 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/BUILD:
--------------------------------------------------------------------------------
1 | load("@python//3.10:defs.bzl", compile_pip_requirements_3_10 = "compile_pip_requirements")
2 | load("@python//3.11:defs.bzl", compile_pip_requirements_3_11 = "compile_pip_requirements")
3 | load("@python//3.12:defs.bzl", compile_pip_requirements_3_12 = "compile_pip_requirements")
4 | load("@python//3.13:defs.bzl", compile_pip_requirements_3_13 = "compile_pip_requirements")
5 |
6 | py_library(
7 | name = "setup",
8 | srcs = ["setup.py"],
9 | srcs_version = "PY3",
10 | )
11 |
12 | compile_pip_requirements_3_10(
13 | name = "requirements_3_10",
14 | requirements_in = "test_requirements.in",
15 | requirements_txt = "test_requirements_lock_3_10.txt",
16 | )
17 |
18 | compile_pip_requirements_3_11(
19 | name = "requirements_3_11",
20 | requirements_in = "test_requirements.in",
21 | requirements_txt = "test_requirements_lock_3_11.txt",
22 | )
23 |
24 | compile_pip_requirements_3_12(
25 | name = "requirements_3_12",
26 | requirements_in = "test_requirements.in",
27 | requirements_txt = "test_requirements_lock_3_12.txt",
28 | )
29 |
30 | compile_pip_requirements_3_13(
31 | name = "requirements_3_13",
32 | requirements_in = "test_requirements.in",
33 | requirements_txt = "test_requirements_lock_3_13.txt",
34 | )
35 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/MODULE.bazel:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | module(
15 | name = "grain",
16 | version = "0.2.9",
17 | repo_name = "com_google_grain",
18 | )
19 |
20 | bazel_dep(name = "bazel_skylib", version = "1.6.1")
21 | bazel_dep(name = "platforms", version = "0.0.8")
22 | bazel_dep(name = "rules_python", version = "0.37.0")
23 | bazel_dep(name = "rules_cc", version = "0.0.9")
24 | bazel_dep(name = "pybind11_bazel", version = "2.11.1")
25 | bazel_dep(name = "abseil-py", version = "2.1.0")
26 | bazel_dep(name = "abseil-cpp", version = "20230802.0.bcr.1")
27 | bazel_dep(name = "protobuf", version = "24.4", repo_name = "com_google_protobuf")
28 |
29 | http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
30 |
31 | http_archive(
32 | name = "pybind11",
33 | build_file = "@pybind11_bazel//:pybind11.BUILD",
34 | sha256 = "201966a61dc826f1b1879a24a3317a1ec9214a918c8eb035be2f30c3e9cfbdcb",
35 | strip_prefix = "pybind11-2.10.3",
36 | urls = ["https://github.com/pybind/pybind11/archive/refs/tags/v2.10.3.zip"],
37 | )
38 |
39 | SUPPORTED_PYTHON_VERSIONS = [
40 | "3.10",
41 | "3.11",
42 | "3.12",
43 | "3.13",
44 | ]
45 |
46 | DEFAULT_PYTHON_VERSION = "3.10"
47 |
48 | python_configure = use_extension("@pybind11_bazel//:python_configure.bzl", "extension")
49 | use_repo(python_configure, "local_config_python")
50 |
51 | python = use_extension("@rules_python//python/extensions:python.bzl", "python")
52 |
53 | [
54 | python.toolchain(
55 | ignore_root_user_error = True,
56 | is_default = python_version == DEFAULT_PYTHON_VERSION,
57 | python_version = python_version,
58 | )
59 | for python_version in SUPPORTED_PYTHON_VERSIONS
60 | ]
61 |
62 | use_repo(python, python = "python_versions")
63 |
64 | pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip")
65 |
66 | # requirements_lock.txt is generated by
67 | # bazel run //:requirements.update
68 | [
69 | pip.parse(
70 | hub_name = "pypi",
71 | python_version = version,
72 | requirements_lock = "test_requirements_lock_" + version.replace(".", "_") + ".txt",
73 | )
74 | for version in SUPPORTED_PYTHON_VERSIONS
75 | ]
76 |
77 | use_repo(pip, "pypi")
78 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Grain - Feeding JAX Models
2 |
3 | [](https://github.com/google/grain/actions/workflows/tests.yml)
4 | [](https://pypi.org/project/grain/)
5 |
6 |
7 | [**Installation**](#installation)
8 | | [**Quickstart**](#quickstart)
9 | | [**Reference docs**](https://google-grain.readthedocs.io/en/latest/)
10 |
11 | Grain is a Python library for reading and processing data for training and
12 | evaluating JAX models. It is flexible, fast and deterministic.
13 |
14 | Grain allows to define data processing steps in a simple declarative way:
15 |
16 | ```python
17 | import grain
18 |
19 | dataset = (
20 | grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
21 | .shuffle(seed=42) # Shuffles elements globally.
22 | .map(lambda x: x+1) # Maps each element.
23 | .batch(batch_size=2) # Batches consecutive elements.
24 | )
25 |
26 | for batch in dataset:
27 | # Training step.
28 | ```
29 |
30 | Grain is designed to work with JAX models but it does not require JAX to run
31 | and can be used with other frameworks as well.
32 |
33 | ## Installation
34 |
35 | Grain is available on [PyPI](https://pypi.org/project/grain/) and can be
36 | installed with `pip install grain`.
37 |
38 | ### Supported platforms
39 |
40 | Grain does not directly use GPU or TPU in its transformations, the processing
41 | within Grain will be done on the CPU by default.
42 |
43 | | | Linux | Mac | Windows |
44 | |---------|---------|---------|---------|
45 | | x86_64 | yes | no | no |
46 | | aarch64 | yes | yes | n/a |
47 |
48 | ## Quickstart
49 |
50 | - [Basic `Dataset` tutorial](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_basic_tutorial.html)
51 |
52 | ## Existing users
53 |
54 | Grain is used by [MaxText](https://github.com/google/maxtext/tree/main),
55 | [Gemma](https://github.com/google-deepmind/gemma),
56 | [kauldron](https://github.com/google-research/kauldron) and multiple internal
57 | Google projects.
--------------------------------------------------------------------------------
/docs/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Grain
2 |
3 |
4 |
5 | ## Contributing to the Grain project documentation
6 |
7 | ### Pre-requisites
8 |
9 | To contribute to the documentation, you will need to set your development environment.
10 |
11 | You can create a virtual environment or conda environment and install the packages in
12 | `docs/requirements.txt`.
13 |
14 | ```bash
15 | # Create a virtual environment
16 | python3 -m venv .venv
17 | # Activate the virtual environment
18 | source .venv/bin/activate
19 | # Install the requirements
20 | pip install -r docs/requirements.txt
21 | ```
22 |
23 | or with conda
24 |
25 | ```bash
26 | # Create a conda environment
27 | conda create -n "grain-docs" python=3.12
28 | # Activate the conda environment
29 | conda activate grain-docs
30 | # Install the requirements
31 | python3 -m pip install -r docs/requirements.txt
32 | ```
33 |
34 | ### Building the documentation locally
35 |
36 | To build the documentation locally, you can run the following command:
37 |
38 | ```bash
39 | # Change to the docs/ directory
40 | cd docs
41 | sphinx-build -b html . _build/html
42 | ```
43 |
44 | You can then open the generated HTML files in your browser by opening
45 | `docs/_build/html/index.html`.
46 |
47 | ## Documentation via Jupyter notebooks
48 |
49 | The `pygrain` documentation includes Jupyter notebooks that are rendered
50 | directly into the website via the [myst-nb](https://myst-nb.readthedocs.io/) extension.
51 | To ease review and diff of notebooks, we keep markdown versions of the content
52 | synced via [jupytext](https://jupytext.readthedocs.io/).
53 |
54 | Note you will need to install `jupytext` to sync the notebooks with markdown files:
55 |
56 | ```bash
57 | # With pip
58 | python3 -m pip install jupytext
59 |
60 | # With conda
61 | conda install -c conda-forge jupytext
62 | ```
63 |
64 | ### Adding a new notebook
65 |
66 | We aim to have one notebook per topic or tutorial covered.
67 | To add a new notebook to the repository, first move the notebook into the appropriate
68 | location in the `docs` directory:
69 |
70 | ```bash
71 | mv ~/new-tutorial.ipynb docs/tutorials/new_tutorial.ipynb
72 | ```
73 |
74 | Next, we use `jupytext` to mark the notebook for syncing with Markdown:
75 |
76 | ```bash
77 | jupytext --set-formats ipynb,md:myst docs/tutorials/new_tutorial.ipynb
78 | ```
79 |
80 | Finally, we can sync the notebook and markdown source:
81 |
82 | ```bash
83 | jupytext --sync docs/tutorials/new_tutorial.ipynb
84 | ```
85 |
86 | To ensure that the new notebook is rendered as part of the site, be sure to add
87 | references to a `toctree` declaration somewhere in the source tree, for example
88 | in `docs/index.md`. You will also need to add references in `docs/conf.py`
89 | to specify whether the notebook should be executed, and to specify which file
90 | sphinx should use when generating the site.
91 |
92 | ### Editing an existing notebook
93 |
94 | When editing the text of an existing notebook, it is recommended to edit the
95 | markdown file only, and then automatically sync using `jupytext` via the
96 | `pre-commit` framework, which we use to check in GitHub CI that notebooks are
97 | properly synced.
98 | For example, say you have edited `docs/tutorials/new_tutorial.md`, then
99 | you can do the following:
100 |
101 | ```bash
102 | pip install pre-commit
103 | git add docs/tutorials/new_tutorial.* # stage the new changes
104 | pre-commit run # run pre-commit checks on added files
105 | git add docs/tutorials/new_tutorial.* # stage the files updated by pre-commit
106 | git commit -m "Update new tutorial" # commit to the branch
--------------------------------------------------------------------------------
/docs/api_choice.md:
--------------------------------------------------------------------------------
1 | # Choice of API
2 |
3 |
4 |
5 |
6 |
7 | Grain offers two different ways of defining data processing pipelines:
8 | [`DataLoader`](tutorials/data_loader_tutorial.md) and [`Dataset`](tutorials/dataset_basic_tutorial.md).
9 |
10 | > TL;DR: If you need to do one of the following:
11 | >
12 | > * mix multiple data sources
13 | > * pack variable length elements
14 | > * split dataset elements and globally shuffle the splits
15 | >
16 | > then you should use `Dataset`, otherwise use simpler `DataLoader`.
17 |
18 | ## `DataLoader`
19 |
20 | `DataLoader` is a high-level API that uses the following abstractions to define
21 | data processing:
22 |
23 | * [`RandomAccessDataSource`](https://github.com/google/grain/tree/main/grain/_src/python/data_sources.py)
24 | that reads raw input data.
25 | * A
26 | [`Sampler`](https://github.com/google/grain/tree/main/grain/_src/python/samplers.py)
27 | that defines the order in which the raw data should be read.
28 | * A flat sequence of
29 | [`Transformation`s](https://github.com/google/grain/tree/main/grain/_src/core/transforms.py)
30 | to apply to the raw data.
31 |
32 | You can specify other execution parameters for asynchronous data processing,
33 | sharding, shuffling, and `DataLoader` will automatically take care of inserting
34 | them in the right places between the data processing steps.
35 |
36 | These are simple and usually general enough to cover most data processing use
37 | cases. Prefer using `DataLoader` if your workflow can be described using the
38 | abstractions above. See [tutorial](tutorials/data_loader_tutorial.md)
39 | for more details.
40 |
41 | ## `Dataset`
42 |
43 | `Dataset` is a lower-level API that uses chaining syntax to define data
44 | transformation steps. It allows more general types of processing (e.g. dataset
45 | mixing) and more control over the execution (e.g. different order of data
46 | sharding and shuffling). `Dataset` transformations are composed in a way that
47 | allows to preserve random access property past the source and some of the
48 | transformations. This, among other things, can be used for debugging by
49 | evaluating dataset elements at specific positions without processing the entire
50 | dataset.
51 |
52 | There are 3 main classes comprising the `Dataset` API:
53 |
54 | * [`MapDataset`](https://github.com/google/grain/tree/main/grain/_src/python/dataset/dataset.py)
55 | defines a dataset that supports efficient random access. Think of it as an
56 | (infinite) `Sequence` that computes values lazily.
57 | * [`IterDataset`](https://github.com/google/grain/tree/main/grain/_src/python/dataset/dataset.py)
58 | defines a dataset that does not support efficient random access and only
59 | supports iterating over it. It's an `Iterable`. Any `MapDataset` can be
60 | turned into a `IterDataset` by calling `to_iter_dataset()`.
61 | * [`DatasetIterator`](https://github.com/google/grain/tree/main/grain/_src/python/dataset/dataset.py)
62 | defines a stateful iterator of an `IterDataset`. The state of the iterator
63 | can be saved and restored.
64 |
65 | Most data pipelines will start with one or more `MapDataset` (often derived from
66 | a `RandomAccessDataSource`) and switch to `IterDataset` late or not at all. See
67 | [tutorial](tutorials/dataset_basic_tutorial.md)
68 | for more details.
69 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | """Configuration file for the Sphinx documentation builder.
2 |
3 | For the full list of built-in configuration values, see the documentation:
4 | https://www.sphinx-doc.org/en/master/usage/configuration.html
5 | """
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | import os
13 | import pathlib
14 | import sys
15 |
16 | sys.path.insert(0, str(pathlib.Path('..', 'grain').resolve()))
17 | sys.path.insert(0, os.path.abspath('..'))
18 |
19 | # -- Project information -----------------------------------------------------
20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
21 |
22 | project = 'Grain'
23 | copyright = '2024, Grain team' # pylint: disable=redefined-builtin
24 | author = 'Grain team'
25 |
26 | # -- General configuration ---------------------------------------------------
27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
28 |
29 | extensions = [
30 | 'myst_nb',
31 | 'sphinx_copybutton',
32 | 'sphinx_design',
33 | 'sphinx.ext.autodoc',
34 | 'sphinx.ext.autosummary',
35 | 'sphinx.ext.napoleon',
36 | ]
37 |
38 | templates_path = ['_templates']
39 | source_suffix = ['.rst', '.ipynb', '.md']
40 | exclude_patterns = [
41 | '_build',
42 | 'Thumbs.db',
43 | '.DS_Store',
44 | 'tutorials/dataset_basic_tutorial.md',
45 | ]
46 |
47 | # Suppress warning in exception basic_data_tutorial
48 | suppress_warnings = [
49 | 'misc.highlighting_failure',
50 | ]
51 |
52 | # -- Options for HTML output -------------------------------------------------
53 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
54 |
55 | html_theme = 'sphinx_book_theme'
56 | html_title = 'Grain'
57 | html_static_path = ['_static']
58 |
59 | # TODO: Add logo and favicon
60 | # html_logo = '_static/'
61 | # html_favicon = '_static/favicon.png'
62 |
63 | # Theme-specific options
64 | # https://sphinx-book-theme.readthedocs.io/en/stable/reference.html
65 | html_theme_options = {
66 | 'show_navbar_depth': 1,
67 | 'show_toc_level': 3,
68 | 'repository_url': 'https://github.com/google/grain',
69 | 'use_issues_button': True,
70 | 'use_repository_button': True,
71 | 'path_to_docs': 'docs/',
72 | 'navigation_with_keys': True,
73 | }
74 |
75 | # Autodoc settings
76 | autosummary_generate = True
77 | autodoc_typehints = 'description'
78 | # We mock dependencies and internal modules that require building to be able to
79 | # import the python symbols for extracting docstrings.
80 | autodoc_mock_imports = [
81 | 'grain.proto.execution_summary_pb2',
82 | 'grain._src.python.experimental.index_shuffle.python.index_shuffle_module',
83 | 'cloudpickle',
84 | 'numpy',
85 | 'orbax',
86 | 'tree',
87 | 'absl',
88 | 'absl.logging',
89 | 'array_record',
90 | ]
91 |
92 | # -- Myst configurations -------------------------------------------------
93 | myst_enable_extensions = ['colon_fence']
94 | nb_execution_mode = 'force'
95 | nb_execution_allow_errors = False
96 | nb_merge_streams = True
97 | nb_execution_show_tb = True
98 |
99 | # Notebook cell execution timeout; defaults to 30.
100 | nb_execution_timeout = 100
101 |
102 | # List of patterns, relative to source directory, that match notebook
103 | # files that will not be executed.
104 | nb_execution_excludepatterns = [
105 | 'tutorials/dataset_advanced_tutorial.ipynb',
106 | 'tutorials/dataset_basic_tutorial.ipynb',
107 | 'tutorials/data_loader_tutorial.ipynb',
108 | 'tutorials/dataset_debugging_tutorial.ipynb',
109 | 'tutorials/data_sources/load_from_s3_tutorial.ipynb',
110 | 'tutorials/data_sources/load_from_gcs_tutorial.ipynb',
111 | 'tutorials/data_sources/parquet_dataset_tutorial.ipynb',
112 | 'tutorials/data_sources/arrayrecord_data_source_tutorial.ipynb',
113 | 'tutorials/data_sources/huggingface_dataset_tutorial.ipynb',
114 | 'tutorials/data_sources/pytorch_dataset_tutorial.ipynb',
115 | ]
116 |
--------------------------------------------------------------------------------
/docs/data_loader/samplers.md:
--------------------------------------------------------------------------------
1 | # Samplers
2 |
3 | Samplers in Grain are responsible for determining the order in which records
4 | are processed. This allows Grain to implement global transformations (e.g.
5 | global shuffling, sharding, repeating for multiple epochs) before reading any
6 | records.
7 |
8 |
9 |
10 | Samplers need to implement the following iterator protocol:
11 |
12 | ```python
13 | class Sampler(Protocol):
14 |
15 | def __iter__(self):
16 | ...
17 |
18 | def __next__(self) -> record.RecordMetadata:
19 | ...
20 |
21 | @dataclasses.dataclass
22 | class RecordMetadata:
23 | """RecordMetadata contains metadata about individual records.
24 |
25 | RecordMetadata objects are emitted by the sampler to refer to which record to
26 | read next (record_key), what its index is (for keeping progress and
27 | checkpointing) as well as having an optional rng for stateless random
28 | transformations. In addition, they are also used to keep information about
29 | records as they flow through the pipeline from one operation to the other.
30 | """
31 | index: int
32 | record_key: Optional[int] = None
33 | rng: Optional[np.random.Generator] = None
34 | ```
35 |
36 | ## Index Sampler
37 |
38 | This is our recommended Sampler. It supports:
39 |
40 | * Sharding across multiple machines (`shard_options` parameter).
41 | * Global shuffle of the data (`shuffle` parameter).
42 | * Repeating records for multiple epochs (`num_epochs` parameter). Note that
43 | the shuffle order changes across epochs. Behind the scenes, this relies on
44 | [tf.random_index_shuffle](https://www.tensorflow.org/api_docs/python/tf/random_index_shuffle).
45 | * Stateless random operations. Each `RecordMetadata` object emitted by the
46 | `IndexSampler` contains an RNG uniquely seeded on a per-record basis. This
47 | RNG can be used for random augmentations while not relying on a global
48 | state.
49 |
50 | ```python
51 | index_sampler = pygrain.IndexSampler(
52 | num_records=5,
53 | num_epochs=2,
54 | shard_options=pygrain.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True),
55 | shuffle=True,
56 | seed=0)
57 | for record_metadata in index_sampler:
58 | print(record_metadata)
59 |
60 | # Output
61 | # RecordMetadata(index=0, record_key=0, rng=Generator(Philox) at 0x7FB09947AF80)
62 | # RecordMetadata(index=1, record_key=4, rng=Generator(Philox) at 0x7FB0994789E0)
63 | # RecordMetadata(index=2, record_key=2, rng=Generator(Philox) at 0x7FB099478740)
64 | # RecordMetadata(index=3, record_key=3, rng=Generator(Philox) at 0x7FB0994789E0)
65 | # RecordMetadata(index=4, record_key=1, rng=Generator(Philox) at 0x7FB099478740)
66 | # RecordMetadata(index=5, record_key=1, rng=Generator(Philox) at 0x7FB0994789E0)
67 | # RecordMetadata(index=6, record_key=0, rng=Generator(Philox) at 0x7FB099478740)
68 | # RecordMetadata(index=7, record_key=3, rng=Generator(Philox) at 0x7FB0994789E0)
69 | # RecordMetadata(index=8, record_key=4, rng=Generator(Philox) at 0x7FB099478740)
70 | # RecordMetadata(index=9, record_key=2, rng=Generator(Philox) at 0x7FB0994789E0)
71 | ```
72 |
73 | ## Implement your Own Sampler
74 | Grain can accommodate custom user-defined samplers. Users implementing their
75 | own sampler should ensure it:
76 |
77 | * implements the aforementioned interface.
78 | * is adequately performant. Since Grain's `DataLoader` iterates sequentially
79 | through the sampler to distribute indices to child processes, a slow sampler
80 | will become a bottleneck and reduce end-to-end pipeline performance. As a
81 | reference, we recommend sampler iteration performance of at approx. 50,000
82 | elements / sec for most use cases.
83 |
--------------------------------------------------------------------------------
/docs/data_loader/transformations.md:
--------------------------------------------------------------------------------
1 | # Transformations
2 |
3 |
4 |
5 | .md
6 |
7 |
8 |
9 | Grain Transforms interface denotes transformations which are applied to data. In
10 | the case of local transformations (such as map, random map, filter), the
11 | transforms receive an element on which custom changes are applied. For global
12 | transformations (such as batching), one must provide the batch size.
13 |
14 | The Grain core transforms interface code is
15 | [here](https://github.com/google/grain/tree/main/grain/_src/core/transforms.py).
16 |
17 | ## MapTransform
18 |
19 | `MapTransform` is for 1:1 transformations of elements. Elements can be of any
20 | type, it is the user's responsibility to use the transformation such that the
21 | inputs it receives correspond to the signature.
22 |
23 | Example of transformation which implements `MapTransform` (for elements of type
24 | `int`):
25 |
26 | ```python
27 | class PlusOne(transforms.MapTransform):
28 |
29 | def map(self, x: int) -> int:
30 | return x + 1
31 | ```
32 |
33 | ## MapWithIndexTransform
34 |
35 | `MapWithIndexTransform` is similar to `MapTransform` in being a 1:1
36 | transformations of elements, but also takes in the index/position of the element
37 | as the first argument. This is useful for pairing elements with an index key or
38 | even keeping it as metadata alongside the actual data.
39 |
40 | Example of transformation which implements `MapWithIndexTransform` (for elements
41 | of type `int`):
42 |
43 | ```python
44 | class PlusOneWithIndexKey(transforms.MapWithIndexTransform):
45 |
46 | def map_with_index(self, i: int, x: int) -> tuple[int, int]:
47 | return (x + 1, i)
48 | ```
49 |
50 | ## RandomMapTransform
51 |
52 | `RandomMapTransform` is for 1:1 random transformations of elements. The
53 | interface requires a `np.random.Generator` as parameter to the `random_map`
54 | function.
55 |
56 | Example of a `RandomMapTransform`:
57 |
58 | ```python
59 | class PlusRandom(transforms.RandomMapTransform):
60 |
61 | def random_map(self, x: int, rng: np.random.Generator) -> int:
62 | return x + rng.integers(100_000)
63 | ```
64 |
65 | ## FlatMapTransform
66 |
67 | `FlatMapTransform` is for splitting operations of individual elements. The
68 | `max_fan_out` is the maximum number of splits that an element can generate.
69 | Please consult the code for detailed info.
70 |
71 | Example of a `FlatMapTransform`:
72 |
73 | ```python
74 | class FlatMapTransformExample(transforms.FlatMapTransform):
75 | max_fan_out: int
76 |
77 | def flat_map(self, element: int):
78 | for _ in range(self.max_fan_out):
79 | yield element
80 | ```
81 |
82 | ## FilterTransform
83 |
84 | `FilterTransform` is for applying filtering to individual elements. Elements for
85 | which the filter function returns False will be removed.
86 |
87 | Example of a `FilterTransform` that removes all even elements:
88 |
89 | ```python
90 | class RemoveEvenElements(FilterTransform):
91 |
92 | def filter(self, element: int) -> bool:
93 | return element % 2
94 | ```
95 |
96 | ## Batch
97 |
98 | To apply the `Batch` transform, pass `grain.Batch(batch_size=batch_size,
99 | drop_remainder=drop_remainder)`.
100 |
101 | Note: The batch size used when passing `Batch` transform will be the global
102 | batch size if it is done before sharding and the *per host* batch size if it is
103 | after. Typically usage with `IndexSampler` is after sharding.
104 |
--------------------------------------------------------------------------------
/docs/grain.checkpoint.rst:
--------------------------------------------------------------------------------
1 |
2 | ``grain.checkpoint`` module
3 | ===========================
4 |
5 | .. automodule:: grain.checkpoint
6 |
7 | List of Members
8 | ---------------
9 |
10 | .. autosummary::
11 | :toctree: _autosummary
12 |
13 | CheckpointHandler
14 | CheckpointSave
15 | CheckpointRestore
--------------------------------------------------------------------------------
/docs/grain.constants.rst:
--------------------------------------------------------------------------------
1 | ``grain.constants`` module
2 | ==========================
3 |
4 | .. automodule:: grain.constants
5 |
6 | List of Constants
7 | -----------------
8 |
9 | .. autosummary::
10 | :toctree: _autosummary
11 |
12 | DATASET_INDEX
13 | EPOCH
14 | INDEX
15 | META_FEATURES
16 | RECORD
17 | RECORD_KEY
18 | SEED
--------------------------------------------------------------------------------
/docs/grain.dataset.rst:
--------------------------------------------------------------------------------
1 | ``grain`` Dataset
2 | =================
3 |
4 | .. automodule:: grain._src.python.dataset.dataset
5 | .. currentmodule:: grain
6 |
7 | List of Members
8 | ---------------
9 |
10 | .. autoclass:: _src.python.dataset.dataset.MapDatasetMeta
11 | :members:
12 |
13 | .. autoclass:: MapDataset
14 | :special-members: __init__, __len__, __getitem__, __iter__
15 | :members:
16 | :show-inheritance:
17 | :inherited-members:
18 | :undoc-members:
19 |
20 | .. autoclass:: _src.python.dataset.dataset.IterDatasetMeta
21 | :members:
22 |
23 | .. autoclass:: IterDataset
24 | :special-members: __init__, __iter__
25 | :members:
26 | :show-inheritance:
27 | :inherited-members:
28 | :undoc-members:
29 |
30 | .. autoclass:: DatasetIterator
31 | :special-members: __init__, __iter__
32 | :members:
33 | :show-inheritance:
34 | :inherited-members:
35 | :undoc-members:
36 |
--------------------------------------------------------------------------------
/docs/grain.experimental.rst:
--------------------------------------------------------------------------------
1 | ``grain.experimental`` module
2 | =============================
3 |
4 | .. automodule:: grain.experimental
5 |
6 | List of Members
7 | ---------------
8 |
9 | .. autosummary::
10 | :toctree: _autosummary
11 |
12 | FlatMapTransform
13 | DatasetOptions
14 | ExecutionTrackingMode
15 | apply_transformations
16 | ElasticIterator
17 | WithOptionsIterDataset
18 | ParquetIterDataset
19 | FlatMapMapDataset
20 | FlatMapIterDataset
21 | InterleaveIterDataset
22 | LimitIterDataset
23 | RngPool
24 | FirstFitPackIterDataset
25 | BOSHandling
26 | ConcatThenSplitIterDataset
27 | ThreadPrefetchIterDataset
28 | WindowShuffleMapDataset
29 | WindowShuffleIterDataset
30 | ZipMapDataset
31 | ZipIterDataset
32 | PackAndBatchOperation
33 | assert_equal_output_after_checkpoint
34 |
--------------------------------------------------------------------------------
/docs/grain.multiprocessing.rst:
--------------------------------------------------------------------------------
1 | ``grain.multiprocessing`` module
2 | ================================
3 |
4 | .. automodule:: grain.multiprocessing
5 |
6 | List of Members
7 | ---------------
8 |
9 | .. autosummary::
10 | :toctree: _autosummary
11 |
12 | MultiprocessingOptions
13 | SharedMemoryArray
14 |
--------------------------------------------------------------------------------
/docs/grain.rst:
--------------------------------------------------------------------------------
1 | .. currentmodule:: grain
2 |
3 | Public API: ``grain`` package
4 | =============================
5 |
6 | Subpackages
7 | -----------
8 |
9 | .. toctree::
10 | :maxdepth: 1
11 |
12 | grain.checkpoint
13 | grain.constants
14 | grain.experimental
15 | grain.multiprocessing
16 | grain.samplers
17 | grain.sharding
18 | grain.sources
19 | grain.transforms
20 |
21 |
22 | Flexible low-level pipelines
23 | ----------------------------
24 |
25 | .. autosummary::
26 |
27 | MapDataset
28 | IterDataset
29 | DatasetIterator
30 | ReadOptions
31 |
32 |
33 | Simple high-level pipelines
34 | ---------------------------
35 |
36 |
37 | .. autosummary::
38 | :toctree: _autosummary
39 |
40 | load
41 | DataLoader
42 | DataLoaderIterator
43 | Record
44 | RecordMetadata
--------------------------------------------------------------------------------
/docs/grain.samplers.rst:
--------------------------------------------------------------------------------
1 | ``grain.samplers`` module
2 | =========================
3 |
4 | .. automodule:: grain.samplers
5 |
6 | List of Members
7 | ---------------
8 |
9 | .. autosummary::
10 | :toctree: _autosummary
11 |
12 | IndexSampler
13 | Sampler
14 | SequentialSampler
15 |
--------------------------------------------------------------------------------
/docs/grain.sharding.rst:
--------------------------------------------------------------------------------
1 | ``grain.sharding`` module
2 | =========================
3 |
4 | .. automodule:: grain.sharding
5 |
6 | List of Members
7 | ---------------
8 |
9 | .. autosummary::
10 | :toctree: _autosummary
11 |
12 | NoSharding
13 | ShardByJaxProcess
14 | ShardOptions
15 |
--------------------------------------------------------------------------------
/docs/grain.sources.rst:
--------------------------------------------------------------------------------
1 | ``grain.sources`` module
2 | ========================
3 |
4 | .. automodule:: grain.sources
5 |
6 | List of Members
7 | ---------------
8 |
9 | .. autosummary::
10 | :toctree: _autosummary
11 |
12 | ArrayRecordDataSource
13 | SharedMemoryDataSource
14 | RandomAccessDataSource
15 | RangeDataSource
16 |
--------------------------------------------------------------------------------
/docs/grain.transforms.rst:
--------------------------------------------------------------------------------
1 | ``grain.transforms`` module
2 | ===========================
3 |
4 | .. automodule:: grain.transforms
5 |
6 | List of Members
7 | ---------------
8 |
9 | .. autosummary::
10 | :toctree: _autosummary
11 |
12 | Batch
13 | Filter
14 | Map
15 | MapWithIndex
16 | RandomMap
17 | Transformation
18 | Transformations
19 | DatasetSelectionMap
20 |
--------------------------------------------------------------------------------
/docs/images/data_flow_multiple_workers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/docs/images/data_flow_multiple_workers.png
--------------------------------------------------------------------------------
/docs/images/data_flow_zero_workers.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/docs/images/data_flow_zero_workers.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Grain - Feeding JAX Models
2 |
3 |
4 |
5 | Grain is a library for reading data for training and evaluating JAX models. It's
6 | open source, fast and deterministic.
7 |
8 | ::::{grid} 1 2 2 3
9 | :gutter: 1 1 1 2
10 |
11 | :::{grid-item-card} {octicon}`zap;1.5em;sd-mr-1` Powerful
12 | Users can bring arbitrary Python transformations.
13 | :::
14 |
15 | :::{grid-item-card} {octicon}`tools;1.5em;sd-mr-1` Flexible
16 | Grain is designed to
17 | be modular. Users can readily override Grain components if need be with their
18 | own implementation.
19 | :::
20 |
21 | :::{grid-item-card} {octicon}`versions;1.5em;sd-mr-1` Deterministic
22 | Multiple runs of the same pipeline will produce the same output.
23 | :::
24 |
25 | :::{grid-item-card} {octicon}`check-circle;1.5em;sd-mr-1` Resilient to preemptions
26 | Grain is designed such that checkpoints have minimal size. After
27 | pre-emption, Grain can resume from where it left off and produce the same output
28 | as if it was never preempted.
29 | :::
30 |
31 | :::{grid-item-card} {octicon}`zap;1.5em;sd-mr-1` Performant
32 | We took care while designing Grain to ensure that it's performant (refer to the
33 | [Behind the Scenes](behind_the_scenes.md) section of the documentation.) We also
34 | tested it against multiple data modalities (e.g.Text/Audio/Images/Videos).
35 | :::
36 |
37 | :::{grid-item-card} {octicon}`package;1.5em;sd-mr-1` With minimal dependencies
38 | Grain minimizes its set of dependencies when possible. For example, it should
39 | not depend on TensorFlow.
40 | :::
41 | ::::
42 |
43 | ``` {toctree}
44 | :maxdepth: 1
45 | :hidden:
46 | :caption: Get started
47 | installation
48 | api_choice
49 | data_sources
50 | behind_the_scenes
51 | ```
52 |
53 | ``` {toctree}
54 | :maxdepth: 1
55 | :hidden:
56 | :caption: Data Loader
57 | data_loader/samplers
58 | data_loader/transformations
59 | ```
60 |
61 | ```{toctree}
62 | :maxdepth: 3
63 | :hidden:
64 | :caption: Tutorials
65 | tutorials/data_loader_tutorial
66 | tutorials/dataset_basic_tutorial
67 | tutorials/dataset_advanced_tutorial
68 | tutorials/dataset_debugging_tutorial
69 | tutorials/dataset_load_from_s3_tutorial
70 | tutorials/data_sources/index
71 | ```
72 |
73 | ``` {toctree}
74 | :maxdepth: 1
75 | :hidden:
76 | :caption: API reference
77 | grain
78 | ```
79 |
80 | ``` {toctree}
81 | :maxdepth: 1
82 | :hidden:
83 | :caption: Contributor guides
84 | CONTRIBUTING
85 | ```
--------------------------------------------------------------------------------
/docs/installation.md:
--------------------------------------------------------------------------------
1 | # Installing Grain
2 |
3 |
4 |
5 | To install Grain, you can use pip:
6 |
7 | ```bash
8 | pip install grain
9 | ```
10 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | # Sphinx-related requirements.
2 | sphinx
3 | sphinx-book-theme>=1.0.1
4 | myst-nb
5 | myst-parser[linkify]
6 | sphinx-book-theme
7 | sphinx-copybutton
8 | sphinx-design
9 | # Avoiding an issue with the collapsible sidebar.
10 | pydata-sphinx-theme<0.16.0
11 | # To generate API documentation.
12 | sphinx-autoapi
13 | sphinx-autodoc2
14 | # To import the Grain package. We mock all other dependencies, but this one has
15 | # context managers that are tricky to mock.
16 | etils[epath,epy]
--------------------------------------------------------------------------------
/docs/tutorials/data_sources/arrayrecord_data_source_tutorial.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupyter:
3 | jupytext:
4 | text_representation:
5 | extension: .md
6 | format_name: markdown
7 | format_version: '1.3'
8 | jupytext_version: 1.17.1
9 | kernelspec:
10 | display_name: Python 3
11 | name: python3
12 | ---
13 |
14 |
15 | # Reading ArrayRecord Files
16 | [](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/arrayrecord_data_source_tutorial.ipynb)
17 |
18 | This tutorial provides an example of how to retrieve records from ArrayRecord
19 | files using `grain.sources.ArrayRecordDataSource`, also covers how to process
20 | and transform the data with Grain.
21 |
22 |
23 |
24 |
25 | ## Install and Load Dependencies
26 |
27 |
28 | ```python id="tzWZLNklr4Iy"
29 | !pip install grain array_record
30 | ```
31 |
32 | ```python id="8NF4E-cCbyjV"
33 | import pickle
34 | import grain
35 | import tensorflow_datasets as tfds
36 | from array_record.python import array_record_module
37 | ```
38 |
39 |
40 | ## Write a temp ArrayRecord file
41 |
42 |
43 | ```python id="WrCQ-jH53t-K"
44 | # Load a public tensorflow dataset.
45 | test_tfds = tfds.data_source("bool_q", split="train")
46 | ```
47 |
48 | ```python id="_0yBaN7hXmbu"
49 | # Write the dataset into a test array_record file.
50 | example_file_path = "./test.array_record"
51 | writer = array_record_module.ArrayRecordWriter(
52 | example_file_path, "group_size:1"
53 | )
54 | record_count = 0
55 | for record in test_tfds:
56 | writer.write(pickle.dumps(record))
57 | record_count += 1
58 | writer.close()
59 |
60 | print(
61 | f"Number of records written to array_record file {example_file_path} :"
62 | f" {record_count}"
63 | )
64 | ```
65 |
66 | ```python id="HKJ_49JCXmbu"
67 | # @title Load Data Source
68 | example_array_record_data_source = (grain.sources.ArrayRecordDataSource(
69 | example_file_path
70 | ))
71 | print(f"Number of records: {len(example_array_record_data_source)}")
72 | ```
73 |
74 | ```python id="NVRGllY3Xmbu"
75 | print(example_array_record_data_source[0])
76 | ```
77 |
78 |
79 | ## Define Transformation Function
80 |
81 |
82 | ```python id="0AS5w9quXmbu"
83 | # Load a pre trained tokenizer
84 | from tokenizers import Tokenizer
85 |
86 | tokenizer = Tokenizer.from_pretrained("bert-base-cased")
87 | ```
88 |
89 | ```python id="YiS85paBXmbu"
90 | class ParseAndTokenizeText(grain.transforms.Map):
91 | """This function takes a serialized dict (as bytes), decodes it,
92 |
93 | applies a tokenizer to a specified feature within the dict,
94 | and returns the first 10 tokens from results.
95 | """
96 |
97 | def __init__(self, tokenizer, feature_name):
98 | self._tokenizer = tokenizer
99 | self._feature_name = feature_name
100 |
101 | def map(self, element: bytes) -> [str]:
102 | parsed_element = pickle.loads(element)
103 | # only pick the first 10 token IDs from the tokenized text for testing
104 | return self._tokenizer.encode(
105 | parsed_element[self._feature_name].decode('utf-8')
106 | ).tokens[:10]
107 | ```
108 |
109 |
110 | ## Load and process data via the Dataset API
111 |
112 |
113 | ```python id="RPIy05gGUBzI"
114 | # Example using Grain's MapDataset with ArrayRecord file source.
115 | example_datasets = (
116 | grain.MapDataset.source(example_array_record_data_source)
117 | .shuffle(seed=42)
118 | .map(ParseAndTokenizeText(tokenizer, "question"))
119 | .batch(batch_size=10)
120 | )
121 | ```
122 |
123 | ```python id="xqJSeQ9hdAmF"
124 | # Output a record at a random index
125 | print(example_datasets[100])
126 | ```
127 |
--------------------------------------------------------------------------------
/docs/tutorials/data_sources/index.rst:
--------------------------------------------------------------------------------
1 | .. _data-sources-tutorials-section:
2 |
3 | Data Sources
4 | ============
5 |
6 | This section contains tutorials for using Grain to read data from various sources.
7 |
8 | .. toctree::
9 | :maxdepth: 1
10 |
11 | parquet_dataset_tutorial.md
12 | arrayrecord_data_source_tutorial.md
13 | load_from_s3_tutorial.md
14 | load_from_gcs_tutorial.md
15 | huggingface_dataset_tutorial.md
16 | pytorch_dataset_tutorial.md
17 |
--------------------------------------------------------------------------------
/docs/tutorials/data_sources/load_from_gcs_tutorial.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupyter:
3 | jupytext:
4 | text_representation:
5 | extension: .md
6 | format_name: markdown
7 | format_version: '1.3'
8 | jupytext_version: 1.17.1
9 | kernelspec:
10 | display_name: Python 3
11 | name: python3
12 | ---
13 |
14 |
15 |
16 | # Reading from GCS
17 |
18 | [](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/load_from_gcs_tutorial.ipynb)
19 |
20 | This document demonstrates how to access and load data from Google Cloud Storage using Grain. To achieve this, we'll utilize Cloud Storage [FUSE](https://cloud.google.com/storage/docs/cloud-storage-fuse/overview), an adapter that allows you to mount GCS buckets as local file systems. By using Cloud Storage FUSE to mount GCS buckets as local file systems, you can access cloud storage data just like local files.
21 |
22 |
23 |
24 | ## Mount a Cloud Storage location into the local filesystem
25 |
26 |
27 | ```python id="h6HqcZSQ_Y23"
28 | # Authenticate.
29 | from google.colab import auth
30 | auth.authenticate_user()
31 | ```
32 |
33 |
34 |
35 |
36 | The gcsfuse CLI offers various configurable options, detailed at https://cloud.google.com/storage/docs/gcsfuse-cli. Utilizing certain options, such as the caching features described at https://cloud.google.com/storage/docs/cloud-storage-fuse/caching, can enhance read performance and lower costs. For instance, MaxText setup gcsfuse flags ([MaxText gcsfuse setting link](https://github.com/AI-Hypercomputer/maxtext/blob/4e36b61cf40698224c5251c4aa4086df24140bd1/setup_gcsfuse.sh#L48)) to reduce data loading time for training. We advise users to consider adopting similar settings or customizing their own gcsfuse options.
37 |
38 |
39 | ```python id="bqz6cD7xl7F3"
40 | # Mount a Cloud Storage bucket or location, without the gs:// prefix.
41 | mount_path = "my-bucket" # or a location like "my-bucket/path/to/mount"
42 | local_path = f"./mnt/gs/{mount_path}"
43 |
44 | !mkdir -p {local_path}
45 | # The flags below are configured to improve GCS data loading performance. Users are encouraged to explore alternative settings and we would greatly appreciate any feedback or insights shared with the Grain team.
46 | !gcsfuse --implicit-dirs --type-cache-max-size-mb=-1 --stat-cache-max-size-mb=-1 --kernel-list-cache-ttl-secs=-1 --metadata-cache-ttl-secs=-1 {mount_path} {local_path}
47 | ```
48 |
49 | ```python id="j2e8nv0j_Y23"
50 | # Then you can access it like a local path.
51 | !ls -lh {local_path}
52 | ```
53 |
54 |
55 | ## Read files using Grain
56 |
57 | If your data is in an ArrayRecord file, you can directly load it using `grain.sources.ArrayRecordDataSource`. For information on handling other file formats, please see the Grain data sources documentation at: https://google-grain.readthedocs.io/en/latest/data_sources.html
58 |
59 |
60 | ```python id="yisjIpbZ_Y23"
61 | # Install Grain.
62 | !pip install grain
63 | ```
64 |
65 | ```python id="pvNTx6sL_Y23"
66 | import grain
67 |
68 | source = grain.sources.ArrayRecordDataSource(local_path+"/local_file_name")
69 |
70 | # Create a dataset from the data source then process the data.
71 | dataset = (
72 | grain.MapDataset.source(source)
73 | .shuffle(seed=10) # Shuffles globally.
74 | .batch(batch_size=2) # Batches consecutive elements.
75 | )
76 | ```
77 |
78 | ```python id="bJIYx60H_Y23"
79 | # Output a record at a random index
80 | print(dataset[10])
81 | ```
82 |
--------------------------------------------------------------------------------
/docs/tutorials/data_sources/load_from_s3_tutorial.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupyter:
3 | jupytext:
4 | text_representation:
5 | extension: .md
6 | format_name: markdown
7 | format_version: '1.3'
8 | jupytext_version: 1.17.1
9 | kernelspec:
10 | display_name: Python 3
11 | name: python3
12 | ---
13 |
14 |
15 | ## Reading from AWS S3
16 |
17 | [](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/load_from_s3_tutorial.ipynb)
18 |
19 | This document outlines how to read data from an Amazon S3 bucket and construct a Grain pipeline. We will leverage [S3 Mountpoint](https://docs.aws.amazon.com/AmazonS3/latest/userguide/mountpoint.html), a service provided by AWS. S3 Mountpoint enables you to mount your S3 bucket as a local file system, allowing you to access and read data as if it were stored locally.
20 |
21 |
22 |
23 | ### Install Mountpoint for Amazon S3
24 |
25 |
26 | ```python id="K6UTOyamWlWf"
27 | !wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb
28 | ```
29 |
30 | ```python id="iHA-C85NhwFJ"
31 | !sudo apt-get install -y ./mount-s3.deb
32 | ```
33 |
34 |
35 | ### Configure AWS credentials
36 |
37 |
38 | ```python id="8fhEOwxcWlWf"
39 | !pip install aws configure
40 | !pip install awscli
41 | ```
42 |
43 | ```python id="5Lt_644G7G9R"
44 | !aws configure
45 | ```
46 |
47 |
48 | ### Mount your S3 bucket to your local filepath
49 |
50 |
51 | ```python id="G6boYrD5WlWf"
52 | !mount-s3 /path/to/mount/files
53 | ```
54 |
55 |
56 | ### Install Grain and other dependencies
57 |
58 |
59 | ```python id="3BZP9fBiWlWf"
60 | !pip install grain
61 | !pip install array_record
62 | ```
63 |
64 |
65 | ### Write temp ArrayRecord files to the bucket
66 |
67 |
68 | ```python id="xVGVuDKNic0B"
69 | from array_record.python import array_record_module
70 |
71 | digits = [b"1", b"2", b"3", b"4", b"5"]
72 |
73 | writer = array_record_module.ArrayRecordWriter("/path/to/mount/files/data.array_record")
74 | for i in digits:
75 | writer.write(i)
76 | writer.close()
77 | ```
78 |
79 |
80 | ### Read ArrayRecord files using Grain
81 |
82 |
83 | ```python id="3l4Pnc4bWlWf"
84 | import grain
85 | from pprint import pprint
86 |
87 | source = grain.sources.ArrayRecordDataSource(paths="/path/to/mount/files/data.array_record")
88 |
89 | dataset = (
90 | grain.MapDataset.source(source)
91 | .shuffle(seed=10) # Shuffles globally.
92 | .batch(batch_size=2) # Batches consecutive elements.
93 | )
94 |
95 | pprint(list(dataset))
96 | ```
97 |
--------------------------------------------------------------------------------
/docs/tutorials/data_sources/parquet_dataset_tutorial.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupyter:
3 | jupytext:
4 | text_representation:
5 | extension: .md
6 | format_name: markdown
7 | format_version: '1.3'
8 | jupytext_version: 1.17.1
9 | kernelspec:
10 | display_name: Python 3
11 | name: python3
12 | ---
13 |
14 |
15 | # Reading Apache Parquet Files
16 |
17 | [](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/parquet_dataset_tutorial.ipynb)
18 |
19 | This tutorial provides an example of how to read data from [Apache Parquet](https://parquet.apache.org/) file, also covers how to process and transform the data with Grain.
20 |
21 |
22 |
23 |
24 | ## Generate a test Parquet file on local
25 |
26 |
27 | ```python id="dFnwN6NNw0Oe"
28 | import pyarrow as pa
29 | import pyarrow.parquet as pq
30 |
31 | # Generate a sample PyArrow table containing email subjects and bodies.
32 | table = pa.table({
33 | 'email_subject': [
34 | "Meeting Reminder: Project X Update",
35 | "Important Announcement Regarding Company Policy",
36 | "FWD: Quick Question",
37 | "Your Order Confirmation #12345",
38 | "Invitation to Team Building Activity"
39 | ],
40 | 'email_body': [
41 | "Hi team,\n\nJust a reminder about our Project X update meeting tomorrow at 10 AM PST. Please come prepared to discuss your progress and any roadblocks.\n\nSee you there,\n[Your Name]",
42 | "Dear employees,\n\nPlease be advised of a new company policy regarding remote work, effective May 1st, 2025. You can find the full details on the company intranet.\n\nRegards,\nManagement",
43 | "Hi [Name],\n\nForwarding you this email as you might have the answer to this quick question:\n\n[Original Email Content]",
44 | "Dear [Customer Name],\n\nThank you for your recent order! This email confirms your order #12345. You can view the details and track its shipment here: [Link]\n\nSincerely,\nThe [Company Name] Team",
45 | "Hello everyone,\n\nYou're invited to participate in our upcoming team building activity on Friday, April 28th. It will be a fun afternoon of [Activity]. Please RSVP by Wednesday.\n\nBest,\n[Organizer Name]"
46 | ]
47 | })
48 |
49 | # Write this table to a parquet file.
50 | writer = pq.ParquetWriter('emails.parquet', table.schema)
51 | writer.write_table(table)
52 | writer.close()
53 |
54 | ```
55 |
56 |
57 | ## Load Dataset
58 |
59 |
60 | ```python id="fNIApmKT34ac"
61 | # Install Grain
62 | !pip install grain
63 | ```
64 |
65 | ```python id="nb5yPasvDwaj"
66 | import grain
67 | import pprint
68 | ```
69 |
70 | ```python id="62_D74LkdrQn"
71 | ds = grain.experimental.ParquetIterDataset('./emails.parquet')
72 | ```
73 |
74 | ```python id="DlhbJX5zdrQo"
75 | list(ds)[0]
76 | ```
77 |
78 |
79 | ## Transform Dataset
80 |
81 |
82 | ```python id="1qevksHMdrQo"
83 | # Load a pre trained tokenizer.
84 | from tokenizers import Tokenizer
85 | tokenizer = Tokenizer.from_pretrained("bert-base-cased")
86 | ```
87 |
88 | ```python id="PNPv8LEGdrQo"
89 | class TokenizeText(grain.transforms.Map):
90 | """Tokenizes the text values within each element using a provided tokenizer."""
91 | def __init__(self, tokenizer):
92 | self._tokenizer = tokenizer
93 |
94 | def map(self, element):
95 | return [self._tokenizer.encode(item).tokens for item in element.values()]
96 | ```
97 |
98 | ```python id="O5vnwj7cek30"
99 | # Tokenize the data using the provided tokenizer.
100 | ds = ds.map(TokenizeText(tokenizer))
101 | ```
102 |
103 | ```python id="46ZVbfmtek30"
104 | # Create an iterator object of the dataset.
105 | ds_iter = iter(ds)
106 | # Print the first element in the dataset.
107 | pprint.pprint(next(ds_iter))
108 | ```
109 |
--------------------------------------------------------------------------------
/docs/tutorials/dataset_debugging_tutorial.md:
--------------------------------------------------------------------------------
1 | ---
2 | jupytext:
3 | formats: ipynb,md:myst
4 | main_language: python
5 | text_representation:
6 | extension: .md
7 | format_name: myst
8 | format_version: 0.13
9 | jupytext_version: 1.16.1
10 | kernelspec:
11 | display_name: Python 3
12 | name: python3
13 | ---
14 |
15 | +++ {"id": "OHoxgqr6sRKE"}
16 |
17 | # Performance & Debugging tool
18 | Grain offers two configurable modes that can be set to gain deeper insights into
19 | pipeline execution and identify potential issues.
20 |
21 | [](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/dataset_debugging_tutorial.ipynb)
22 |
23 | ```{code-cell}
24 | :id: xw_-jT1r6zNM
25 |
26 | # @test {"output": "ignore"}
27 | !pip install grain
28 | ```
29 |
30 | +++ {"id": "YLaRRlCPsRKE"}
31 |
32 | ## Visualization mode
33 | To get an overview of your dataset pipeline structure and clear understanding of
34 | how the data flows, enable visualization mode. This will log a visual
35 | representation of your pipeline, allowing you to easily identify different
36 | transformation stages and their relationships. To enable visualization mode, set
37 | the flag `--grain_py_dataset_visualization_output_dir=""` or call
38 | `grain.config.update("py_dataset_visualization_output_dir", "")`
39 |
40 | ```{code-cell}
41 | :id: 4y89Wx7PsRKE
42 |
43 | # @test {"output": "ignore"}
44 | import grain.python as grain
45 |
46 | grain.config.update("py_dataset_visualization_output_dir", "")
47 | ds = (
48 | grain.MapDataset.range(20)
49 | .seed(seed=42)
50 | .shuffle()
51 | .batch(batch_size=2)
52 | .map(lambda x: x)
53 | .to_iter_dataset()
54 | )
55 | it = iter(ds)
56 |
57 | # Visualization graph is constructed once the dataset produces the first element
58 | for _ in range(10):
59 | next(it)
60 | ```
61 |
62 | +++ {"id": "_3h-u2I1i7wv"}
63 |
64 | ## Debug mode
65 | To troubleshoot performance issues in your dataset pipeline, enable debug mode.
66 | This will log a real-time execution summary of the pipeline at one-minute
67 | intervals. This execution summary provides a detailed information on each
68 | transformation stage such as processing time, number of elements processed and
69 | other details that helps in identifying the slower stages in the pipeline.
70 | To enable debug mode, set the flag `--grain_py_debug_mode=true` or call
71 | `grain.config.update("py_debug_mode",True)`
72 |
73 | ```{code-cell}
74 | :id: bN45Z58E3jGS
75 |
76 | import time
77 |
78 |
79 | # Define a dummy slow preprocessing function
80 | def _dummy_slow_fn(x):
81 | time.sleep(10)
82 | return x
83 | ```
84 |
85 | ```{code-cell}
86 | ---
87 | colab:
88 | height: 897
89 | id: bN45Z58E3jGS
90 | outputId: f3d640a8-1eae-414f-e6eb-e7c02c9a91df
91 | ---
92 | # @test {"output": "ignore"}
93 | import time
94 |
95 | grain.config.update("py_debug_mode", True)
96 |
97 | ds = (
98 | grain.MapDataset.range(20)
99 | .seed(seed=42)
100 | .shuffle()
101 | .batch(batch_size=2)
102 | .map(_dummy_slow_fn)
103 | .to_iter_dataset()
104 | .map(_dummy_slow_fn)
105 | )
106 | it = iter(ds)
107 |
108 | for _ in range(10):
109 | next(it)
110 | ```
111 |
112 | +++ {"id": "eSu9SOP8_x6A"}
113 |
114 | In the above execution summary, 86% of the time is spent in the
115 | `MapDatasetIterator` node and is the slowest stage of the pipeline.
116 |
117 | Note that although from the `total_processing_time`, it might appear that
118 | `MapMapDataset`(id:2) is the slowest stage, nodes from the id 2 to 6 are
119 | executed in multiple threads and hence, the `total_processing_time` of these
120 | nodes should be compared to the `total_processing_time` of iterator nodes(id:0)
121 |
--------------------------------------------------------------------------------
/grain/BUILD:
--------------------------------------------------------------------------------
1 |
2 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library")
3 |
4 | package(default_visibility = ["//grain:__subpackages__"])
5 |
6 | licenses(["notice"])
7 |
8 | exports_files(["LICENSE"])
9 |
10 | # Backwards compatibility alias.
11 | alias(
12 | name = "python",
13 | actual = ":grain",
14 | visibility = ["//visibility:public"],
15 | )
16 |
17 | py_library(
18 | name = "grain",
19 | srcs = [
20 | "__init__.py",
21 | "_src/__init__.py",
22 | "checkpoint.py",
23 | "constants.py",
24 | "experimental.py",
25 | "multiprocessing.py",
26 | "python/__init__.py",
27 | "python/experimental.py",
28 | "samplers.py",
29 | "sharding.py",
30 | "sources.py",
31 | "transforms.py",
32 | ],
33 | data = ["//grain/_src/python/experimental/index_shuffle/python:index_shuffle_module.so"],
34 | srcs_version = "PY3",
35 | # Implicit build flag
36 | visibility = ["//visibility:public"],
37 | deps = [
38 | "//grain/_src/core:config",
39 | "//grain/_src/core:constants",
40 | "//grain/_src/core:monitoring",
41 | "//grain/_src/core:sharding",
42 | "//grain/_src/core:transforms",
43 | "//grain/_src/python:checkpoint_handlers",
44 | "//grain/_src/python:data_loader",
45 | "//grain/_src/python:data_sources",
46 | "//grain/_src/python:load",
47 | "//grain/_src/python:operations",
48 | "//grain/_src/python:options",
49 | "//grain/_src/python:record",
50 | "//grain/_src/python:samplers",
51 | "//grain/_src/python:shared_memory_array",
52 | "//grain/_src/python/dataset",
53 | "//grain/_src/python/dataset:base",
54 | "//grain/_src/python/dataset:elastic_iterator",
55 | "//grain/_src/python/dataset:stats",
56 | "//grain/_src/python/dataset:visualize",
57 | "//grain/_src/python/dataset/sources:parquet_dataset",
58 | "//grain/_src/python/dataset/sources:tfrecord_dataset",
59 | "//grain/_src/python/dataset/transformations:interleave",
60 | "//grain/_src/python/dataset/transformations:limit",
61 | "//grain/_src/python/dataset/transformations:packing",
62 | "//grain/_src/python/dataset/transformations:packing_concat_then_split",
63 | "//grain/_src/python/dataset/transformations:zip",
64 | "//grain/_src/python/experimental/example_packing:packing",
65 | "//grain/_src/python/testing:experimental",
66 | "//grain/proto:execution_summary_py_pb2",
67 | ],
68 | )
69 |
--------------------------------------------------------------------------------
/grain/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Public API for Grain."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=unused-import
19 | # pylint: disable=g-multiple-import
20 | # pylint: disable=g-import-not-at-top
21 |
22 | # We import all public modules here to enable the use of `grain.foo.Bar`
23 | # instead of forcing users to write `from grain import foo as grain_foo`.
24 | from grain import (
25 | experimental,
26 | checkpoint,
27 | constants,
28 | multiprocessing,
29 | samplers,
30 | sharding,
31 | sources,
32 | transforms,
33 | )
34 |
35 | from grain._src.core.config import config
36 | from grain._src.python.data_loader import (
37 | DataLoader,
38 | DataLoaderIterator,
39 | )
40 | from grain._src.python.dataset.dataset import (
41 | DatasetIterator,
42 | IterDataset,
43 | MapDataset,
44 | )
45 | from grain._src.python.load import load
46 | from grain._src.python.options import ReadOptions
47 | from grain._src.python.record import Record, RecordMetadata
48 |
--------------------------------------------------------------------------------
/grain/_src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/grain/_src/core/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//grain:__subpackages__"])
2 |
3 | licenses(["notice"])
4 |
5 | py_library(
6 | name = "config",
7 | srcs = ["config.py"],
8 | srcs_version = "PY3",
9 | deps = [
10 | ":monitoring",
11 | "@abseil-py//absl/flags",
12 | ],
13 | )
14 |
15 | py_library(
16 | name = "constants",
17 | srcs = ["constants.py"],
18 | srcs_version = "PY3",
19 | )
20 |
21 | py_library(
22 | name = "exceptions",
23 | srcs = ["exceptions.py"],
24 | srcs_version = "PY3",
25 | )
26 |
27 | py_library(
28 | name = "monitoring",
29 | srcs = ["monitoring.py"],
30 | srcs_version = "PY3",
31 | deps = [
32 | ],
33 | )
34 |
35 | py_library(
36 | name = "parallel",
37 | srcs = ["parallel.py"],
38 | srcs_version = "PY3",
39 | deps = [
40 | ],
41 | )
42 |
43 | py_test(
44 | name = "parallel_test",
45 | srcs = ["parallel_test.py"],
46 | srcs_version = "PY3",
47 | deps = [
48 | ":parallel",
49 | "@abseil-py//absl/testing:absltest",
50 | "@abseil-py//absl/testing:parameterized",
51 | ],
52 | )
53 |
54 | py_library(
55 | name = "grain_random",
56 | srcs = ["grain_random.py"],
57 | srcs_version = "PY3",
58 | deps = [
59 | "@abseil-py//absl/logging",
60 | "@pypi//jax:pkg",
61 | "@pypi//numpy:pkg",
62 | ],
63 | )
64 |
65 | py_library(
66 | name = "sharding",
67 | srcs = ["sharding.py"],
68 | srcs_version = "PY3",
69 | deps = [
70 | "@abseil-py//absl/logging",
71 | ],
72 | )
73 |
74 | py_test(
75 | name = "sharding_test",
76 | srcs = ["sharding_test.py"],
77 | srcs_version = "PY3",
78 | deps = [
79 | ":sharding",
80 | "@abseil-py//absl/testing:absltest",
81 | "@abseil-py//absl/testing:parameterized",
82 | ],
83 | )
84 |
85 | py_library(
86 | name = "usage_logging",
87 | srcs = ["usage_logging.py"],
88 | srcs_version = "PY3",
89 | )
90 |
91 | py_library(
92 | name = "transforms",
93 | srcs = ["transforms.py"],
94 | srcs_version = "PY3",
95 | deps = [
96 | "@pypi//numpy:pkg",
97 | ],
98 | )
99 |
100 | py_test(
101 | name = "transforms_test",
102 | srcs = ["transforms_test.py"],
103 | srcs_version = "PY3",
104 | deps = [
105 | ":transforms",
106 | "@abseil-py//absl/testing:absltest",
107 | "@abseil-py//absl/testing:parameterized",
108 | ],
109 | )
110 |
111 | py_library(
112 | name = "tree_lib",
113 | srcs = [
114 | "tree_lib.py",
115 | ],
116 | srcs_version = "PY3",
117 | deps = [
118 | "@pypi//dm_tree:pkg",
119 | "@pypi//numpy:pkg",
120 | ],
121 | )
122 |
123 | py_library(
124 | name = "tree_test_lib",
125 | testonly = 1,
126 | srcs = ["tree_lib_test.py"],
127 | srcs_version = "PY3",
128 | deps = [
129 | ":tree_lib",
130 | "@abseil-py//absl/testing:absltest",
131 | "@abseil-py//absl/testing:parameterized",
132 | "@pypi//numpy:pkg",
133 | ],
134 | )
135 |
136 | py_test(
137 | name = "tree_lib_test",
138 | srcs = ["tree_lib_test.py"],
139 | srcs_version = "PY3",
140 | deps = [
141 | ":tree_test_lib",
142 | ],
143 | )
144 |
145 | py_test(
146 | name = "tree_lib_jax_test",
147 | srcs = ["tree_lib_jax_test.py"],
148 | srcs_version = "PY3",
149 | deps = [
150 | ":tree_lib",
151 | ":tree_test_lib",
152 | "@abseil-py//absl/testing:absltest",
153 | "@pypi//attrs:pkg",
154 | "@pypi//jax:pkg",
155 | "@pypi//numpy:pkg",
156 | ],
157 | )
158 |
159 | py_test(
160 | name = "version_test",
161 | srcs = ["version_test.py"],
162 | srcs_version = "PY3",
163 | deps = [
164 | "//grain:python",
165 | "@abseil-py//absl/testing:absltest",
166 | ],
167 | )
168 |
--------------------------------------------------------------------------------
/grain/_src/core/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Shared constants for various Grain APIs."""
15 |
16 | # Below are names of meta features used by index_dataset (and pipelines building
17 | # on top of it). These features are generated on the fly and help to track
18 | # progress over the dataset. Users can read these but shouldn't alter them. They
19 | # start with "_" to indicate that they are "private".
20 | # Index into the stream of all records (globally unique). Starts with 0.
21 | INDEX = "_index"
22 | # Key of the record. If DATASET_INDEX is present it's the key in the dataset.
23 | # Starts with 0.
24 | RECORD_KEY = "_record_key"
25 | # Index of the dataset from which to take the record. Only present when mixing.
26 | # Starts with 0.
27 | DATASET_INDEX = "_dataset_index"
28 | # Epoch for the record. When mixing datasets this is the epoch over the dataset,
29 | # not the mixture. Starts with 1.
30 | EPOCH = "_epoch"
31 | # Random seed for stateless random operations. This is unique per record
32 | # and changes every epoch.
33 | SEED = "_seed"
34 | # Serialized record.
35 | RECORD = "_record"
36 |
37 | META_FEATURES = frozenset(
38 | [INDEX, RECORD_KEY, DATASET_INDEX, EPOCH, SEED, RECORD]
39 | )
40 |
--------------------------------------------------------------------------------
/grain/_src/core/exceptions.py:
--------------------------------------------------------------------------------
1 | """Custom exception for PyGrain-specific internal errors."""
2 |
3 |
4 | class PyGrainInternalError(Exception):
5 | pass
6 |
--------------------------------------------------------------------------------
/grain/_src/core/monitoring.py:
--------------------------------------------------------------------------------
1 | """Grain metrics."""
2 |
3 | import enum
4 |
5 |
6 | @enum.unique
7 | class Units(enum.Enum):
8 | """Grain metric units."""
9 |
10 | NANOSECONDS = enum.auto()
11 | MILLISECONDS = enum.auto()
12 |
13 |
14 | class NoOpMetric:
15 | """Grain metric no-op implementation."""
16 |
17 | def __init__(self, *args, **kwargs):
18 | del args, kwargs
19 |
20 | def IncrementBy(self, *args, **kwargs):
21 | del args, kwargs
22 |
23 | def Increment(self, *args, **kwargs):
24 | self.IncrementBy(1, *args, **kwargs)
25 |
26 | def Set(self, *args, **kwargs):
27 | del args, kwargs
28 |
29 | def Record(self, *args, **kwargs):
30 | del args, kwargs
31 |
32 | def Get(self, *args, **kwargs):
33 | del args, kwargs
34 |
35 |
36 | class Metadata:
37 | """Grain metric no-op metadata."""
38 |
39 | def __init__(self, *args, **kwargs):
40 | del args, kwargs
41 |
42 |
43 | class Bucketer:
44 | """Grain metric no-op bucketer."""
45 |
46 | def __init__(self, *args, **kwargs):
47 | del args, kwargs
48 |
49 | def PowersOf(self, *args, **kwargs):
50 | del args, kwargs
51 |
52 |
53 | Counter = Metric = EventMetric = NoOpMetric
54 |
55 | def get_monitoring_root() -> None:
56 | return None
57 |
--------------------------------------------------------------------------------
/grain/_src/core/parallel.py:
--------------------------------------------------------------------------------
1 | """Provides a methods to run functions in parallel using a thread pool."""
2 | from collections.abc import Mapping, Sequence
3 | from typing import Any, Callable, TypeVar
4 |
5 | from concurrent import futures
6 |
7 | T = TypeVar("T")
8 |
9 |
10 | def run_in_parallel(
11 | function: Callable[..., T],
12 | list_of_kwargs_to_function: Sequence[Mapping[str, Any]],
13 | num_workers: int,
14 | thread_name_prefix: str = "parallel_",
15 | ) -> list[T]:
16 | """Run a function on a list of kwargs in parallel with ThreadPoolExecutor.
17 |
18 | Works best when there is IO boundedness, not when there is CPU boundedness, as
19 | the threads used are bound by the GIL.
20 |
21 | Propagates first exception to the calling thread. If cancel_futures=True,
22 | then stop as many of the ongoing work units as possible.
23 |
24 | Example usage:
25 | def io_bound_function(p):
26 | get_contents_from_cns(p)
27 |
28 | run_in_parallel(
29 | function=io_bound_function,
30 | list_of_kwargs_to_function=[{"p": p} for p in long_list_of_paths],
31 | num_workers=3)
32 |
33 | Args:
34 | function: a function.
35 | list_of_kwargs_to_function: A list of dicts mapping from string to argument
36 | value. These will be passed into `function` as kwargs.
37 | num_workers: int.
38 | thread_name_prefix: The thread name prefix string. Processes are run in
39 | threads, and each thread is named. This parameter allows the user to
40 | control the prefix for that thread name.
41 |
42 | Returns:
43 | list of return values from function, in the same order as the arguments in
44 | list_of_kwargs_to_function.
45 | """
46 | if num_workers < 1:
47 | raise ValueError(
48 | "Number of workers must be greater than 0. Was {}".format(num_workers)
49 | )
50 |
51 | thread_name = thread_name_prefix + getattr(function, "__name__", "unknown")
52 | with futures.ThreadPoolExecutor(
53 | num_workers, thread_name_prefix=thread_name
54 | ) as executor:
55 | fs = []
56 |
57 | for kwargs in list_of_kwargs_to_function:
58 | f = executor.submit(function, **kwargs)
59 | fs.append(f)
60 |
61 | futures_as_completed = futures.as_completed(fs)
62 |
63 | for completed in futures_as_completed:
64 | if completed.exception():
65 | # Cancel all remaining futures, if possible.
66 | for remaining_future in fs:
67 | remaining_future.cancel()
68 |
69 | # Propagate exception to main thread.
70 | raise completed.exception()
71 |
72 | return [f.result() for f in fs]
73 |
--------------------------------------------------------------------------------
/grain/_src/core/parallel_test.py:
--------------------------------------------------------------------------------
1 | """Tests for parallel."""
2 | import threading
3 |
4 | from absl.testing import absltest
5 | from absl.testing import parameterized
6 | from grain._src.core import parallel
7 |
8 |
9 | def ReturnThreadName():
10 | return threading.current_thread().name
11 |
12 |
13 | def Identity(i):
14 | return i
15 |
16 |
17 | def FnThatAlwaysFails(arg):
18 | del arg
19 | raise ValueError("I always fail")
20 |
21 |
22 | def FnThatFailsOnOddInputs(i):
23 | if i % 2 == 1:
24 | raise ValueError("Failed on an odd input")
25 | return i
26 |
27 |
28 | class ParallelTest(parameterized.TestCase):
29 |
30 | @parameterized.named_parameters(
31 | dict(
32 | testcase_name=" empty list of args",
33 | num_workers=1,
34 | input_dict_list=[],
35 | expected=[],
36 | ),
37 | dict(
38 | testcase_name=" one worker, nonempty list",
39 | num_workers=1,
40 | input_dict_list=[dict(i=k) for k in range(1, 10)],
41 | expected=list(range(1, 10)),
42 | ),
43 | dict(
44 | testcase_name=" fewer workers than jobs, nonempty list",
45 | num_workers=3,
46 | input_dict_list=[dict(i=k) for k in range(1, 10)],
47 | expected=list(range(1, 10)),
48 | ),
49 | dict(
50 | testcase_name=" more workers than jobs, nonempty list",
51 | num_workers=20,
52 | input_dict_list=[dict(i=k) for k in range(1, 10)],
53 | expected=list(range(1, 10)),
54 | ),
55 | )
56 | def testRunInParallel(self, input_dict_list, num_workers: int, expected):
57 | actual = parallel.run_in_parallel(Identity, input_dict_list, num_workers)
58 | self.assertEqual(actual, expected)
59 |
60 | def testRunInParallelOnAlwaysFailingFn(self):
61 | with self.assertRaisesRegex(ValueError, "I always fail"):
62 | parallel.run_in_parallel(FnThatAlwaysFails, [dict(arg="hi")], 10)
63 |
64 | @parameterized.named_parameters(
65 | dict(
66 | testcase_name=" one failing input, one worker",
67 | num_workers=1,
68 | input_dict_list=[{"i": 1}],
69 | ),
70 | dict(
71 | testcase_name=" one failing input, many workers",
72 | num_workers=5,
73 | input_dict_list=[{"i": 1}],
74 | ),
75 | dict(
76 | testcase_name=" one failing input, one succeeding input",
77 | num_workers=5,
78 | input_dict_list=[{"i": 1}, {"i": 2}],
79 | ),
80 | dict(
81 | testcase_name=" two failing inputs, one succeeding input",
82 | num_workers=5,
83 | input_dict_list=[{"i": 1}, {"i": 2}, {"i": 3}],
84 | ),
85 | )
86 | def testRunInParallelFailsIfSomeFnCallsFail(
87 | self, input_dict_list, num_workers: int
88 | ):
89 | with self.assertRaisesRegex(ValueError, "Failed on an odd input"):
90 | parallel.run_in_parallel(
91 | FnThatFailsOnOddInputs, input_dict_list, num_workers
92 | )
93 |
94 | def testRunInParallelForThreadNamePrefix(self):
95 | input_kwarg_list = [{}]
96 | thread_names = parallel.run_in_parallel(
97 | ReturnThreadName, input_kwarg_list, 5, thread_name_prefix="Customized-"
98 | )
99 | self.assertStartsWith(thread_names[0], "Customized-")
100 |
101 | def testRunInParallelForDefaultThreadNamePrefix(self):
102 | input_kwarg_list = [{}]
103 | thread_names = parallel.run_in_parallel(
104 | ReturnThreadName, input_kwarg_list, 5
105 | )
106 | self.assertStartsWith(thread_names[0], "parallel_")
107 |
108 |
109 | if __name__ == "__main__":
110 | absltest.main()
111 |
--------------------------------------------------------------------------------
/grain/_src/core/sharding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Classes for handling sharding of data sources arcoss machines/VMs."""
15 | import dataclasses
16 |
17 | from absl import logging
18 |
19 |
20 | @dataclasses.dataclass(frozen=True)
21 | class ShardOptions:
22 | """Dataclass to hold options for sharding a data source.
23 |
24 | Attributes:
25 | shard_index: The index of the shard to use in this process. Must be in [0,
26 | shard_count - 1].
27 | shard_count: The total number of shards.
28 | drop_remainder: If True shard() will create even splits and drop the
29 | remainder examples (all shards will have the same number of examples). If
30 | False will distribute the remainder N over the first N shards.
31 | """
32 |
33 | shard_index: int
34 | shard_count: int
35 | drop_remainder: bool = False
36 |
37 | def __post_init__(self):
38 | if self.shard_count <= 0:
39 | raise ValueError(
40 | "Number of shards must be a positive integer but got "
41 | f"{self.shard_count}."
42 | )
43 | if self.shard_index < 0 or self.shard_index >= self.shard_count:
44 | raise ValueError(
45 | "Shard shard_index must be in [0, shard_count - 1], shard_count was "
46 | f"{self.shard_count} and shard_index was {self.shard_index}."
47 | )
48 |
49 |
50 | class NoSharding(ShardOptions):
51 | """Doesn't shard data. Each process will load all data."""
52 |
53 | def __init__(self):
54 | super().__init__(shard_index=0, shard_count=1, drop_remainder=False)
55 |
56 |
57 | class ShardByJaxProcess(ShardOptions):
58 | """Shards the data across JAX processes."""
59 |
60 | def __init__(self, drop_remainder: bool = False):
61 | process_index, process_count = get_process_index_and_count()
62 | super().__init__(
63 | shard_index=process_index,
64 | shard_count=process_count,
65 | drop_remainder=drop_remainder,
66 | )
67 |
68 |
69 | def even_split(num_examples: int, options: ShardOptions) -> tuple[int, int]:
70 | """Returns the interval for the shard when sharding `num_examples` evenly.
71 |
72 | This splits the interval [0, num_examples - 1] into `shard_count` intervals
73 | and returns the `shard_index`'s interval. If `drop_remainder` is True all
74 | intervals will have the same size.
75 |
76 | Args:
77 | num_examples: Number of examples to shard.
78 | options: Options for sharding the data in this process.
79 |
80 | Returns:
81 | Tuple with the start and end of the interval. The start is the first
82 | example that should be included in this interval and end - 1 is the last
83 | example to be include in the shard.
84 | """
85 | examples_per_shard = num_examples // options.shard_count
86 | shard_start = examples_per_shard * options.shard_index
87 | shard_end = examples_per_shard * (options.shard_index + 1)
88 |
89 | # Handle remaining examples.
90 | num_unused_examples = num_examples % options.shard_count
91 |
92 | if num_unused_examples > 0:
93 | if options.drop_remainder:
94 | logging.warning(
95 | "Dropping %d examples of %d examples (shard %d).",
96 | num_unused_examples,
97 | num_examples,
98 | options.shard_count,
99 | )
100 | else:
101 | shard_start += min(options.shard_index, num_unused_examples)
102 | shard_end += min(options.shard_index + 1, num_unused_examples)
103 | return shard_start, shard_end
104 |
105 |
106 | def get_process_index_and_count():
107 | try:
108 | import jax # pylint:disable=g-import-not-at-top # pytype:disable=import-error
109 |
110 | return jax.process_index(), jax.process_count()
111 | except ImportError:
112 | return 0, 1
113 |
--------------------------------------------------------------------------------
/grain/_src/core/sharding_test.py:
--------------------------------------------------------------------------------
1 | """Tests for sharding."""
2 | from absl.testing import absltest
3 | from absl.testing import parameterized
4 | from grain._src.core import sharding
5 |
6 |
7 | class ShardingTest(parameterized.TestCase):
8 |
9 | @parameterized.parameters(
10 | # num_examples, shard_index, shard_count, drop_remainder, expected_output.
11 | (9, 0, 1, True, (0, 9)),
12 | (9, 0, 2, True, (0, 4)),
13 | (9, 1, 2, True, (4, 8)), # Last example gets dropped.
14 | (9, 0, 3, True, (0, 3)),
15 | (9, 1, 3, True, (3, 6)),
16 | (9, 2, 3, True, (6, 9)),
17 | (9, 0, 1, False, (0, 9)),
18 | (9, 0, 2, False, (0, 5)), # First shard gets an extra example.
19 | (9, 1, 2, False, (5, 9)),
20 | (8, 0, 3, False, (0, 3)), # First 2 shards get 1 example each.
21 | (8, 1, 3, False, (3, 6)),
22 | (8, 2, 3, False, (6, 8)),
23 | )
24 | def test_sharding(
25 | self,
26 | num_examples: int,
27 | shard_index: int,
28 | shard_count: int,
29 | drop_remainder,
30 | expected_output: tuple[int, int],
31 | ):
32 | shard_options = sharding.ShardOptions(
33 | shard_index=shard_index,
34 | shard_count=shard_count,
35 | drop_remainder=drop_remainder,
36 | )
37 | actual_output = sharding.even_split(num_examples, shard_options)
38 | self.assertEqual(actual_output, expected_output)
39 |
40 |
41 | if __name__ == '__main__':
42 | absltest.main()
43 |
--------------------------------------------------------------------------------
/grain/_src/core/smoke_test_with_jax.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Checks that OSS Grain Package works end-to-end with JAX."""
15 | from typing import Sequence
16 | from absl import app
17 | import grain
18 | import jax.numpy as jnp
19 |
20 |
21 | def main(argv: Sequence[str]) -> None:
22 | del argv
23 | ds = grain.MapDataset.source(jnp.arange(10)).map(lambda x: x + 1)
24 |
25 | for _ in ds:
26 | pass
27 |
28 |
29 | if __name__ == "__main__":
30 | app.run(main)
31 |
--------------------------------------------------------------------------------
/grain/_src/core/smoke_test_with_tf.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Checks that OSS Grain Package works end-to-end with TF."""
15 | from typing import Sequence
16 | from absl import app
17 | import grain
18 | import tensorflow as tf
19 |
20 |
21 | def main(argv: Sequence[str]) -> None:
22 | del argv
23 | ds = grain.MapDataset.source(range(10)).map(tf.convert_to_tensor)
24 |
25 | for _ in ds:
26 | pass
27 |
28 |
29 | if __name__ == "__main__":
30 | app.run(main)
31 |
--------------------------------------------------------------------------------
/grain/_src/core/transforms_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import functools
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | from grain._src.core import transforms
19 |
20 |
21 | class _TestFilter(transforms.Filter):
22 |
23 | def filter(self, x):
24 | return x % 2 == 0
25 |
26 |
27 | class _TestFilterWithStr(transforms.Filter):
28 |
29 | def filter(self, x):
30 | return x % 2 == 0
31 |
32 | def __str__(self):
33 | return "CustomStr"
34 |
35 |
36 | class _TestMapWithRepr(transforms.MapTransform):
37 |
38 | def map(self, x):
39 | return x % 2 == 0
40 |
41 | def __repr__(self):
42 | return "CustomRepr"
43 |
44 |
45 | class GetPrettyTransformNameTest(parameterized.TestCase):
46 |
47 | @parameterized.parameters(
48 | dict(
49 | transform=lambda x: x,
50 | expected_substring=" @ .../_src/core/transforms_test.py:",
51 | ),
52 | dict(
53 | transform=transforms.get_pretty_transform_name,
54 | expected_substring=(
55 | "get_pretty_transform_name @ .../_src/core/transforms.py:"
56 | ),
57 | ),
58 | dict(transform=list, expected_substring="list"),
59 | dict(
60 | transform=functools.partial(lambda x, y: x + y, 1),
61 | expected_substring="functools.partial",
62 | ),
63 | dict(transform=_TestFilter(), expected_substring="_TestFilter"),
64 | dict(
65 | transform=_TestFilterWithStr(),
66 | expected_substring="CustomStr",
67 | ),
68 | dict(
69 | transform=_TestMapWithRepr(),
70 | expected_substring="CustomRepr",
71 | ),
72 | )
73 | def test_get_pretty_transform_name(self, transform, expected_substring):
74 | self.assertIn(
75 | expected_substring, transforms.get_pretty_transform_name(transform)
76 | )
77 |
78 |
79 | if __name__ == "__main__":
80 | absltest.main()
81 |
--------------------------------------------------------------------------------
/grain/_src/core/tree_lib_jax_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Testes for tree_lib.py with JAX dependency present."""
15 |
16 | from absl.testing import absltest
17 | import attrs
18 | from grain._src.core import tree_lib
19 | from grain._src.core import tree_lib_test
20 | import jax
21 | import numpy as np
22 |
23 |
24 | class MyTree:
25 |
26 | def __init__(self, a, b):
27 | self.a = a
28 | self.b = b
29 |
30 | def __eq__(self, other):
31 | return self.a == other.a and self.b == other.b
32 |
33 |
34 | class MyClass:
35 |
36 | def __init__(self, c):
37 | self.c = c
38 |
39 |
40 | @attrs.define
41 | class MyAttrs:
42 | d: int
43 | e: str
44 |
45 |
46 | class TreeJaxTest(tree_lib_test.TreeTest):
47 |
48 | def test_map_custom_tree(self):
49 | jax.tree_util.register_pytree_node(
50 | MyTree, lambda t: ((t.a, t.b), None), lambda _, args: MyTree(*args)
51 | )
52 | self.assertEqual(
53 | tree_lib.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3)
54 | )
55 |
56 | def test_spec_like_with_class(self):
57 | self.assertEqual(
58 | tree_lib.spec_like({"B": 1232.4, "C": MyClass(1)}),
59 | {
60 | "B": "[]",
61 | "C": "[]",
62 | },
63 | )
64 |
65 | def test_spec_like_with_list(self):
66 | self.assertEqual(
67 | tree_lib.spec_like({
68 | "B": 1232.4,
69 | "C": [
70 | tree_lib_test.TestClass(a=1, b="v2"),
71 | tree_lib_test.TestClass(a=2, b="v2"),
72 | ],
73 | }),
74 | {
75 | "B": "[]",
76 | "C": "list[2]",
77 | },
78 | )
79 |
80 | def test_spec_like_with_unknown_shape(self):
81 | self.assertEqual(
82 | tree_lib.spec_like({
83 | "B": [np.zeros([2]), np.zeros([1])],
84 | "C": [],
85 | }),
86 | {"B": "list[unknown shape]", "C": "list<>[0]"},
87 | )
88 |
89 | def test_spec_like_with_dataclass(self):
90 | self.assertEqual(
91 | tree_lib.spec_like(tree_lib_test.TestClass(a=1, b="v2")),
92 | "\n"
93 | "{'a': \"[]\", 'b': \"[]\"}[]",
94 | )
95 |
96 | def test_spec_like_with_attrs(self):
97 | self.assertEqual(
98 | tree_lib.spec_like(MyAttrs(d=1, e="v2")),
99 | "\n"
100 | "{'d': \"[]\", 'e': \"[]\"}[]",
101 | )
102 |
103 |
104 | if __name__ == "__main__":
105 | absltest.main()
106 |
--------------------------------------------------------------------------------
/grain/_src/core/usage_logging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License
14 | """Internal usage logging."""
15 |
16 |
17 | def log_event(tag: str, *, tag_2: str = "", tag_3: str = "") -> None:
18 | return
19 |
--------------------------------------------------------------------------------
/grain/_src/core/version_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Checks that tests in OSS are run with the correct version of Python."""
15 | # Make sure grain can be imported.
16 | from grain import python as grain # pylint: disable=unused-import
17 |
18 | import os
19 | import sys
20 | from absl.testing import absltest
21 |
22 |
23 | class VersionTest(absltest.TestCase):
24 |
25 | def test_python_version(self):
26 | expected = os.getenv("PYTHON_VERSION")
27 | current = f"{sys.version_info.major}.{sys.version_info.minor}"
28 | if current != expected:
29 | raise ValueError(
30 | f"expected version '{expected}' is different than returned"
31 | f" '{current}'"
32 | )
33 |
34 |
35 | if __name__ == "__main__":
36 | absltest.main()
37 |
--------------------------------------------------------------------------------
/grain/_src/python/checkpoint_handlers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """This module provides a PyGrain CheckpointHandler for integration with Orbax."""
15 | import dataclasses
16 | import json
17 | from typing import Any, Optional, TypeVar
18 |
19 | from etils import epath
20 | from grain._src.core import sharding
21 | from grain._src.python import data_loader
22 | from grain._src.python.dataset import dataset
23 |
24 | IteratorType = TypeVar(
25 | "IteratorType", data_loader.DataLoaderIterator, dataset.DatasetIterator
26 | )
27 |
28 |
29 | # Ipmlements orbax.checkpoint.CheckpointHandler.
30 | class CheckpointHandler:
31 | """Orbax CheckpointHandler for PyGrain iterators."""
32 |
33 | def save(
34 | self,
35 | directory: epath.Path,
36 | # `item` is for backwards compatibility with older Orbax API, see
37 | # https://orbax.readthedocs.io/en/latest/api_refactor.html.
38 | item: Optional[IteratorType] = None,
39 | args: Any = None,
40 | ):
41 | """Saves the given iterator to the checkpoint in `directory`."""
42 | item = item or args.item # pytype:disable=attribute-error
43 | if isinstance(item, dataset.DatasetIterator):
44 | state = json.dumps(item.get_state(), indent=4)
45 | else:
46 | state = item.get_state().decode()
47 | process_index, process_count = sharding.get_process_index_and_count()
48 | filename = directory / f"process_{process_index}-of-{process_count}.json"
49 | filename.write_text(state)
50 |
51 | def restore(
52 | self,
53 | directory: epath.Path,
54 | item: Optional[IteratorType] = None,
55 | args: Any = None,
56 | ) -> IteratorType:
57 | """Restores the given iterator from the checkpoint in `directory`."""
58 | item = item or args.item # pytype:disable=attribute-error
59 | process_index, process_count = sharding.get_process_index_and_count()
60 | filename = directory / f"process_{process_index}-of-{process_count}.json"
61 | if not filename.exists():
62 | raise ValueError(f"File {filename} does not exist.")
63 | state = filename.read_text()
64 | if isinstance(item, dataset.DatasetIterator):
65 | state = json.loads(state)
66 | else:
67 | state = state.encode()
68 | item.set_state(state)
69 | return item
70 |
71 | # Required by interface but not supported by PyGrain checkpoints.
72 | def structure(self, directory: epath.Path) -> Any:
73 | del directory
74 | return None
75 |
76 | # Required by interface.
77 |
78 | def metadata(self, directory: epath.Path) -> Optional[Any]:
79 | del directory
80 | return None
81 |
82 | def finalize(self, directory: epath.Path):
83 | pass
84 |
85 | def close(self):
86 | pass
87 |
88 | @classmethod
89 | def typestr(cls):
90 | return f"{cls.__module__}.{cls.__qualname__}"
91 |
92 |
93 | try:
94 | # Register the handler to be used with the new checkpointing API if Orbax is
95 | # present.
96 | import orbax.checkpoint as ocp # pylint:disable=g-import-not-at-top # pytype:disable=import-error
97 |
98 | @ocp.args.register_with_handler(CheckpointHandler, for_save=True) # pytype:disable=wrong-arg-types
99 | @dataclasses.dataclass
100 | class CheckpointSave(ocp.args.CheckpointArgs):
101 | item: Any
102 |
103 | @ocp.args.register_with_handler(CheckpointHandler, for_restore=True) # pytype:disable=wrong-arg-types
104 | @dataclasses.dataclass
105 | class CheckpointRestore(ocp.args.CheckpointArgs):
106 | item: Any
107 |
108 |
109 | except (ImportError, TypeError, AttributeError):
110 | pass
111 |
--------------------------------------------------------------------------------
/grain/_src/python/checkpointing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Utilities providing checkpointing capabilities for Grain iterators."""
15 |
16 | import asyncio
17 | from typing import Callable, Protocol
18 | from etils import epath
19 | from grain._src.core import sharding
20 |
21 |
22 | class PathAwaitingCreation(Protocol):
23 | """A path that is in the process of being created.
24 |
25 | Please see orbax/checkpoint for the full definition of this type.
26 | """
27 |
28 | async def await_creation(self) -> epath.Path:
29 | """Waits for the path to be created.
30 |
31 | This function MUST be called before accessing the physical path. Prefer to
32 | perform in the background operation, rather than the main-thread-blocking
33 | operation.
34 |
35 | Returns:
36 | The path that is now created.
37 | """
38 |
39 |
40 | async def background_save(directory: PathAwaitingCreation, state: str):
41 | """An async function that saves iterator state in a background thread.
42 |
43 | Args:
44 | directory: The directory to save the state to.
45 | state: The state to save.
46 | """
47 | directory = await directory.await_creation()
48 | process_index, process_count = sharding.get_process_index_and_count()
49 | filename = directory / f"process_{process_index}-of-{process_count}.json"
50 | await asyncio.to_thread(filename.write_text, state)
51 |
52 |
53 | async def background_load(
54 | directory: epath.Path, set_state_fn: Callable[[str], None]
55 | ):
56 | process_index, process_count = sharding.get_process_index_and_count()
57 | filename = directory / f"process_{process_index}-of-{process_count}.json"
58 | if not await asyncio.to_thread(filename.exists):
59 | raise ValueError(f"File {filename} does not exist.")
60 | state = await asyncio.to_thread(filename.read_text)
61 | set_state_fn(state)
62 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/base_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for base.py."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | from grain._src.python import data_sources
19 | from grain._src.python.dataset import base
20 |
21 |
22 | class RandomAccessDataSourceTest(parameterized.TestCase):
23 |
24 | @parameterized.parameters(
25 | data_sources.ArrayRecordDataSource,
26 | data_sources.RangeDataSource,
27 | data_sources.SharedMemoryDataSource,
28 | )
29 | def test_protocol(self, source_cls):
30 | self.assertIsInstance(source_cls, base.RandomAccessDataSource)
31 |
32 |
33 | class DatasetOptionsTest(parameterized.TestCase):
34 |
35 | @parameterized.named_parameters(
36 | dict(
37 | testcase_name="no_conflicts",
38 | a=base.DatasetOptions(filter_warn_threshold_ratio=0.1),
39 | b=base.DatasetOptions(filter_raise_threshold_ratio=0.2),
40 | expected=base.DatasetOptions(
41 | filter_warn_threshold_ratio=0.1,
42 | filter_raise_threshold_ratio=0.2,
43 | ),
44 | ),
45 | dict(
46 | testcase_name="all_fields_default",
47 | a=base.DatasetOptions(),
48 | b=base.DatasetOptions(
49 | filter_warn_threshold_ratio=0.4,
50 | filter_raise_threshold_ratio=0.3,
51 | ),
52 | expected=base.DatasetOptions(
53 | filter_warn_threshold_ratio=0.4,
54 | filter_raise_threshold_ratio=0.3,
55 | ),
56 | ),
57 | dict(
58 | testcase_name="field_conflict",
59 | a=base.DatasetOptions(filter_raise_threshold_ratio=0.1),
60 | b=base.DatasetOptions(filter_raise_threshold_ratio=0.2),
61 | expected=base.DatasetOptions(
62 | filter_raise_threshold_ratio=0.1,
63 | ),
64 | ),
65 | )
66 | def test_merge(self, a, b, expected):
67 | self.assertEqual(a.merge(b), expected)
68 |
69 |
70 | class IteratorContextTest(parameterized.TestCase):
71 |
72 | def test_merge(self):
73 | a = base.IteratorContext(
74 | dataset_options=base.DatasetOptions(filter_warn_threshold_ratio=0.1)
75 | )
76 | b = base.IteratorContext(
77 | dataset_options=base.DatasetOptions(
78 | filter_warn_threshold_ratio=0.2, filter_raise_threshold_ratio=0.2
79 | )
80 | )
81 | a.merge(b)
82 | self.assertEqual(
83 | a,
84 | base.IteratorContext(
85 | dataset_options=base.DatasetOptions(
86 | filter_warn_threshold_ratio=0.1,
87 | filter_raise_threshold_ratio=0.2,
88 | )
89 | ),
90 | )
91 |
92 | def test_merge_with_different_mp_context(self):
93 | a = base.IteratorContext(
94 | mp_context=base.MultiprocessingContext(process_index=0, process_count=1)
95 | )
96 | b = base.IteratorContext(
97 | mp_context=base.MultiprocessingContext(process_index=1, process_count=2)
98 | )
99 | with self.assertRaisesRegex(
100 | ValueError, "Cannot merge contexts from different worker processes"
101 | ):
102 | a.merge(b)
103 |
104 |
105 | if __name__ == "__main__":
106 | absltest.main()
107 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/elastic_iterator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Iterator supporting changes in the number of hosts (dataset shards)."""
15 |
16 | import functools
17 | from typing import Any
18 |
19 | from grain._src.core import sharding
20 | from grain._src.python import options
21 | from grain._src.python.dataset import dataset
22 | from grain._src.python.dataset.transformations import (
23 | filter as filter_dataset,
24 | )
25 |
26 | _GLOBAL_NEXT_INDEX_STATE_KEY = "global_next_index"
27 |
28 |
29 | class ElasticIterator(dataset.DatasetIterator):
30 | """Iterator supporting recovery from a checkpoint after changes in sharding.
31 |
32 | The input dataset is expected to be unbatched and unsharded. In order to
33 | provide elasticity guarantee this iterator includes both, batching and
34 | sharding. The iterator supports elastic re-configuration by having each
35 | shard produce the same exact checkpoint (while producing different data) as
36 | long as they are advanced the same number of steps.
37 |
38 | State of any shard can be used to restore the state of all of the shards after
39 | changes in sharding and global batch size.
40 |
41 | This iterator explicitly disallows many-to-one transformations without
42 | a fixed ratio, like `filter` and generic `IterDataset` transformations.
43 | """
44 |
45 | def __init__(
46 | self,
47 | ds: dataset.MapDataset,
48 | global_batch_size: int,
49 | shard_options: sharding.ShardOptions,
50 | *,
51 | read_options: options.ReadOptions = options.ReadOptions(),
52 | multiprocessing_options: options.MultiprocessingOptions | None = None,
53 | ):
54 | super().__init__()
55 | to_check = [ds]
56 | while to_check:
57 | next_ds = to_check.pop()
58 | if isinstance(next_ds, filter_dataset.FilterMapDataset):
59 | raise ValueError(
60 | "ElasticIterator does not support `filter` transformation."
61 | )
62 | to_check.extend(next_ds.parents)
63 | self._ds = ds
64 | self._global_batch_size = global_batch_size
65 | self._shard_options = shard_options
66 | self._global_next_index = 0
67 | self._read_options = read_options
68 | self._multiprocessing_options = multiprocessing_options
69 |
70 | @functools.cached_property
71 | def _iterator(self) -> dataset.DatasetIterator:
72 | ds = self._ds[
73 | self._global_next_index
74 | + self._shard_options.shard_index :: self._shard_options.shard_count
75 | ]
76 | host_batch_size, remainder = divmod(
77 | self._global_batch_size, self._shard_options.shard_count
78 | )
79 | if remainder:
80 | raise ValueError(
81 | f"Global batch size {self._global_batch_size} is not divisible by"
82 | f" shard count {self._shard_options.shard_count}."
83 | )
84 | ds = ds.batch(host_batch_size, drop_remainder=True)
85 | ds = ds.to_iter_dataset(read_options=self._read_options)
86 | if self._multiprocessing_options is not None:
87 | ds = ds.mp_prefetch(self._multiprocessing_options)
88 | return ds.__iter__()
89 |
90 | def __iter__(self) -> dataset.DatasetIterator:
91 | return self
92 |
93 | def __next__(self) -> Any:
94 | result = next(self._iterator)
95 | self._global_next_index += self._global_batch_size
96 | return result
97 |
98 | def get_state(self) -> dict[str, Any]:
99 | return {
100 | _GLOBAL_NEXT_INDEX_STATE_KEY: self._global_next_index,
101 | }
102 |
103 | def set_state(self, state: dict[str, Any]):
104 | self._global_next_index = state[_GLOBAL_NEXT_INDEX_STATE_KEY]
105 | # Reset the iterator if it was already created.
106 | self.__dict__.pop("_iterator", None)
107 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/sources/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//grain:__subpackages__"])
2 |
3 | licenses(["notice"])
4 |
5 | py_library(
6 | name = "parquet_dataset",
7 | srcs = ["parquet_dataset.py"],
8 | srcs_version = "PY3",
9 | deps = [
10 | "//grain/_src/python/dataset",
11 | "@pypi//etils:pkg",
12 | ],
13 | )
14 |
15 | py_test(
16 | name = "parquet_dataset_test",
17 | srcs = ["parquet_dataset_test.py"],
18 | srcs_version = "PY3",
19 | deps = [
20 | ":parquet_dataset",
21 | "//grain:python",
22 | "@abseil-py//absl/flags",
23 | "@abseil-py//absl/testing:absltest",
24 | "@pypi//pyarrow:pkg",
25 | ],
26 | )
27 |
28 | py_library(
29 | name = "tfrecord_dataset",
30 | srcs = ["tfrecord_dataset.py"],
31 | srcs_version = "PY3",
32 | deps = ["//grain/_src/python/dataset"],
33 | )
34 |
35 | py_test(
36 | name = "tfrecord_dataset_test",
37 | srcs = ["tfrecord_dataset_test.py"],
38 | args = ["--test_srcdir=grain/_src/python/testdata"],
39 | data = [
40 | "//grain/_src/python/testdata:morris_sequence_first_5.tfrecord",
41 | ],
42 | srcs_version = "PY3",
43 | deps = [
44 | ":tfrecord_dataset",
45 | "//grain",
46 | "@abseil-py//absl/flags",
47 | "@abseil-py//absl/testing:absltest",
48 | ],
49 | )
50 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/sources/parquet_dataset.py:
--------------------------------------------------------------------------------
1 | """Provides an `IterDataset` for Parquet file format."""
2 |
3 | from typing import TypeVar
4 |
5 | from etils import epy
6 | from grain._src.python.dataset import dataset
7 |
8 |
9 | # lazy import for pyarrow
10 | with epy.lazy_imports():
11 | import pyarrow.parquet as pq # pytype: disable=import-error # pylint: disable=g-import-not-at-top
12 |
13 |
14 | T = TypeVar("T")
15 |
16 |
17 | class _ParquetDatasetIterator(dataset.DatasetIterator[T]):
18 | """A DatasetIterator for Parquet file format."""
19 |
20 | def __init__(
21 | self, path: str, row_group: int = 0, index_within_row_group: int = 0
22 | ):
23 | super().__init__()
24 | self._row_group = row_group
25 | self._index_within_row_group = index_within_row_group
26 | self._pq_path = path
27 | self._pq_file = pq.ParquetFile(self._pq_path)
28 | self._np_table = {}
29 | self._row_group_len = 0
30 | self._read_row_group_to_np_table()
31 |
32 | def _read_row_group_to_np_table(self):
33 | table = self._pq_file.read_row_group(self._row_group)
34 | self._row_group_len = len(table)
35 | self._np_table = {}
36 | for i in range(table.num_columns):
37 | self._np_table[table.field(i).name] = table.column(i).to_numpy()
38 |
39 | def __next__(self):
40 | if self._index_within_row_group >= self._row_group_len:
41 | if self._row_group < self._pq_file.num_row_groups - 1:
42 | self._row_group += 1
43 | self._index_within_row_group = 0
44 | self._read_row_group_to_np_table()
45 | return self.__next__()
46 | else:
47 | raise StopIteration()
48 | else:
49 | item = {
50 | k: v[self._index_within_row_group] for k, v in self._np_table.items()
51 | }
52 | self._index_within_row_group += 1
53 | return item
54 |
55 | def get_state(self):
56 | return {
57 | "row_group": self._row_group,
58 | "index_within_row_group": self._index_within_row_group,
59 | }
60 |
61 | def set_state(self, state):
62 | self._row_group = state["row_group"]
63 | self._index_within_row_group = state["index_within_row_group"]
64 | self._read_row_group_to_np_table()
65 |
66 |
67 | class ParquetIterDataset(dataset.IterDataset[T]):
68 | """An IterDataset for a parquet format file."""
69 |
70 | def __init__(self, path: str):
71 | """Initializes ParquetIterDataset.
72 |
73 | Args:
74 | path: A path to a record io format file.
75 | """
76 | super().__init__()
77 | self._path = path
78 |
79 | def __iter__(self) -> _ParquetDatasetIterator[T]:
80 | return _ParquetDatasetIterator(self._path)
81 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/sources/parquet_dataset_test.py:
--------------------------------------------------------------------------------
1 | from absl import flags
2 | from absl.testing import absltest
3 | from grain._src.python.dataset.sources import parquet_dataset
4 | import grain.python as grain
5 | import pyarrow as pa
6 | import pyarrow.parquet as pq
7 |
8 | flags.FLAGS.mark_as_parsed()
9 |
10 | SOME_TEXT = [
11 | [
12 | "This is the first file the first record",
13 | "This is the first file the second record",
14 | "This is the first file the third record",
15 | "This is the first file the forth record",
16 | ],
17 | [
18 | "This is the second file the first record",
19 | "This is the second file the second record",
20 | "This is the second file the third record",
21 | "This is the second file the forth record",
22 | ],
23 | ]
24 | INTERLEAVED_TEXT = [
25 | "This is the first file the first record",
26 | "This is the second file the first record",
27 | "This is the first file the second record",
28 | "This is the second file the second record",
29 | "This is the first file the third record",
30 | "This is the second file the third record",
31 | "This is the first file the forth record",
32 | "This is the second file the forth record",
33 | ]
34 | WINDOWSHUFFLED_TEXT = [
35 | "This is the first file the second record",
36 | "This is the second file the first record",
37 | "This is the first file the first record",
38 | "This is the first file the third record",
39 | "This is the second file the second record",
40 | "This is the second file the third record",
41 | "This is the second file the forth record",
42 | "This is the first file the forth record",
43 | ]
44 |
45 |
46 | class ParquetIterDatasetTest(absltest.TestCase):
47 |
48 | def setUp(self):
49 | super().setUp()
50 | self.filenames = []
51 | for i in range(len(SOME_TEXT)):
52 | temp_file = self.create_tempfile()
53 | filename = temp_file.full_path
54 | self.filenames.append(filename)
55 | table = pa.table({"text": SOME_TEXT[i]})
56 | pq.write_table(table, filename, row_group_size=2)
57 |
58 | def test_read_row_group(self):
59 | dataset = parquet_dataset.ParquetIterDataset(self.filenames[0])
60 | records = list(dataset)
61 | self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]])
62 |
63 | def test_checkpointing(self):
64 | dataset = parquet_dataset.ParquetIterDataset(self.filenames[0])
65 | grain.experimental.assert_equal_output_after_checkpoint(dataset)
66 |
67 | def test_sharded_files_and_interleaved_dataset(self):
68 | dataset = grain.MapDataset.source(self.filenames)
69 | dataset = dataset.map(parquet_dataset.ParquetIterDataset)
70 | dataset = grain.experimental.InterleaveIterDataset(
71 | dataset, cycle_length=len(self.filenames)
72 | )
73 | self.assertSequenceEqual(
74 | list(iter(dataset)), [{"text": x} for x in INTERLEAVED_TEXT]
75 | )
76 |
77 | dataset = grain.experimental.WindowShuffleIterDataset(
78 | dataset, window_size=3, seed=42
79 | )
80 | self.assertSequenceEqual(
81 | list(iter(dataset)), [{"text": x} for x in WINDOWSHUFFLED_TEXT]
82 | )
83 |
84 |
85 | if __name__ == "__main__":
86 | absltest.main()
87 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/sources/tfrecord_dataset.py:
--------------------------------------------------------------------------------
1 | """Provides an `IterDataset` for TFRecord file format."""
2 |
3 | import codecs
4 | import struct
5 | from typing import TypeVar
6 |
7 | from grain._src.python.dataset import dataset
8 |
9 |
10 | T = TypeVar("T")
11 |
12 | # Format of a single tf_record:
13 | # uint64 length of record in bytes
14 | # uint32 masked crc of length
15 | # bytes record data
16 | # uint32 masked crc of data
17 | _UNIT32_SIZE_IN_BYTES = 4
18 | _UNIT64_SIZE_IN_BYTES = 8
19 |
20 |
21 | class _TFRecordReader:
22 | """A reader for TFRecord files."""
23 |
24 | def __init__(self, path: str):
25 | self._reader = open(path, "rb")
26 |
27 | def __next__(self) -> bytes:
28 | """Reads the next record from the reader."""
29 | # Read the length and the length mask of the tf_record (uint64 and uint32
30 | # respectively)
31 | buf_length_expected = _UNIT64_SIZE_IN_BYTES + _UNIT32_SIZE_IN_BYTES
32 | buf = self._reader.read(buf_length_expected)
33 | if not buf:
34 | # If the buffer is empty, we have reached the end of the dataset.
35 | raise StopIteration()
36 | if len(buf) != buf_length_expected:
37 | raise ValueError(
38 | f"Not a valid TFRecord. Fewer than {buf_length_expected} bytes:"
39 | f" {codecs.encode(buf, 'hex')}"
40 | )
41 | length, _ = struct.unpack(" int:
61 | return self._reader.tell()
62 |
63 | def __del__(self):
64 | if hasattr(self, "_reader") and self._reader:
65 | self._reader.close()
66 |
67 |
68 | class _TFRecordDatasetIterator(dataset.DatasetIterator[T]):
69 | """A DatasetIterator for TFRecord file format."""
70 |
71 | def __init__(self, path: str):
72 | super().__init__()
73 | self._reader = _TFRecordReader(path)
74 |
75 | def __next__(self) -> T:
76 | return next(self._reader)
77 |
78 | def get_state(self) -> dict[str, int]:
79 | return {
80 | "reader_offset": self._reader.tell(),
81 | }
82 |
83 | def set_state(self, state: dict[str, int]):
84 | self._reader.seek(state["reader_offset"])
85 |
86 |
87 | class TFRecordIterDataset(dataset.IterDataset[T]):
88 | """An IterDataset for a TFRecord format file."""
89 |
90 | def __init__(self, path: str):
91 | super().__init__()
92 | self._path = path
93 |
94 | def __iter__(self) -> dataset.DatasetIterator[T]:
95 | return _TFRecordDatasetIterator[T](self._path)
96 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/sources/tfrecord_dataset_test.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 |
3 | from absl import flags
4 | from absl.testing import absltest
5 | from grain._src.python.dataset.sources import tfrecord_dataset
6 | import grain.python as pygrain
7 |
8 |
9 | class TFRecordIterDatasetTest(absltest.TestCase):
10 |
11 | def setUp(self):
12 | super().setUp()
13 | self.testdata_dir = pathlib.Path(flags.FLAGS.test_srcdir)
14 | self.testdata_file_path = self.testdata_dir
15 | self.testdata_file_path /= "morris_sequence_first_5.tfrecord"
16 | self.expected_data = [
17 | b"1",
18 | b"1 1",
19 | b"2 1",
20 | b"1 2 1 1",
21 | b"1 1 1 2 2 1",
22 | ]
23 |
24 | def test_nonexistent_tfrecord_file(self):
25 | dataset = tfrecord_dataset.TFRecordIterDataset(
26 | str(self.testdata_dir / "non_existent_file.tfrecord")
27 | )
28 | with self.assertRaises(FileNotFoundError):
29 | list(dataset)
30 |
31 | def test_empty_tfrecord_file(self):
32 | empty_tf_record_file = self.create_tempfile("empty_file.tfrecord")
33 | dataset = tfrecord_dataset.TFRecordIterDataset(
34 | empty_tf_record_file.full_path
35 | )
36 | self.assertSequenceEqual(list(dataset), [])
37 |
38 | def test_invalid_tfrecord_file(self):
39 | truncated_length_tf_record_file = self.create_tempfile(
40 | "truncated_length_file.tfrecord"
41 | )
42 | # Directly write the data to the file instead of using the tfrecord writer,
43 | # and without the length prefix. This will create an invalid tfrecord file.
44 | with open(truncated_length_tf_record_file, "wb") as f:
45 | f.write(b"1")
46 |
47 | dataset = tfrecord_dataset.TFRecordIterDataset(
48 | truncated_length_tf_record_file.full_path
49 | )
50 | with self.assertRaises(ValueError):
51 | list(dataset)
52 |
53 | def test_read_tfrecord_file(self):
54 | dataset = tfrecord_dataset.TFRecordIterDataset(str(self.testdata_file_path))
55 | self.assertSequenceEqual(list(dataset), self.expected_data)
56 |
57 | def test_checkpointing(self):
58 | dataset = tfrecord_dataset.TFRecordIterDataset(str(self.testdata_file_path))
59 | pygrain.experimental.assert_equal_output_after_checkpoint(dataset)
60 |
61 |
62 | if __name__ == "__main__":
63 | absltest.main()
64 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/limit.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Implements limit transformations."""
15 |
16 | from typing import Any, TypeVar
17 |
18 | from grain._src.python.dataset import dataset
19 | from grain._src.python.dataset import stats
20 |
21 | Element = Any
22 | T = TypeVar("T") # pylint: disable=invalid-name
23 |
24 |
25 | class _LimitDatasetIterator(dataset.DatasetIterator[T]):
26 | """Iterator that limits the number of elements in the dataset."""
27 |
28 | def __init__(
29 | self,
30 | parent: dataset.DatasetIterator[T],
31 | count: int,
32 | ):
33 | super().__init__(parent)
34 | self._count = count
35 | self._count_elements_read = 0
36 |
37 | @stats.record_next_duration_if_output
38 | def __next__(self):
39 | if self._count_elements_read >= self._count:
40 | raise StopIteration
41 | value = next(self._parent)
42 | self._count_elements_read += 1
43 | return value
44 |
45 | def get_state(self):
46 | return {
47 | "parent": self._parent.get_state(),
48 | "count_elements_read": self._count_elements_read,
49 | }
50 |
51 | def set_state(self, state):
52 | self._parent.set_state(state["parent"])
53 | self._count_elements_read = state["count_elements_read"]
54 |
55 |
56 | class LimitIterDataset(dataset.IterDataset[T]):
57 | """Limits the number of elements in the dataset.
58 |
59 | Example usage:
60 |
61 | ```
62 | list(LimitIterDataset(MapDataset.range(5).to_iter_dataset(), 2) == [0, 1]
63 | ```
64 |
65 | Attributes:
66 | parent: The dataset to limit.
67 | count: The maximum number of elements to include in the dataset.
68 | """
69 |
70 | def __init__(
71 | self,
72 | parent: dataset.IterDataset[T],
73 | count: int,
74 | ):
75 | """Initializes the limit dataset."""
76 | if count <= 0:
77 | raise ValueError(f"Count must be a non-negative integer. Got {count}")
78 | super().__init__(parent)
79 | self._count = count
80 |
81 | def __iter__(self) -> _LimitDatasetIterator[T]:
82 | parent_iter = self._parent.__iter__()
83 | return _LimitDatasetIterator(parent_iter, self._count)
84 |
85 | def __str__(self) -> str:
86 | return f"LimitIterDataset(count={self._count})"
87 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/limit_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for limit transformations."""
15 |
16 | from absl.testing import absltest
17 | from absl.testing import parameterized
18 | from grain._src.python.dataset import dataset
19 | from grain._src.python.dataset.transformations import limit
20 | import grain._src.python.testing.experimental as test_util
21 |
22 |
23 | class LimitIterDatasetTest(parameterized.TestCase):
24 |
25 | @parameterized.parameters([0, -1, -5])
26 | def test_non_positive_count_raises_error(self, count):
27 | ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
28 | with self.assertRaises(ValueError):
29 | _ = limit.LimitIterDataset(ds, count=count)
30 |
31 | def test_stop_iteration_raised_after_limit_reached(self):
32 | ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
33 | ds = limit.LimitIterDataset(ds, count=1)
34 | ds_iter = iter(ds)
35 | _ = next(ds_iter)
36 | with self.assertRaises(StopIteration):
37 | next(ds_iter)
38 |
39 | @parameterized.parameters([1, 3, 5, 7, 10])
40 | def test_count(self, count):
41 | ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
42 | ds = limit.LimitIterDataset(ds, count=count)
43 | actual_data = list(ds)
44 | self.assertLen(actual_data, count)
45 | self.assertEqual(actual_data, list(range(count)))
46 |
47 | def test_count_over_epochs(self):
48 | ds = dataset.MapDataset.range(0, 10).repeat(2).to_iter_dataset()
49 | ds = limit.LimitIterDataset(ds, count=15)
50 | actual_data = list(ds)
51 | self.assertLen(actual_data, 15)
52 | self.assertEqual(actual_data, list(range(10)) + list(range(5)))
53 |
54 | def test_limit_after_batch(self):
55 | def flatten_batches(batches):
56 | actual_data = []
57 | for batch in batches:
58 | actual_data.extend(batch.tolist())
59 | return actual_data
60 |
61 | ds = dataset.MapDataset.range(0, 10).batch(3).to_iter_dataset()
62 |
63 | ds_1 = limit.LimitIterDataset(ds, count=2)
64 | batches = list(ds_1)
65 | actual_data = flatten_batches(batches)
66 | self.assertEqual(actual_data, list(range(6)))
67 |
68 | ds_2 = limit.LimitIterDataset(ds, count=5)
69 | batches = list(ds_2)
70 | actual_data = flatten_batches(batches)
71 | self.assertLen(batches, 4)
72 | self.assertEqual(actual_data, list(range(10)))
73 |
74 | def test_checkpointing(self):
75 | ds = dataset.MapDataset.range(0, 10).batch(3).to_iter_dataset()
76 | limited_ds = limit.LimitIterDataset(ds, count=2)
77 | test_util.assert_equal_output_after_checkpoint(limited_ds)
78 |
79 |
80 | if __name__ == "__main__":
81 | absltest.main()
82 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/packing_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for batch transformation."""
15 |
16 | from absl.testing import absltest
17 | from grain._src.python.dataset.transformations import testing_util
18 |
19 |
20 | class FirstFitPackIterDatasetTest(testing_util.BaseFirstFitPackIterDatasetTest):
21 |
22 | def setUp(self):
23 | super().setUp()
24 | self.kwargs = {}
25 |
26 | if __name__ == "__main__":
27 | absltest.main()
28 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/repeat.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Implements repeat transformation."""
15 | import sys
16 | from typing import Optional, TypeVar
17 |
18 | from grain._src.python.dataset import dataset
19 |
20 | T = TypeVar("T")
21 |
22 |
23 | class RepeatMapDataset(dataset.MapDataset[T]):
24 | """Repeats the underlying dataset for num_epochs.
25 |
26 | This effectively just changes the length, which indicates the size of a single
27 | epoch, of the dataset. This makes it easier to iterate for a fixed number
28 | of steps.
29 | """
30 |
31 | _MUTATES_ELEMENT_SPEC = False
32 |
33 | def __init__(
34 | self,
35 | parent: dataset.MapDataset[T],
36 | num_epochs: Optional[int] = None,
37 | ):
38 | super().__init__(parent)
39 | if num_epochs is not None and num_epochs <= 0:
40 | raise ValueError(f"num_epochs must be positive, but got {num_epochs}.")
41 | if len(parent) >= sys.maxsize:
42 | raise ValueError(
43 | f"Repeating already infinite dataset {parent} does nothing."
44 | )
45 | self._num_epochs = num_epochs
46 | if num_epochs is None:
47 | if len(parent) == 0: # pylint: disable=g-explicit-length-test
48 | self._length: int = 0
49 | else:
50 | self._length: int = sys.maxsize
51 | else:
52 | self._length = num_epochs * len(parent)
53 |
54 | def __len__(self) -> int:
55 | return self._length
56 |
57 | def __str__(self) -> str:
58 | return f"RepeatMapDataset(num_epochs={self._num_epochs})"
59 |
60 | def __getitem__(self, index):
61 | if isinstance(index, slice):
62 | return self.slice(index)
63 | return self._stats.record_output_spec(self._parent[index])
64 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/repeat_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for repeat transformation."""
15 | import sys
16 |
17 | from absl.testing import absltest
18 | from grain._src.python.dataset import dataset
19 | from grain._src.python.dataset.transformations import repeat
20 | from typing_extensions import override
21 |
22 |
23 | class EmptyMapDataset(dataset.MapDataset[int]):
24 |
25 | def __init__(self):
26 | super().__init__(parents=[])
27 |
28 | @override
29 | def __len__(self) -> int:
30 | return 0
31 |
32 | @override
33 | def __getitem__(self, index):
34 | raise IndexError("Index out of range")
35 |
36 |
37 | class RepeatMapDatasetTest(absltest.TestCase):
38 |
39 | def test_finite_num_epochs_changes_length(self):
40 | ds = dataset.MapDataset.range(6)
41 | self.assertLen(ds, 6)
42 | ds = repeat.RepeatMapDataset(ds, num_epochs=3)
43 | self.assertLen(ds, 18)
44 |
45 | def test_finite_num_epochs_produces_expected_elements_when_iterated(self):
46 | ds = dataset.MapDataset.range(4)
47 | ds = repeat.RepeatMapDataset(ds, num_epochs=3)
48 | self.assertSequenceEqual(list(ds), [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3])
49 |
50 | def test_infinite_epochs_sets_length_to_maxsize(self):
51 | ds = dataset.MapDataset.range(6)
52 | ds = repeat.RepeatMapDataset(ds, num_epochs=None)
53 | self.assertLen(ds, sys.maxsize)
54 |
55 | def test_repeat_after_setting_infinite_epochs_raises_value_error(self):
56 | ds = dataset.MapDataset.range(6)
57 | ds = repeat.RepeatMapDataset(ds, num_epochs=None)
58 | with self.assertRaises(ValueError):
59 | repeat.RepeatMapDataset(ds, num_epochs=2)
60 |
61 | def test_setting_zero_epochs_raises_value_error(self):
62 | ds = dataset.MapDataset.range(6)
63 | with self.assertRaises(ValueError):
64 | repeat.RepeatMapDataset(ds, num_epochs=0)
65 |
66 | def test_setting_negative_epochs_raises_value_error(self):
67 | ds = dataset.MapDataset.range(6)
68 | with self.assertRaises(ValueError):
69 | repeat.RepeatMapDataset(ds, num_epochs=-1)
70 |
71 | def test_infinite_epochs_of_empty_dataset_keeps_length_zero(self):
72 | ds = EmptyMapDataset()
73 | ds = repeat.RepeatMapDataset(ds, num_epochs=None)
74 | self.assertEmpty(ds)
75 |
76 |
77 | if __name__ == "__main__":
78 | absltest.main()
79 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/slice.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Implements slice transformation."""
15 | from typing import TypeVar
16 |
17 | from grain._src.python.dataset import dataset
18 |
19 | T = TypeVar("T")
20 |
21 |
22 | class SliceMapDataset(dataset.MapDataset[T]):
23 | """Slices a MapDataset similar to the slicing syntax in Python."""
24 |
25 | _MUTATES_ELEMENT_SPEC = False
26 |
27 | def __init__(self, parent: dataset.MapDataset[T], sl: slice):
28 | super().__init__(parent)
29 | if not isinstance(sl, slice):
30 | raise ValueError(f"sl is not a slice: {type(sl)}")
31 | self._start, self._stop, self._step = sl.indices(len(parent))
32 | self._length = len(range(self._start, self._stop, self._step))
33 |
34 | def __len__(self) -> int:
35 | return self._length
36 |
37 | def __getitem__(self, index):
38 | if isinstance(index, slice):
39 | return SliceMapDataset(self, index)
40 | with self._stats.record_self_time():
41 | parent_index = self._start + (index % len(self)) * self._step
42 | return self._parent[parent_index]
43 |
44 | def __str__(self) -> str:
45 | return f"SliceMapDataset[{self._start}:{self._stop}:{self._step}]"
46 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/slice_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for slice transformation."""
15 |
16 | import itertools
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from grain._src.python.dataset import dataset
21 | import grain._src.python.dataset.transformations.slice as slice_ds
22 | from typing_extensions import override
23 |
24 |
25 | class EmptyMapDataset(dataset.MapDataset[int]):
26 |
27 | def __init__(self):
28 | super().__init__(parents=[])
29 |
30 | @override
31 | def __len__(self) -> int:
32 | return 0
33 |
34 | @override
35 | def __getitem__(self, index):
36 | raise IndexError("Index out of range")
37 |
38 |
39 | class SliceMapDatasetTest(parameterized.TestCase):
40 |
41 | @parameterized.parameters(
42 | (0, 1, 20),
43 | (0, 2, 10),
44 | (1, 2, 10),
45 | (0, 3, 7),
46 | (1, 3, 7),
47 | (2, 3, 6),
48 | (30, 100, 0),
49 | )
50 | def test_len(self, start: int, step: int, expected_len: int):
51 | ds = dataset.MapDataset.range(20)
52 | sl = slice(start, 20, step)
53 | range_ds_for_process = slice_ds.SliceMapDataset(ds, sl)
54 | self.assertLen(range_ds_for_process, expected_len)
55 |
56 | @parameterized.parameters(
57 | itertools.product(range(-8, 8), range(-9, 8), [-2, -1, 1, 2])
58 | )
59 | def test_getitem(self, start: int, stop: int, step: int):
60 | ds = dataset.MapDataset.range(20)
61 | ds = slice_ds.SliceMapDataset(ds, slice(start, stop, step))
62 | ds_items = [ds[i] for i in range(len(ds))]
63 | self.assertSequenceEqual(ds_items, list(range(20))[start:stop:step])
64 |
65 | @parameterized.parameters(
66 | itertools.product(range(-8, 8), range(-9, 8), [-2, -1, 1, 2])
67 | )
68 | def test_getitem_slice(self, start: int, stop: int, step: int):
69 | ds = dataset.MapDataset.range(20)
70 | ds = ds[start:stop:step]
71 | ds_items = [ds[i] for i in range(len(ds))]
72 | self.assertSequenceEqual(ds_items, list(range(20))[start:stop:step])
73 |
74 | @parameterized.parameters(
75 | itertools.product(range(-8, 8), range(-9, 8), [-2, -1, 1, 2])
76 | )
77 | def test_iter(self, start: int, stop: int, step: int):
78 | ds = dataset.MapDataset.range(20)
79 | ds = slice_ds.SliceMapDataset(ds, slice(start, stop, step))
80 | ds_iter = iter(ds)
81 | ds_items = list(ds_iter)
82 | self.assertSequenceEqual(ds_items, list(range(20))[start:stop:step])
83 |
84 | def test_slice_of_empty_dataset_is_empty(self):
85 | ds = EmptyMapDataset()
86 | ds = slice_ds.SliceMapDataset(ds, slice(0, 10))
87 | self.assertEmpty(ds)
88 |
89 | def test_accessing_items_beyond_len_minus_one_succeeds(self):
90 | ds = dataset.MapDataset.range(20)
91 | ds = slice_ds.SliceMapDataset(ds, slice(5)) # 0, 1, 2, 3, 4
92 | self.assertLen(ds, 5)
93 | self.assertEqual(ds[5], 0)
94 | self.assertEqual(ds[13], 3)
95 | self.assertEqual(ds[42], 2)
96 |
97 | def test_composing_slices_contains_correct_elements(self):
98 | ds = dataset.MapDataset.range(20)
99 | ds = slice_ds.SliceMapDataset(ds, slice(0, 15, 3)) # 0, 3, 6, 9, 12
100 | ds = slice_ds.SliceMapDataset(ds, slice(0, 20, 2)) # 0, 6, 12
101 | self.assertSequenceEqual(list(ds), [0, 6, 12])
102 |
103 |
104 | if __name__ == "__main__":
105 | absltest.main()
106 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/source.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """LazyDataset data sources."""
15 | from __future__ import annotations
16 |
17 | import functools
18 | from typing import Sequence, Union
19 |
20 | from absl import logging
21 | from grain._src.python import options
22 | from grain._src.python.dataset import base
23 | from grain._src.python.dataset import dataset
24 |
25 |
26 | class SourceMapDataset(dataset.MapDataset):
27 | """Simple wrapper for random access data sources."""
28 |
29 | def __init__(self, source: base.RandomAccessDataSource):
30 | super().__init__()
31 | self._source = source
32 |
33 | def __len__(self) -> int:
34 | return len(self._source)
35 |
36 | def __str__(self) -> str:
37 | return f"SourceMapDataset(source={self._source.__class__.__name__})"
38 |
39 | def __getitem__(self, index):
40 | if isinstance(index, slice):
41 | return self.slice(index)
42 | with self._stats.record_self_time():
43 | return self._stats.record_output_spec(self._source[index % len(self)])
44 |
45 | def log_lineage(self):
46 | pass
47 |
48 | @property
49 | def paths(self) -> str | Sequence[str]:
50 | if hasattr(self._source, "paths"):
51 | assert isinstance(self._source, base.RandomAccessDataSource)
52 | return self._source.paths
53 | else:
54 | return []
55 |
56 |
57 | def log_lineage_for_sources(
58 | root: Union[dataset.MapDataset, dataset.IterDataset],
59 | ):
60 | """Traverses tree of transformations and logs lineage on source datasets."""
61 | pass
62 |
63 |
64 | class RangeMapDataset(dataset.MapDataset[int]):
65 | """Range data source, similar to python range() function."""
66 |
67 | def __init__(self, start: int, stop: int | None = None, step: int = 1):
68 | super().__init__()
69 | self.start = 0 if stop is None else start
70 | self.stop = start if stop is None else stop
71 | self.step = step
72 |
73 | @functools.cached_property
74 | def _length(self) -> int:
75 | return len(range(self.start, self.stop, self.step))
76 |
77 | def __len__(self) -> int:
78 | return self._length
79 |
80 | def __str__(self) -> str:
81 | return (
82 | f"RangeMapDataset(start={self.start}, stop={self.stop},"
83 | f" step={self.step})"
84 | )
85 |
86 | def __getitem__(self, index):
87 | if isinstance(index, slice):
88 | return self.slice(index)
89 | with self._stats.record_self_time():
90 | return self._stats.record_output_spec(
91 | self.start + (index % self._length) * self.step
92 | )
93 |
94 | def to_iter_dataset(
95 | self,
96 | read_options: options.ReadOptions | None = None,
97 | *,
98 | allow_nones: bool = False,
99 | ) -> dataset.IterDataset[int]:
100 | # Override the default multithreaded execution to avoid wasting memory.
101 | # The prefetch is not necessary since there's no IO.
102 | return super().to_iter_dataset(
103 | read_options=(
104 | read_options or options.ReadOptions(prefetch_buffer_size=0)
105 | ),
106 | allow_nones=allow_nones,
107 | )
108 |
--------------------------------------------------------------------------------
/grain/_src/python/dataset/transformations/source_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for LazyDataset data sources."""
15 | import random
16 | from unittest import mock
17 |
18 | from absl.testing import absltest
19 | from grain._src.python.dataset import dataset
20 | from grain._src.python.dataset.transformations import source
21 |
22 |
23 | class _Interleave(dataset.MapDataset):
24 |
25 | def __len__(self):
26 | return sum((len(p) for p in self.parents))
27 |
28 | def __getitem__(self, index):
29 | index, parent_index = divmod(index, len(self.parents))
30 | return self.parents[parent_index][index]
31 |
32 |
33 | class SourceMapDatasetTest(absltest.TestCase):
34 |
35 | def setUp(self):
36 | super().setUp()
37 | self.sample_data_source = [1, 2, 3, 4, 5]
38 | self.lazy_dataset_source = source.SourceMapDataset( # pytype: disable=wrong-arg-types
39 | self.sample_data_source
40 | )
41 |
42 | def test_lazy_dataset_source_len(self):
43 | self.assertLen(self.lazy_dataset_source, 5)
44 |
45 | def test_lazy_dataset_source_sequential_get(self):
46 | indices_to_read = [0, 1, 2, 3, 4]
47 | expected_data = [1, 2, 3, 4, 5]
48 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read]
49 | self.assertEqual(expected_data, actual_data)
50 |
51 | def test_lazy_dataset_source_reverse_sequential_get(self):
52 | indices_to_read = [0, 1, 2, 3, 4]
53 | expected_data = [1, 2, 3, 4, 5]
54 | indices_to_read.reverse()
55 | expected_data.reverse()
56 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read]
57 | self.assertEqual(expected_data, actual_data)
58 |
59 | def test_lazy_dataset_source_random_get(self):
60 | indices_to_read = [0, 1, 2, 3, 4]
61 | random.shuffle(indices_to_read)
62 | expected_data = [self.sample_data_source[i] for i in indices_to_read]
63 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read]
64 | self.assertEqual(expected_data, actual_data)
65 |
66 | def test_lazy_dataset_source_random_modulo_get(self):
67 | len_data_source = len(self.lazy_dataset_source)
68 | indices_to_read = [100, 207, 303, 401]
69 | expected_data = [
70 | self.sample_data_source[i % len_data_source] for i in indices_to_read
71 | ]
72 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read]
73 | self.assertEqual(expected_data, actual_data)
74 |
75 |
76 | class RangeMapDatasetTest(absltest.TestCase):
77 |
78 | def test_len(self):
79 | ds = source.RangeMapDataset(12)
80 | self.assertLen(ds, 12)
81 | ds = source.RangeMapDataset(0, 12)
82 | self.assertLen(ds, 12)
83 | ds = source.RangeMapDataset(2, 12)
84 | self.assertLen(ds, 10)
85 | ds = source.RangeMapDataset(2, 12, 1)
86 | self.assertLen(ds, 10)
87 | ds = source.RangeMapDataset(2, 12, 2)
88 | self.assertLen(ds, 5)
89 | ds = source.RangeMapDataset(2, 13, 2)
90 | self.assertLen(ds, 6)
91 |
92 | def test_getitem(self):
93 | ds = source.RangeMapDataset(12)
94 | for i in range(12):
95 | self.assertEqual(ds[i], i)
96 | for i in range(12):
97 | self.assertEqual(ds[i + 12], i)
98 | ds = source.RangeMapDataset(2, 9, 2)
99 | self.assertEqual(ds[0], 2)
100 | self.assertEqual(ds[1], 4)
101 | self.assertEqual(ds[2], 6)
102 | self.assertEqual(ds[3], 8)
103 | self.assertEqual(ds[4], 2)
104 | self.assertEqual(ds[5], 4)
105 |
106 | def test_iter(self):
107 | ds = source.RangeMapDataset(12)
108 | ds_iter = iter(ds)
109 | elements = [next(ds_iter) for _ in range(12)]
110 | self.assertEqual(elements, list(range(12)))
111 | ds = source.RangeMapDataset(2, 9, 2)
112 | ds_iter = iter(ds)
113 | elements = [next(ds_iter) for _ in range(4)]
114 | self.assertEqual(elements, [2, 4, 6, 8])
115 |
116 |
117 | if __name__ == "__main__":
118 | absltest.main()
119 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/example_packing/BUILD:
--------------------------------------------------------------------------------
1 | # Experimental transformation for example packing in PyGrain.
2 |
3 | package(default_visibility = ["//grain:__subpackages__"])
4 |
5 | py_library(
6 | name = "packing",
7 | srcs = ["packing.py"],
8 | srcs_version = "PY3",
9 | deps = [
10 | "//grain/_src/core:tree_lib",
11 | "//grain/_src/python:record",
12 | "@abseil-py//absl/logging",
13 | "@pypi//numpy:pkg",
14 | ],
15 | )
16 |
17 | py_test(
18 | name = "packing_test",
19 | srcs = ["packing_test.py"],
20 | srcs_version = "PY3",
21 | deps = [
22 | ":packing",
23 | "//grain/_src/core:tree_lib",
24 | "//grain/_src/python:record",
25 | "@abseil-py//absl/testing:absltest",
26 | "@pypi//jax:pkg",
27 | "@pypi//numpy:pkg",
28 | ],
29 | )
30 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/index_shuffle/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//grain:__subpackages__"])
2 |
3 | licenses(["notice"])
4 |
5 | cc_library(
6 | name = "index_shuffle",
7 | srcs = ["index_shuffle.cc"],
8 | hdrs = ["index_shuffle.h"],
9 | )
10 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/index_shuffle/index_shuffle.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2023 Google LLC. All Rights Reserved.
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | ==============================================================================*/
15 |
16 | #ifndef GRAIN_RANDOM_RANDOM_INDEX_SHUFFLE_H_
17 | #define GRAIN_RANDOM_RANDOM_INDEX_SHUFFLE_H_
18 |
19 | #include
20 | #include
21 |
22 | namespace grain {
23 | namespace random {
24 |
25 | // Returns the position of `index` in a permutation of [0, ..., max_index].
26 | //
27 | // Index must be number in [0, ..., max_index].
28 | // Key is the random key for the permutation.
29 | // The returned index will also be in [0, ..., max_index]. For a fixed `key`
30 | // and `max_index` all the possible `index` values and the returned values form
31 | // a bijection.
32 | // Rounds must be a positive even integer >= 4. Larger values improve
33 | // 'randomness' of permutations for small `max_index` values. The time to
34 | // compute the result scales linearly with the number of rounds. We recommend 8
35 | // rounds for a good trade off.
36 | //
37 | // For more details on the algorithm see the top of the cc file.
38 | uint64_t index_shuffle(uint64_t index, uint64_t max_index, uint32_t seed,
39 | uint32_t rounds);
40 |
41 | } // namespace random
42 | } // namespace grain
43 |
44 | #endif // GRAIN_RANDOM_RANDOM_INDEX_SHUFFLE_H_
45 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/index_shuffle/python/BUILD:
--------------------------------------------------------------------------------
1 | load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
2 |
3 | package(default_visibility = ["//grain:__subpackages__"])
4 |
5 | licenses(["notice"])
6 |
7 | pybind_extension(
8 | name = "index_shuffle_module",
9 | srcs = ["index_shuffle_module.cc"],
10 | deps = [
11 | "//grain/_src/python/experimental/index_shuffle",
12 | ],
13 | )
14 |
15 | py_test(
16 | name = "index_shuffle_test",
17 | srcs = ["index_shuffle_test.py"],
18 | data = [":index_shuffle_module.so"],
19 | srcs_version = "PY3",
20 | deps = ["@abseil-py//absl/testing:absltest"],
21 | )
22 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/index_shuffle/python/index_shuffle_module.cc:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include "grain/_src/python/experimental/index_shuffle/index_shuffle.h"
4 |
5 | namespace py = pybind11;
6 |
7 | PYBIND11_MODULE(index_shuffle_module, m) {
8 | constexpr char kDoc[] =
9 | "Returns the position of `index` in a permutation of [0, ..., "
10 | "max_index].";
11 | m.doc() = kDoc;
12 | m.def("index_shuffle", &::grain::random::index_shuffle, kDoc,
13 | py::arg("index"), py::arg("max_index"), py::arg("seed"),
14 | py::arg("rounds"));
15 | }
16 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/index_shuffle/python/index_shuffle_python.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Pure Python version of `index_shuffle`.
15 |
16 | This is roughly 10x slower than the C++ index_shuffle but still sufficiently
17 | fast for many use cases. Use it if the C++ version (and it's CLIF wrapper) don't
18 | work for you.
19 | """
20 |
21 | import hashlib
22 |
23 |
24 | def _fingerprint(*args) -> int:
25 | """A 128-bit fingerprint based on md5.
26 |
27 | For data shuffling - not for cryptography.
28 |
29 | Args:
30 | *args: any argument list that can be converted to a string
31 |
32 | Returns:
33 | an integer in [0, 2 ** 128)
34 | """
35 | return int.from_bytes(hashlib.md5(str(args).encode()).digest(), "little")
36 |
37 |
38 | def index_shuffle(index: int, max_index: int, seed: int, rounds: int) -> int:
39 | """computes the position of `index` after a pseudorandom permutation on `[0, max_index])`.
40 |
41 | Based on Feistel ciphers.
42 |
43 | For data shuffling - not for cryptography.
44 |
45 | if i != j, then
46 | pseudorandom_permutation(n, i, seed) != pseudorandom_permutation(n, j, seed)
47 |
48 | Args:
49 | index: an integer in [0, max_index)
50 | max_index: A positive integer.
51 | seed: A posivtive integer used as seed for the pseudorandom permutation.
52 | rounds: Ignored. For compatibility with C++ version.
53 |
54 | Returns:
55 | An integer in [0, max_index].
56 | """
57 | del rounds
58 | if not isinstance(max_index, int):
59 | raise ValueError("n must be an integer")
60 |
61 | if index < 0 or index > max_index:
62 | raise ValueError("out of range")
63 |
64 | if max_index == 1:
65 | return 0
66 |
67 | # smallest k such that max_index fits in 2k bits
68 | k = (max_index.bit_length() + 1) // 2
69 | assert max_index <= 4**k
70 | # Permute repeatedly in [max_index, 4 ** k) until you land back in
71 | # [0, max_index]. This constitutes a permutation of [0, max_index].
72 | while True:
73 | # Feistel ciper on 2k bits - i.e. a permutation of [0, 4 ** k)
74 | a, b = index // (2**k), index % (2**k)
75 | for r in range(3):
76 | a, b = b, a ^ (_fingerprint(b, r, seed) % (2**k))
77 | index = a * (2**k) + b
78 | if index <= max_index:
79 | return int(index)
80 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/index_shuffle/python/index_shuffle_python_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Minimal unit test for the Python wrapper of index_shuffle."""
15 |
16 | from absl.testing import absltest
17 | from grain._src.python.experimental.index_shuffle.python import index_shuffle_python
18 |
19 |
20 | class IndexShuffleTest(absltest.TestCase):
21 |
22 | def test_index_shuffle(self):
23 | max_index = 46_204
24 | seen = set()
25 | for x in range(max_index + 1):
26 | y = index_shuffle_python.index_shuffle(x, max_index, seed=52, rounds=4)
27 | self.assertBetween(y, 0, max_index)
28 | seen.add(y)
29 | self.assertLen(seen, max_index + 1)
30 |
31 | def test_index_shuffle_huge_number(self):
32 | max_index = 1_234_567_891
33 | seen = set()
34 | for x in range(10_000):
35 | y = index_shuffle_python.index_shuffle(x, max_index, seed=27, rounds=4)
36 | self.assertBetween(y, 0, max_index)
37 | seen.add(y)
38 | self.assertLen(seen, 10_000)
39 |
40 | def test_index_shuffle_single_record(self):
41 | self.assertEqual(
42 | 0,
43 | index_shuffle_python.index_shuffle(
44 | index=0, max_index=0, seed=0, rounds=4
45 | ),
46 | )
47 |
48 |
49 | if __name__ == '__main__':
50 | absltest.main()
51 |
--------------------------------------------------------------------------------
/grain/_src/python/experimental/index_shuffle/python/index_shuffle_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Minimal unit test for the Python wrapper of index_shuffle."""
15 |
16 | from absl.testing import absltest
17 | from grain._src.python.experimental.index_shuffle.python import index_shuffle_module as index_shuffle
18 |
19 |
20 | class IndexShuffleTest(absltest.TestCase):
21 |
22 | def test_index_shuffle(self):
23 | max_index = 46_204
24 | seen = set()
25 | for x in range(max_index + 1):
26 | y = index_shuffle.index_shuffle(x, max_index, seed=52, rounds=4)
27 | self.assertBetween(y, 0, max_index)
28 | seen.add(y)
29 | self.assertLen(seen, max_index + 1)
30 |
31 | def test_index_shuffle_huge_number(self):
32 | max_index = 1_234_567_891
33 | seen = set()
34 | for x in range(10_000):
35 | y = index_shuffle.index_shuffle(x, max_index, seed=27, rounds=4)
36 | self.assertBetween(y, 0, max_index)
37 | seen.add(y)
38 | self.assertLen(seen, 10_000)
39 |
40 | def test_index_shuffle_single_record(self):
41 | self.assertEqual(
42 | 0, index_shuffle.index_shuffle(index=0, max_index=0, seed=0, rounds=4)
43 | )
44 |
45 |
46 | if __name__ == '__main__':
47 | absltest.main()
48 |
--------------------------------------------------------------------------------
/grain/_src/python/grain_logging.py:
--------------------------------------------------------------------------------
1 | """A library for adding a custom identifier to Python log messages.
2 |
3 | It allows adding a prefix identifying the process generating the log statement.
4 | The main purpose is to make logs more readable when using multiprocessing.
5 | """
6 |
7 | import logging
8 | from absl import logging as absl_logging
9 |
10 |
11 | # Adds a prefix containing the `identifier` to all new Python log messages.
12 | def set_process_identifier_prefix(identifier: str) -> None:
13 | log_formatter = logging.Formatter(f'[{identifier}] %(message)s')
14 | absl_logging.get_absl_handler().setFormatter(log_formatter)
15 |
--------------------------------------------------------------------------------
/grain/_src/python/grain_logging_test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 | from absl import logging as absl_logging
4 | from grain._src.python import grain_logging
5 | from absl.testing import absltest
6 |
7 |
8 | class GrainLoggingTest(absltest.TestCase):
9 |
10 | def test_prefix_is_part_of_message(self):
11 | # self.assertLogs() doesn't format the messages, so we have to resort to
12 | # formatting directly with the absl handler to test whether the
13 | # prefix is being added.
14 | grain_logging.set_process_identifier_prefix('foo prefix')
15 | with self.assertLogs() as cm:
16 | absl_logging.info('example message')
17 | self.assertLen(cm.records, 1)
18 | log_record = cm.records[0]
19 | self.assertIn(
20 | 'foo prefix', absl_logging.get_absl_handler().format(log_record)
21 | )
22 |
23 | def test_message_is_kept(self):
24 | grain_logging.set_process_identifier_prefix('Foo')
25 | with self.assertLogs() as cm:
26 | absl_logging.info('some info message: %i', 1337)
27 | self.assertLen(cm.output, 1)
28 | self.assertIn('some info message: 1337', cm.output[0])
29 |
30 | def test_message_formatting(self):
31 | log_record = logging.LogRecord(
32 | name=absl_logging.get_absl_logger().name,
33 | level=logging.INFO,
34 | pathname='file.cc',
35 | lineno=42,
36 | msg='some info message: %i',
37 | args=(789,),
38 | exc_info=None,
39 | )
40 | grain_logging.set_process_identifier_prefix('FooBarBaz 123')
41 | # We don't want to enforce a specific prefix format, but we want to avoid
42 | # duplicating the absl prefix (e.g.:
43 | # I0814 11:48:49.888083 1726756 grain_pool.py:161
44 | # ).
45 | self.assertTrue(
46 | re.search(
47 | r'.{0,5}FooBarBaz 123.{0,5} some info message: 789',
48 | absl_logging.get_absl_handler().format(log_record),
49 | )
50 | )
51 |
52 |
53 | if __name__ == '__main__':
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/grain/_src/python/load.py:
--------------------------------------------------------------------------------
1 | """High level APIs that serve as a single endpoint for very common use cases."""
2 |
3 | from typing import Optional
4 |
5 | from grain._src.core import monitoring as grain_monitoring
6 | from grain._src.core import sharding
7 | from grain._src.core import transforms
8 | from grain._src.core import usage_logging
9 | from grain._src.python import data_loader
10 | from grain._src.python import data_sources
11 | from grain._src.python import options
12 | from grain._src.python import samplers
13 |
14 | from grain._src.core import monitoring
15 |
16 |
17 | _api_usage_counter = monitoring.Counter(
18 | "/grain/python/load/api",
19 | monitoring.Metadata(description="API initialization counter."),
20 | root=grain_monitoring.get_monitoring_root(),
21 | fields=[("name", str)],
22 | )
23 |
24 |
25 | def load(
26 | source: data_sources.RandomAccessDataSource,
27 | *,
28 | num_epochs: Optional[int] = None,
29 | shuffle: bool = False,
30 | seed: Optional[int] = None,
31 | shard_options: sharding.ShardOptions = sharding.NoSharding(),
32 | transformations: transforms.Transformations = (),
33 | batch_size: Optional[int] = None,
34 | drop_remainder: bool = False,
35 | worker_count: Optional[int] = 0,
36 | read_options: Optional[options.ReadOptions] = None,
37 | ) -> data_loader.DataLoader:
38 | """Convenient method for simple pipelines on top of a data source.
39 |
40 | Args:
41 | source: Data source to load from. This can be one of the file data sources
42 | provided by Grain, a TFDS data source (`tfds.data_source(...)`) or your
43 | custom data source.
44 | num_epochs: See IndexSampler.
45 | shuffle: See IndexSampler.
46 | seed: See IndexSampler.
47 | shard_options: See IndexSampler.
48 | transformations: List of local (stateless) transformations:
49 | batch_size: Optional batch size. If provided will apply BatchOperation().
50 | drop_remainder: Whether to drop partial batches.
51 | worker_count: Number of child processes launched to parallelize the
52 | transformations among. Zero means processing runs in the same process.
53 | read_options: Read options for the data loader. See ReadOptions.
54 |
55 | Returns:
56 | DataLoader for this dataset.
57 | """
58 | usage_logging.log_event("load", tag_3="PyGrain")
59 | _api_usage_counter.Increment("load")
60 | sampler = samplers.IndexSampler(
61 | num_records=len(source),
62 | shuffle=shuffle,
63 | seed=seed,
64 | num_epochs=num_epochs,
65 | shard_options=shard_options,
66 | )
67 | if batch_size is not None:
68 | transformations = list(transformations)
69 | transformations.append(
70 | transforms.Batch(batch_size, drop_remainder=drop_remainder)
71 | )
72 | return data_loader.DataLoader(
73 | data_source=source,
74 | sampler=sampler,
75 | operations=transformations,
76 | worker_count=worker_count,
77 | read_options=read_options,
78 | )
79 |
--------------------------------------------------------------------------------
/grain/_src/python/load_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for load()."""
15 |
16 | from absl import flags
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from grain._src.core import transforms
20 | import multiprocessing as mp
21 | from grain._src.python import data_sources
22 | from grain._src.python import load
23 | import numpy as np
24 |
25 |
26 | FLAGS = flags.FLAGS
27 |
28 |
29 | class FilterEven(transforms.Filter):
30 |
31 | def filter(self, x: int) -> bool:
32 | return x % 2 == 0
33 |
34 |
35 | class PlusOne(transforms.MapTransform):
36 |
37 | def map(self, x: int) -> int:
38 | return x + 1
39 |
40 |
41 | class DataLoaderTest(parameterized.TestCase):
42 |
43 | def test_with_range_source(self):
44 | range_data_source = data_sources.RangeDataSource(start=0, stop=8, step=1)
45 | transformations = [
46 | PlusOne(),
47 | FilterEven(),
48 | ]
49 | data_loader = load.load(
50 | range_data_source, transformations=transformations, num_epochs=1
51 | )
52 | expected = [2, 4, 6, 8]
53 | actual = list(data_loader)
54 | np.testing.assert_equal(actual, expected)
55 |
56 | def test_with_range_source_with_batch(self):
57 | range_data_source = data_sources.RangeDataSource(start=0, stop=8, step=1)
58 | transformations = [
59 | PlusOne(),
60 | FilterEven(),
61 | ]
62 | data_loader = load.load(
63 | range_data_source,
64 | transformations=transformations,
65 | batch_size=2,
66 | num_epochs=1,
67 | )
68 | expected = [np.array([2, 4]), np.array([6, 8])]
69 | actual = list(data_loader)
70 | np.testing.assert_equal(actual, expected)
71 |
72 |
73 | if __name__ == "__main__":
74 | absltest.main()
75 |
--------------------------------------------------------------------------------
/grain/_src/python/multiprocessing_common.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """This module defines common functions for multiprocessing/threading."""
15 |
16 | import dataclasses
17 | import multiprocessing
18 | from multiprocessing import pool
19 | import queue
20 | from typing import TypeVar, Union, Callable
21 |
22 | T = TypeVar('T')
23 |
24 | _QUEUE_WAIT_TIMEOUT_SECONDS = 0.5
25 | _ASYNC_RESULT_WAIT_TIMEOUT_SECONDS = 0.5
26 |
27 |
28 | @dataclasses.dataclass
29 | class _SystemTerminated:
30 | """When system terminates, this is returned instead of actual elements."""
31 |
32 |
33 | SYSTEM_TERMINATED = _SystemTerminated()
34 |
35 |
36 | def add_element_to_queue(
37 | element: T,
38 | elements_queue: queue.Queue[T],
39 | should_stop: Callable[[], bool],
40 | ) -> bool:
41 | """Try adding element to queue as long as should_stop() is not True.
42 |
43 | Args:
44 | element: Element to add.
45 | elements_queue: Target queue.
46 | should_stop: Callable to check whether addition should proceed (possibly
47 | after a re-try).
48 |
49 | Returns:
50 | Bool indicating whether addition was successfull.
51 | """
52 | while not should_stop():
53 | try:
54 | elements_queue.put(element, timeout=_QUEUE_WAIT_TIMEOUT_SECONDS)
55 | return True
56 | except queue.Full:
57 | pass
58 | return False
59 |
60 |
61 | def get_element_from_queue(
62 | elements_queue: queue.Queue[T],
63 | should_stop: Callable[[], bool],
64 | ) -> Union[T, _SystemTerminated]:
65 | """Try getting element from queue as long as should_stop() is not True."""
66 | while not should_stop():
67 | try:
68 | return elements_queue.get(timeout=_QUEUE_WAIT_TIMEOUT_SECONDS)
69 | except queue.Empty:
70 | pass
71 | return SYSTEM_TERMINATED
72 |
73 |
74 | def get_async_result(
75 | async_result: pool.AsyncResult[T],
76 | should_stop: Callable[[], bool],
77 | ) -> Union[T, _SystemTerminated]:
78 | """Wait for async result as long as should_stop() is not True."""
79 | while not should_stop():
80 | try:
81 | return async_result.get(timeout=_ASYNC_RESULT_WAIT_TIMEOUT_SECONDS)
82 | except multiprocessing.TimeoutError:
83 | pass
84 | return SYSTEM_TERMINATED
85 |
--------------------------------------------------------------------------------
/grain/_src/python/multiprocessing_common_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for multiprocessing common functions."""
15 |
16 | import multiprocessing
17 | from multiprocessing import pool
18 | import queue
19 |
20 | from absl.testing import absltest
21 | from grain._src.python import multiprocessing_common
22 |
23 |
24 | class MultiProcessingCommonTest(absltest.TestCase):
25 |
26 | def test_add_element_to_queue(self):
27 | test_queue = multiprocessing.Queue()
28 | element = 1
29 | termination_event = multiprocessing.Event()
30 | self.assertTrue(
31 | multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
32 | element=element,
33 | elements_queue=test_queue,
34 | should_stop=termination_event.is_set,
35 | )
36 | )
37 | self.assertEqual(test_queue.get(), 1)
38 |
39 | def test_add_element_to_queue_already_terminated(self):
40 | test_queue = multiprocessing.Queue()
41 | element = 1
42 | termination_event = multiprocessing.Event()
43 | termination_event.set()
44 | self.assertFalse(
45 | multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
46 | element=element,
47 | elements_queue=test_queue,
48 | should_stop=termination_event.is_set,
49 | )
50 | )
51 | with self.assertRaises(queue.Empty):
52 | test_queue.get(timeout=0.1)
53 |
54 | def test_get_element_from_queue(self):
55 | test_queue = multiprocessing.Queue()
56 | expected_element = 1
57 | test_queue.put(expected_element)
58 | termination_event = multiprocessing.Event()
59 | actual_element = multiprocessing_common.get_element_from_queue( # pytype: disable=wrong-arg-types
60 | elements_queue=test_queue,
61 | should_stop=termination_event.is_set,
62 | )
63 | self.assertEqual(actual_element, expected_element)
64 |
65 | def test_get_element_from_queue_already_terminated(self):
66 | test_queue = multiprocessing.Queue()
67 | expected_element = 1
68 | test_queue.put(expected_element)
69 | termination_event = multiprocessing.Event()
70 | termination_event.set()
71 | actual_element = multiprocessing_common.get_element_from_queue( # pytype: disable=wrong-arg-types
72 | elements_queue=test_queue,
73 | should_stop=termination_event.is_set,
74 | )
75 | self.assertEqual(actual_element, multiprocessing_common.SYSTEM_TERMINATED)
76 |
77 | def test_get_async_result(self):
78 | thread_pool = pool.ThreadPool(1)
79 | async_result = thread_pool.apply_async(func=lambda x: x + 1, args=(1,))
80 | termination_event = multiprocessing.Event()
81 | result = multiprocessing_common.get_async_result(
82 | should_stop=termination_event.is_set, async_result=async_result
83 | )
84 | self.assertEqual(result, 2)
85 |
86 | def test_get_async_result_already_terminated(self):
87 | thread_pool = pool.ThreadPool(1)
88 | async_result = thread_pool.apply_async(func=lambda x: x + 1, args=(1,))
89 | termination_event = multiprocessing.Event()
90 | termination_event.set()
91 | result = multiprocessing_common.get_async_result(
92 | should_stop=termination_event.is_set, async_result=async_result
93 | )
94 | self.assertEqual(result, multiprocessing_common.SYSTEM_TERMINATED)
95 |
96 |
97 | if __name__ == "__main__":
98 | absltest.main()
99 |
--------------------------------------------------------------------------------
/grain/_src/python/options.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Dataclasses for holdings options."""
15 | import dataclasses
16 |
17 |
18 | @dataclasses.dataclass(slots=True)
19 | class ReadOptions:
20 | """Options for reading data from the DataSource.
21 |
22 | These settings configure a single Python process. Each process uses separate
23 | threads and buffer for reading and processing data.
24 |
25 | Example: With ReadOptions.num_threads=8 and
26 | MultiprocessingOptions.num_workers=10 there will be 80 threads reading the
27 | data (8 threads in each of 10 Python processes).
28 |
29 | Attributes:
30 | num_threads: Number of threads reading from the DataSource in parallel. If
31 | the data are already loaded in memory, we recommend setting this to 0 to
32 | avoid Python GIL contention by multiple threads.
33 | prefetch_buffer_size: Size of the buffer for reading elements per Python
34 | process (not per thread). Useful when reading from a distributed file
35 | system.
36 | """
37 |
38 | # The current default values where chosen by running a few selected
39 | # benchmarks reading from remote hard drives.
40 | # These values should work well for datasets with elements between 1 and
41 | # 10 KiB on disk.
42 | num_threads: int = 16
43 | prefetch_buffer_size: int = 500
44 |
45 |
46 | @dataclasses.dataclass(slots=True)
47 | class MultiprocessingOptions:
48 | """Options for using Python multiprocessing.
49 |
50 | Attributes:
51 | num_workers: Number of Python worker processes. More processes can speed up
52 | the pipeline if it's compute bound and bottlenecked on the CPython's GIL.
53 | The default value of 0 means no Python multiprocessing, and as a result
54 | all data loading and transformation will run in the main Python process.
55 | per_worker_buffer_size: Size of the buffer for preprocessed elements that
56 | each worker maintains. These are elements after all transformations. If
57 | your transformations include batching this means a single element is a
58 | batch.
59 | enable_profiling: If True, profiling info is logged. This is only available
60 | when num_workers >= 1.
61 | """
62 |
63 | num_workers: int = 0
64 | per_worker_buffer_size: int = 1
65 | enable_profiling: bool = False
66 |
--------------------------------------------------------------------------------
/grain/_src/python/record.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Define record class used by various modules in the Grain Python Backend."""
15 |
16 | import dataclasses
17 | from typing import Generic, Optional, TypeVar
18 | import numpy as np
19 |
20 | T = TypeVar("T")
21 |
22 |
23 | @dataclasses.dataclass(slots=True)
24 | class RecordMetadata:
25 | """RecordMetadata contains metadata about indidivual records.
26 |
27 | Metadata can be emitted by the sampler to refer to which record to read next.
28 | In addition, they are also used to keep information about records as they flow
29 | through the pipeline from one operation to the other.
30 | """
31 |
32 | index: int
33 | record_key: Optional[int] = None
34 | rng: Optional[np.random.Generator] = None
35 |
36 | def remove_record_key(self):
37 | """Removes record key if exists."""
38 | if self.record_key is None:
39 | return self
40 | else:
41 | return dataclasses.replace(self, record_key=None)
42 |
43 | def __str__(self):
44 | return (
45 | f"RecordMetadata(index={self.index}, record_key={self.record_key}, "
46 | f"rng={self.rng})"
47 | )
48 |
49 |
50 | @dataclasses.dataclass(slots=True)
51 | class Record(Generic[T]):
52 | metadata: RecordMetadata
53 | data: T
54 |
--------------------------------------------------------------------------------
/grain/_src/python/record_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Tests for record."""
15 |
16 | from grain._src.python import record
17 | import numpy as np
18 | from absl.testing import absltest
19 |
20 |
21 | class RecordTest(absltest.TestCase):
22 |
23 | def test_RecordMetadata_str(self):
24 | record_metadata = record.RecordMetadata(
25 | index=0, record_key=0, rng=np.random.default_rng()
26 | )
27 | self.assertEqual(
28 | str(record_metadata),
29 | "RecordMetadata(index=0, record_key=0, rng=Generator(PCG64))",
30 | )
31 |
32 | def test_RecordMetadata_str_none_rng(self):
33 | record_metadata = record.RecordMetadata(index=0, record_key=0)
34 | self.assertStartsWith(
35 | str(record_metadata),
36 | "RecordMetadata(index=0, record_key=0, rng=None",
37 | )
38 |
39 |
40 | if __name__ == "__main__":
41 | absltest.main()
42 |
--------------------------------------------------------------------------------
/grain/_src/python/testdata/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//grain:__subpackages__"])
2 |
3 | licenses(["notice"])
4 |
5 | exports_files(glob(["*.array_record-*"]))
6 |
7 | exports_files(glob(["*.tfrecord"]))
8 |
--------------------------------------------------------------------------------
/grain/_src/python/testdata/digits.array_record-00000-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/grain/_src/python/testdata/digits.array_record-00000-of-00002
--------------------------------------------------------------------------------
/grain/_src/python/testdata/digits.array_record-00001-of-00002:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/grain/_src/python/testdata/digits.array_record-00001-of-00002
--------------------------------------------------------------------------------
/grain/_src/python/testdata/morris_sequence_first_5.tfrecord:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/grain/_src/python/testdata/morris_sequence_first_5.tfrecord
--------------------------------------------------------------------------------
/grain/_src/python/testing/BUILD:
--------------------------------------------------------------------------------
1 | package(default_visibility = ["//grain:__subpackages__"])
2 |
3 | py_library(
4 | name = "experimental",
5 | srcs = ["experimental.py"],
6 | srcs_version = "PY3",
7 | deps = [
8 | "//grain/_src/core:tree_lib",
9 | "@pypi//numpy:pkg",
10 | ],
11 | )
12 |
--------------------------------------------------------------------------------
/grain/_src/python/testing/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/grain/_src/python/testing/experimental.py:
--------------------------------------------------------------------------------
1 | """API to test checkpointing."""
2 |
3 | import itertools
4 | from typing import Any
5 |
6 | from grain._src.core import tree_lib
7 | import numpy as np
8 |
9 |
10 | def assert_equal_output_after_checkpoint(
11 | ds: Any,
12 | ):
13 | """Tests restoring an iterator to various checkpointed states.
14 |
15 | Args:
16 | ds: The dataset to test. It is recommended to use a small dataset,
17 | potentially created using `grain.python.experimental.LimitIterDataset`, to
18 | restrict the number of steps being tested. The underlying dataset iterator
19 | must implement `get_state` and `set_state` for checkpointing.
20 | """
21 |
22 | iterator = ds.__iter__()
23 | checkpoints = []
24 | expected_values = []
25 | state_spec = None
26 | for i in itertools.count():
27 | current_state = iterator.get_state()
28 | if state_spec is None:
29 | state_spec = tree_lib.spec_like(current_state)
30 | else:
31 | np.testing.assert_equal(
32 | state_spec,
33 | tree_lib.spec_like(current_state),
34 | f"State spec does not match the original state spec at step {i}."
35 | f" Expected: {state_spec},"
36 | f" Actual: {tree_lib.spec_like(current_state)}",
37 | )
38 | try:
39 | value = next(iterator)
40 | except StopIteration:
41 | break
42 | checkpoints.append(current_state)
43 | expected_values.append(value)
44 |
45 | assert expected_values, "Dataset did not produce any elements."
46 |
47 | # Restore the iterator at every state, and compare the values.
48 | for i, state in enumerate(checkpoints):
49 | new_iterator = ds.__iter__()
50 | new_iterator.set_state(state)
51 | np.testing.assert_equal(
52 | new_iterator.get_state(),
53 | state,
54 | f"Restored state does not match the original state at step {i}."
55 | f" Expected: {state}, Actual: {new_iterator.get_state()}",
56 | )
57 |
58 | # Test the values at the current state.
59 | new_values = list(new_iterator)
60 | np.testing.assert_equal(
61 | new_values,
62 | expected_values[i:],
63 | f"Restored values mismatch at step {i} for state {state}."
64 | f" \nExpected: {expected_values[i:]}, \nActual: {new_values}",
65 | )
66 |
--------------------------------------------------------------------------------
/grain/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """APIs for saving and restoring pipeline state."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=g-import-not-at-top
19 | # pylint: disable=unused-import
20 | # pylint: disable=g-multiple-import
21 |
22 | from grain._src.python.checkpoint_handlers import (
23 | CheckpointHandler,
24 | )
25 |
26 | # These are imported only if Orbax is present.
27 | try:
28 | from grain._src.python.checkpoint_handlers import (
29 | CheckpointSave,
30 | CheckpointRestore,
31 | )
32 | except ImportError:
33 | pass
34 |
--------------------------------------------------------------------------------
/grain/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Metadata constants."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=unused-import
19 | # pylint: disable=g-multiple-import
20 |
21 | from grain._src.core.constants import (
22 | DATASET_INDEX,
23 | EPOCH,
24 | INDEX,
25 | META_FEATURES,
26 | RECORD,
27 | RECORD_KEY,
28 | SEED,
29 | )
30 |
--------------------------------------------------------------------------------
/grain/experimental.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Experimental Grain APIs."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=g-bad-import-order
19 | # pylint: disable=g-multiple-import
20 | # pylint: disable=unused-import
21 |
22 | from grain._src.core.transforms import FlatMapTransform
23 |
24 | from grain._src.python.dataset.base import (
25 | DatasetOptions,
26 | ExecutionTrackingMode,
27 | )
28 | from grain._src.python.dataset.dataset import (
29 | apply_transformations,
30 | WithOptionsIterDataset,
31 | )
32 | from grain._src.python.dataset.elastic_iterator import ElasticIterator
33 | from grain._src.python.dataset.sources.parquet_dataset import ParquetIterDataset
34 | from grain._src.python.dataset.sources.tfrecord_dataset import TFRecordIterDataset
35 |
36 | from grain._src.python.dataset.transformations.flatmap import (
37 | FlatMapMapDataset,
38 | FlatMapIterDataset,
39 | )
40 | from grain._src.python.dataset.transformations.interleave import (
41 | InterleaveIterDataset,
42 | )
43 | from grain._src.python.dataset.transformations.limit import LimitIterDataset
44 | from grain._src.python.dataset.transformations.map import RngPool
45 | from grain._src.python.dataset.transformations.packing import (
46 | FirstFitPackIterDataset,
47 | )
48 | from grain._src.python.dataset.transformations.packing_concat_then_split import (
49 | BOSHandling,
50 | ConcatThenSplitIterDataset,
51 | )
52 | from grain._src.python.dataset.transformations.prefetch import (
53 | ThreadPrefetchIterDataset,
54 | )
55 | from grain._src.python.dataset.transformations.shuffle import (
56 | WindowShuffleMapDataset,
57 | WindowShuffleIterDataset,
58 | )
59 | from grain._src.python.dataset.transformations.zip import (
60 | ZipMapDataset,
61 | ZipIterDataset,
62 | )
63 | from grain._src.python.experimental.example_packing.packing import (
64 | PackAndBatchOperation,
65 | )
66 | from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import (
67 | index_shuffle,
68 | )
69 |
70 | # This should eventually live under grain.testing.
71 | from grain._src.python.testing.experimental import (
72 | assert_equal_output_after_checkpoint,
73 | )
74 |
--------------------------------------------------------------------------------
/grain/multiprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """APIs to work with multiprocessing."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=unused-import
19 |
20 | from grain._src.python.options import MultiprocessingOptions
21 | from grain._src.python.shared_memory_array import SharedMemoryArray
22 |
--------------------------------------------------------------------------------
/grain/oss/Dockerfile:
--------------------------------------------------------------------------------
1 | # Constructs the environment within which we will build the grain pip wheels.
2 |
3 |
4 | ARG AUDITWHEEL_PLATFORM
5 |
6 | FROM quay.io/pypa/${AUDITWHEEL_PLATFORM}
7 | LABEL maintainer="Grain team "
8 |
9 | ARG PYTHON_VERSION
10 | ARG BAZEL_VERSION
11 |
12 | ENV DEBIAN_FRONTEND=noninteractive
13 |
14 | RUN ulimit -n 1024 && yum install -y rsync
15 |
16 | ENV PYTHON_BIN_PATH=/opt/python/cp${PYTHON_VERSION}-cp${PYTHON_VERSION}/bin
17 | ENV PATH="${PYTHON_BIN_PATH}:${PATH}"
18 |
19 | ENV PYTHON_BIN=${PYTHON_BIN_PATH}/python
20 |
21 | # Download the correct bazel version and make sure it's on path.
22 | RUN BAZEL_ARCH_SUFFIX="$(uname -m | sed s/aarch64/arm64/)" \
23 | && curl -sSL --fail -o /usr/local/bin/bazel "https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-linux-$BAZEL_ARCH_SUFFIX" \
24 | && chmod a+x /usr/local/bin/bazel
25 |
26 | # Install dependencies needed for grain.
27 | RUN --mount=type=cache,target=/root/.cache \
28 | $PYTHON_BIN -m pip install -U \
29 | setuptools;
30 |
31 |
32 | WORKDIR "/tmp/grain"
--------------------------------------------------------------------------------
/grain/oss/array_record/Dockerfile.patch:
--------------------------------------------------------------------------------
1 | diff --git a/oss/build.Dockerfile b/oss/build.Dockerfile
2 | index 5fefa86..17fd296 100644
3 | --- a/oss/build.Dockerfile
4 | +++ b/oss/build.Dockerfile
5 | @@ -11,7 +11,7 @@ ARG BAZEL_VERSION
6 |
7 | ENV DEBIAN_FRONTEND=noninteractive
8 |
9 | -RUN yum install -y rsync
10 | +RUN ulimit -n 1024 && yum install -y rsync
11 | ENV PATH="${PYTHON_BIN}:${PATH}"
12 |
13 | # Download the correct bazel version and make sure it's on path.
14 |
--------------------------------------------------------------------------------
/grain/oss/array_record/WORKSPACE.patch:
--------------------------------------------------------------------------------
1 | diff --git a/WORKSPACE b/WORKSPACE
2 | index e63922f..a655bdb 100644
3 | --- a/WORKSPACE
4 | +++ b/WORKSPACE
5 | @@ -3,13 +3,11 @@ workspace(name = "array_record")
6 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository")
7 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
8 |
9 | -# Abseil LTS 20230125.0
10 | http_archive(
11 | name = "com_google_absl",
12 | - sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21", # SHARED_ABSL_SHA
13 | - strip_prefix = "abseil-cpp-20230125.0",
14 | + strip_prefix = "abseil-cpp-20230802.1",
15 | urls = [
16 | - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz",
17 | + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.1.tar.gz",
18 | ],
19 | )
20 | # Version: pypi-v0.11.0, 2020/10/27
21 | @@ -70,15 +68,13 @@ http_archive(
22 | load("@pybind11_bazel//:python_configure.bzl", "python_configure")
23 | python_configure(name = "local_config_python")
24 |
25 | -# V21.12, 20230130
26 | # proto_library, cc_proto_library, and java_proto_library rules implicitly
27 | # depend on @com_google_protobuf for protoc and proto runtimes.
28 | # This statement defines the @com_google_protobuf repo.
29 | http_archive(
30 | name = "com_google_protobuf",
31 | - sha256 = "22fdaf641b31655d4b2297f9981fa5203b2866f8332d3c6333f6b0107bb320de",
32 | - strip_prefix = "protobuf-21.12",
33 | - urls = ["https://github.com/protocolbuffers/protobuf/archive/v21.12.tar.gz"],
34 | + strip_prefix = "protobuf-23.1",
35 | + urls = ["https://github.com/protocolbuffers/protobuf/archive/v23.1.tar.gz"],
36 | )
37 |
38 | load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
39 | @@ -87,9 +83,9 @@ protobuf_deps()
40 | # Riegeli does not cut releases, so we reference the head
41 | http_archive(
42 | name = "com_google_riegeli",
43 | - strip_prefix = "riegeli-master",
44 | + strip_prefix = "riegeli-904c0c263b8632265103f0066c168a92c7713b07",
45 | urls = [
46 | - "https://github.com/google/riegeli/archive/master.zip",
47 | + "https://github.com/google/riegeli/archive/904c0c263b8632265103f0066c168a92c7713b07.zip",
48 | ],
49 | )
50 | # Riegeli's dependencies
51 | @@ -131,9 +127,8 @@ http_archive(
52 | http_archive(
53 | name = "highwayhash",
54 | build_file = "@com_google_riegeli//third_party:highwayhash.BUILD",
55 | - sha256 = "5380cb7cf19e7c9591f31792b7794d48084f6a3ab7c03d637cd6a32cf2ee8686",
56 | - strip_prefix = "highwayhash-a7f68e2f95fac08b24327d74747521cf634d5aff",
57 | - urls = ["https://github.com/google/highwayhash/archive/a7f68e2f95fac08b24327d74747521cf634d5aff.zip"], # 2023-08-09
58 | + strip_prefix = "highwayhash-3d6a8d35a6bc823b9dbe08804fc2a2d08d373cd7",
59 | + urls = ["https://github.com/google/highwayhash/archive/3d6a8d35a6bc823b9dbe08804fc2a2d08d373cd7.zip"],
60 | )
61 |
62 | # Tensorflow, 20230705
63 |
--------------------------------------------------------------------------------
/grain/oss/array_record/array_record_reader.patch:
--------------------------------------------------------------------------------
1 | diff --git a/cpp/array_record_reader.cc b/cpp/array_record_reader.cc
2 | index bc7675e..623911a 100644
3 | --- a/cpp/array_record_reader.cc
4 | +++ b/cpp/array_record_reader.cc
5 | @@ -196,7 +196,7 @@ void ArrayRecordReaderBase::Initialize() {
6 | max_parallelism = state_->pool->NumThreads();
7 | if (state_->options.max_parallelism().has_value()) {
8 | max_parallelism =
9 | - std::min(max_parallelism, state_->options.max_parallelism().value());
10 | + std::min(max_parallelism, state_->options.max_parallelism().value());
11 | }
12 | }
13 | state_->options.set_max_parallelism(max_parallelism);
14 | @@ -331,7 +331,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecords(
15 | return absl::OkStatus();
16 | }
17 | uint64_t num_chunk_groups =
18 | - CeilOfRatio(state_->chunk_offsets.size(), state_->chunk_group_size);
19 | + CeilOfRatio(state_->chunk_offsets.size(), state_->chunk_group_size);
20 | const auto reader = get_backing_reader();
21 | Reader* mutable_reader = const_cast(
22 | reinterpret_cast(reader.get()));
23 | @@ -340,7 +340,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecords(
24 | uint64_t chunk_idx_start = buf_idx * state_->chunk_group_size;
25 | // inclusive index, not the conventional exclusive index.
26 | uint64_t last_chunk_idx =
27 | - std::min((buf_idx + 1) * state_->chunk_group_size - 1,
28 | + std::min((buf_idx + 1) * state_->chunk_group_size - 1,
29 | state_->chunk_offsets.size() - 1);
30 | uint64_t buf_len = state_->ChunkEndOffset(last_chunk_idx) -
31 | state_->chunk_offsets[chunk_idx_start];
32 | @@ -406,9 +406,9 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsInRange(
33 | "Invalid range [%d, %d). Total records: %d", begin, end, NumRecords()));
34 | }
35 | uint64_t chunk_idx_begin = begin / state_->record_group_size;
36 | - uint64_t chunk_idx_end = CeilOfRatio(end, state_->record_group_size);
37 | + uint64_t chunk_idx_end = CeilOfRatio(end, state_->record_group_size);
38 | uint64_t num_chunks = chunk_idx_end - chunk_idx_begin;
39 | - uint64_t num_chunk_groups = CeilOfRatio(num_chunks, state_->chunk_group_size);
40 | + uint64_t num_chunk_groups = CeilOfRatio(num_chunks, state_->chunk_group_size);
41 |
42 | const auto reader = get_backing_reader();
43 | Reader* mutable_reader =
44 | @@ -418,7 +418,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsInRange(
45 | uint64_t chunk_idx_start =
46 | chunk_idx_begin + buf_idx * state_->chunk_group_size;
47 | // inclusive index, not the conventional exclusive index.
48 | - uint64_t last_chunk_idx = std::min(
49 | + uint64_t last_chunk_idx = std::min(
50 | chunk_idx_begin + (buf_idx + 1) * state_->chunk_group_size - 1,
51 | chunk_idx_end - 1);
52 | uint64_t buf_len = state_->ChunkEndOffset(last_chunk_idx) -
53 | @@ -617,7 +617,7 @@ bool ArrayRecordReaderBase::SeekRecord(uint64_t record_index) {
54 | if (!ok()) {
55 | return false;
56 | }
57 | - state_->record_idx = std::min(record_index, state_->num_records);
58 | + state_->record_idx = std::min(record_index, state_->num_records);
59 | return true;
60 | }
61 |
62 | @@ -667,7 +667,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
63 | std::vector decoders;
64 | decoders.reserve(state_->chunk_group_size);
65 | uint64_t chunk_start = buffer_idx * state_->chunk_group_size;
66 | - uint64_t chunk_end = std::min(state_->chunk_offsets.size(),
67 | + uint64_t chunk_end = std::min(state_->chunk_offsets.size(),
68 | (buffer_idx + 1) * state_->chunk_group_size);
69 | const auto reader = get_backing_reader();
70 | for (uint64_t chunk_idx = chunk_start; chunk_idx < chunk_end; ++chunk_idx) {
71 | @@ -708,7 +708,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) {
72 | chunk_offsets.reserve(state_->chunk_group_size);
73 | uint64_t chunk_start = buffer_to_add * state_->chunk_group_size;
74 | uint64_t chunk_end =
75 | - std::min(state_->chunk_offsets.size(),
76 | + std::min(state_->chunk_offsets.size(),
77 | (buffer_to_add + 1) * state_->chunk_group_size);
78 | for (uint64_t chunk_idx = chunk_start; chunk_idx < chunk_end; ++chunk_idx) {
79 | chunk_offsets.push_back(state_->chunk_offsets[chunk_idx]);
80 |
--------------------------------------------------------------------------------
/grain/oss/array_record/build_whl.patch:
--------------------------------------------------------------------------------
1 | diff --git a/oss/build_whl.sh b/oss/build_whl.sh
2 | index 275c868..27080a9 100755
3 | --- a/oss/build_whl.sh
4 | +++ b/oss/build_whl.sh
5 | @@ -27,11 +27,15 @@ function main() {
6 | [ -e .bazelrc ] && rm .bazelrc
7 |
8 | write_to_bazelrc "build -c opt"
9 | + write_to_bazelrc "build --action_env MACOSX_DEPLOYMENT_TARGET=11.0"
10 | write_to_bazelrc "build --cxxopt=-std=c++17"
11 | write_to_bazelrc "build --host_cxxopt=-std=c++17"
12 | - write_to_bazelrc "build --linkopt=\"-lrt -lm\""
13 | write_to_bazelrc "build --experimental_repo_remote_exec"
14 | write_to_bazelrc "build --python_path=\"${PYTHON_BIN}\""
15 | + PLATFORM="$(uname)"
16 | + if [[ "$PLATFORM" != "Darwin" ]]; then
17 | + write_to_bazelrc "build --linkopt=\"-lrt -lm\""
18 | + fi
19 |
20 | if [ -n "${CROSSTOOL_TOP}" ]; then
21 | write_to_bazelrc "build --crosstool_top=${CROSSTOOL_TOP}"
22 | @@ -42,8 +46,8 @@ function main() {
23 | # https://github.com/bazelbuild/bazel/issues/8622
24 | export USE_BAZEL_VERSION=5.4.0
25 | bazel clean
26 | - bazel build ...
27 | - bazel test --verbose_failures --test_output=errors ...
28 | + bazel build ... --action_env PYTHON_BIN_PATH="${PYTHON_BIN}"
29 | + bazel test --verbose_failures --test_output=errors ... --action_env PYTHON_BIN_PATH="${PYTHON_BIN}"
30 |
31 | DEST="/tmp/array_record/all_dist"
32 | # Create the directory, then do dirname on a non-existent file inside it to
33 | @@ -71,7 +75,11 @@ function main() {
34 |
35 | pushd ${TMPDIR}
36 | echo $(date) : "=== Building wheel"
37 | - ${PYTHON_BIN} setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION}
38 | + plat_name=""
39 | + if [[ "$(uname)" == "Darwin" ]]; then
40 | + plat_name="--plat-name macosx_11_0_$(uname -m)"
41 | + fi
42 | + ${PYTHON_BIN} setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION} $plat_name
43 |
44 | if [ -n "${AUDITWHEEL_PLATFORM}" ]; then
45 | echo $(date) : "=== Auditing wheel"
46 |
--------------------------------------------------------------------------------
/grain/oss/array_record/runner_common.patch:
--------------------------------------------------------------------------------
1 | diff --git a/oss/runner_common.sh b/oss/runner_common.sh
2 | index 2ee2c8c..433698f 100644
3 | --- a/oss/runner_common.sh
4 | +++ b/oss/runner_common.sh
5 | @@ -2,7 +2,7 @@
6 |
7 | # Builds ArrayRecord from source code located in SOURCE_DIR producing wheels
8 | # under $SOURCE_DIR/all_dist.
9 | -function build_and_test_array_record() {
10 | +function build_and_test_array_record_linux() {
11 | SOURCE_DIR=$1
12 |
13 | # Automatically decide which platform to build for by checking on which
14 | @@ -15,7 +15,7 @@ function build_and_test_array_record() {
15 |
16 | # Build wheels for multiple Python minor versions.
17 | PYTHON_MAJOR_VERSION=3
18 | - for PYTHON_MINOR_VERSION in 9 10 11 12
19 | + for PYTHON_MINOR_VERSION in 10 11 12
20 | do
21 | PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}
22 | PYTHON_BIN=/opt/python/cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}-cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}/bin
23 | @@ -41,4 +41,78 @@ function build_and_test_array_record() {
24 | done
25 |
26 | ls ${SOURCE_DIR}/all_dist/*.whl
27 | +}
28 | +
29 | +function install_and_init_pyenv {
30 | + pyenv_root=${1:-$HOME/.pyenv}
31 | + export PYENV_ROOT=$pyenv_root
32 | + if [[ ! -d $PYENV_ROOT ]]; then
33 | + echo "Installing pyenv.."
34 | + git clone https://github.com/pyenv/pyenv.git "$PYENV_ROOT"
35 | + pushd "$PYENV_ROOT"
36 | + git checkout "v2.4.21"
37 | + popd
38 | + export PATH="/home/kbuilder/.local/bin:$PYENV_ROOT/bin:$PATH"
39 | + eval "$(pyenv init --path)"
40 | + fi
41 | +
42 | + echo "Python setup..."
43 | + pyenv install -s "$PYENV_PYTHON_VERSION"
44 | + pyenv global "$PYENV_PYTHON_VERSION"
45 | + export PYTHON_BIN=$(pyenv which python)
46 | +}
47 | +
48 | +function setup_env_vars_py {
49 | + # This controls the python binary to use.
50 | + PYTHON_MAJOR_VERSION=$1
51 | + PYTHON_MINOR_VERSION=$2
52 | + # This is for pyenv install.
53 | + PYENV_PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}
54 | + PYTHON="python$PYENV_PYTHON_VERSION"
55 | +}
56 | +
57 | +function update_bazel_macos {
58 | + BAZEL_VERSION=$1
59 | + ARCH="$(uname -m)"
60 | + curl -L https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh -O
61 | + ls
62 | + chmod +x bazel-*.sh
63 | + ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh --user
64 | + rm -f ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh
65 | + # Add new bazel installation to path
66 | + export PATH="$HOME/bin:$PATH"
67 | +}
68 | +
69 | +function install_ar_deps {
70 | + $PYTHON_BIN -m pip install -U \
71 | + absl-py \
72 | + build \
73 | + etils[epath] \
74 | + setuptools \
75 | + twine \
76 | + wheel;
77 | +}
78 | +
79 | +function build_and_test_array_record_macos() {
80 | + SOURCE_DIR=$1
81 | + # Set up Bazel.
82 | + # Using a previous version of Bazel to avoid:
83 | + # https://github.com/bazelbuild/bazel/issues/8622
84 | + export BAZEL_VERSION="5.4.0"
85 | + update_bazel_macos ${BAZEL_VERSION}
86 | + bazel --version
87 | +
88 | + PYTHON_MAJOR_VERSION=3
89 | + for PYTHON_MINOR_VERSION in 10 11 12
90 | + do
91 | + # Set up Pyenv.
92 | + PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION}
93 | + echo "Creating array_record wheel for Python Version $PYTHON_VERSION"
94 | + setup_env_vars_py $PYTHON_MAJOR_VERSION $PYTHON_MINOR_VERSION
95 | + install_and_init_pyenv
96 | + install_ar_deps
97 | +
98 | + # Build and test ArrayRecord.
99 | + bash ${SOURCE_DIR}/oss/build_whl.sh
100 | + done
101 | }
102 | \ No newline at end of file
103 |
--------------------------------------------------------------------------------
/grain/oss/array_record/setup.patch:
--------------------------------------------------------------------------------
1 | diff --git a/setup.py b/setup.py
2 | index cfb0bac..4763c00 100644
3 | --- a/setup.py
4 | +++ b/setup.py
5 | @@ -25,7 +25,7 @@ class BinaryDistribution(Distribution):
6 |
7 | setup(
8 | name='array_record',
9 | - version='0.5.1',
10 | + version='0.6.0',
11 | description='A file format that achieves a new frontier of IO efficiency',
12 | author='ArrayRecord team',
13 | author_email='no-reply@google.com',
14 |
--------------------------------------------------------------------------------
/grain/proto/BUILD:
--------------------------------------------------------------------------------
1 | load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
2 |
3 | default_visibility = ["//grain:__subpackages__"]
4 |
5 | package(default_visibility = default_visibility)
6 |
7 | py_proto_library(
8 | name = "execution_summary_py_pb2",
9 | # For profiling tooling.
10 | srcs = ["execution_summary.proto"],
11 | )
12 |
--------------------------------------------------------------------------------
/grain/proto/execution_summary.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package grain.python.execution_summary;
4 |
5 | message ExecutionSummary {
6 | message Node {
7 | // Unique ID of the node.
8 | int32 id = 2;
9 | // Human-readable name of the node.
10 | string name = 3;
11 | // Node IDs of the parent nodes.
12 | repeated int32 inputs = 4;
13 | // Ratio of time spent by the pipeline waiting for the given transformation
14 | // node.
15 | double wait_time_ratio = 5;
16 | // Cummulative processing time spent in the node from the start in
17 | // nanoseconds.
18 | int64 total_processing_time_ns = 6;
19 | // Minimum per-element processing time in nanoseconds.
20 | int64 min_processing_time_ns = 7;
21 | // Maximum per-element processing time in nanoseconds.
22 | int64 max_processing_time_ns = 8;
23 | // Number of elements produced by the node.
24 | int64 num_produced_elements = 9;
25 | // Human-readable specification of the produced elements.
26 | string output_spec = 10;
27 | // Whether the node is the root node.
28 | bool is_output = 11;
29 | // Whether the node is prefetch node. Child nodes of prefetch will have
30 | // their wait time ratio derived from the ratio of the prefetch node.
31 | // Sum of all ratios in a single pipeline is 1.
32 | bool is_prefetch = 12;
33 | // Bytes consumed by the node. Currently, bytes comsumed and bytes produced
34 | // by the node is best estimated only for prefetch nodes. The difference is
35 | // used to estimate the memory usage of the node.
36 | int64 bytes_consumed = 13;
37 | // Bytes produced by the node.
38 | int64 bytes_produced = 14;
39 | }
40 | // Map of node IDs to nodes in the pipeline.
41 | map nodes = 1;
42 | }
43 |
--------------------------------------------------------------------------------
/grain/python/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Public API for Grain.
15 |
16 | Backwards compatibility re-import. Prefer adding new APIs to top-level modules.
17 | """
18 |
19 | # pylint: disable=g-importing-member
20 | # pylint: disable=g-import-not-at-top
21 | # pylint: disable=g-multiple-import
22 | # pylint: disable=unused-import
23 |
24 |
25 | from grain._src.core.config import config
26 | from grain._src.core.constants import (
27 | DATASET_INDEX,
28 | EPOCH,
29 | INDEX,
30 | META_FEATURES,
31 | RECORD,
32 | RECORD_KEY,
33 | SEED,
34 | )
35 | from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
36 | from grain._src.core.transforms import (
37 | Batch,
38 | Filter as FilterTransform,
39 | MapTransform,
40 | MapWithIndex as MapWithIndexTransform,
41 | RandomMapTransform,
42 | Transformation,
43 | Transformations,
44 | )
45 |
46 | from grain._src.python.checkpoint_handlers import (
47 | CheckpointHandler as PyGrainCheckpointHandler,
48 | )
49 | from grain._src.python.data_loader import (
50 | DataLoader,
51 | DataLoaderIterator as PyGrainDatasetIterator,
52 | )
53 | from grain._src.python.data_sources import (
54 | ArrayRecordDataSource,
55 | SharedMemoryDataSource as InMemoryDataSource,
56 | RandomAccessDataSource,
57 | RangeDataSource,
58 | )
59 | from grain._src.python.dataset.base import DatasetSelectionMap
60 | from grain._src.python.dataset.dataset import (
61 | MapDataset,
62 | IterDataset,
63 | DatasetIterator,
64 | )
65 |
66 | from grain._src.python.load import load
67 | from grain._src.python.operations import (
68 | BatchOperation,
69 | FilterOperation,
70 | MapOperation,
71 | Operation,
72 | RandomMapOperation,
73 | )
74 | from grain._src.python.options import ReadOptions, MultiprocessingOptions
75 | from grain._src.python.record import (Record, RecordMetadata)
76 | from grain._src.python.samplers import (
77 | IndexSampler,
78 | Sampler,
79 | SequentialSampler,
80 | )
81 | from grain._src.python.shared_memory_array import SharedMemoryArray
82 | from grain.python import experimental
83 |
84 | # These are imported only if Orbax is present.
85 | try:
86 | from grain._src.python.checkpoint_handlers import (
87 | CheckpointSave as PyGrainCheckpointSave,
88 | CheckpointRestore as PyGrainCheckpointRestore,
89 | )
90 | except ImportError:
91 | pass
92 |
--------------------------------------------------------------------------------
/grain/python/experimental.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Experimental Grain APIs.
15 |
16 | Backwards compatibility re-import. Prefer adding new APIs to top-level
17 | `experimental.py` file.
18 | """
19 |
20 | # pylint: disable=g-importing-member
21 | # pylint: disable=g-bad-import-order
22 | # pylint: disable=g-multiple-import
23 | # pylint: disable=unused-import
24 |
25 | from grain._src.python.dataset.base import (
26 | DatasetOptions,
27 | ExecutionTrackingMode,
28 | )
29 | from grain._src.python.dataset.dataset import (
30 | apply_transformations,
31 | WithOptionsIterDataset,
32 | )
33 | from grain._src.python.dataset.sources.parquet_dataset import ParquetIterDataset
34 |
35 | from grain._src.python.dataset.transformations.flatmap import (
36 | FlatMapMapDataset,
37 | FlatMapIterDataset,
38 | )
39 | from grain._src.python.dataset.transformations.interleave import (
40 | InterleaveIterDataset,
41 | )
42 | from grain._src.python.dataset.transformations.limit import LimitIterDataset
43 | from grain._src.python.dataset.transformations.map import RngPool
44 | from grain._src.python.dataset.transformations.mix import ConcatenateMapDataset
45 | from grain._src.python.dataset.transformations.packing import FirstFitPackIterDataset
46 | from grain._src.python.dataset.transformations.packing_concat_then_split import (
47 | BOSHandling,
48 | ConcatThenSplitIterDataset,
49 | )
50 | from grain._src.python.dataset.transformations.prefetch import (
51 | MultiprocessPrefetchIterDataset,
52 | ThreadPrefetchIterDataset,
53 | )
54 | from grain._src.python.dataset.transformations.shuffle import (
55 | WindowShuffleMapDataset,
56 | WindowShuffleIterDataset,
57 | )
58 | from grain._src.python.dataset.transformations.zip import (
59 | ZipMapDataset,
60 | ZipIterDataset,
61 | )
62 | from grain._src.core.transforms import (
63 | FlatMapTransform,
64 | MapWithIndex as MapWithIndexTransform,
65 | )
66 | from grain._src.python.experimental.example_packing.packing import PackAndBatchOperation
67 |
68 | # This should evetually live under grain.testing.
69 | from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint
70 |
71 | from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import index_shuffle
72 |
--------------------------------------------------------------------------------
/grain/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Sampler APIs."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=g-multiple-import
19 | # pylint: disable=unused-import
20 |
21 | from grain._src.python.samplers import (
22 | IndexSampler,
23 | Sampler,
24 | SequentialSampler,
25 | )
26 |
--------------------------------------------------------------------------------
/grain/sharding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """APIs for sharding pipelines for distributed training."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=g-multiple-import
19 | # pylint: disable=unused-import
20 |
21 | from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions
22 |
--------------------------------------------------------------------------------
/grain/sources.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """APIs for reading data from various file formats."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=g-multiple-import
19 | # pylint: disable=unused-import
20 |
21 | # Note to developers:
22 | # - When adding a new OSS data source make the format lib dependency optional
23 | # by lazily importing the source and providing an extra installation, e.g.
24 | # `pip install grain[parquet]`. This will allow users to avoid installing
25 | # all supported format dependencies.
26 | from grain._src.python.data_sources import (
27 | ArrayRecordDataSource,
28 | SharedMemoryDataSource,
29 | RandomAccessDataSource,
30 | RangeDataSource,
31 | )
32 |
--------------------------------------------------------------------------------
/grain/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data transformation APIs."""
15 |
16 |
17 | # pylint: disable=g-importing-member
18 | # pylint: disable=g-multiple-import
19 | # pylint: disable=unused-import
20 |
21 | from grain._src.core.transforms import (
22 | Batch,
23 | Filter,
24 | MapTransform as Map,
25 | MapWithIndex,
26 | RandomMapTransform as RandomMap,
27 | Transformation,
28 | Transformations,
29 | )
30 |
31 | from grain._src.python.dataset.base import DatasetSelectionMap
32 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "grain"
7 | version = "0.2.10"
8 | description = "Grain: A library for loading and transforming data for ML training."
9 | keywords = []
10 | authors = [
11 | {name = "Grain team", email = "grain-dev@google.com"},
12 | ]
13 | dependencies = [
14 | 'absl-py',
15 | 'array-record',
16 | 'cloudpickle',
17 | 'dm-tree',
18 | 'etils[epath,epy]',
19 | 'more-itertools>=9.1.0',
20 | 'numpy',
21 | 'protobuf>=3.20.3',
22 | ]
23 | readme = "README.md"
24 | license = { file = "LICENSE" }
25 | requires-python = ">=3.10"
26 | classifiers = [
27 | "Programming Language :: Python :: 3",
28 | "License :: OSI Approved :: Apache Software License",
29 | "Operating System :: POSIX :: Linux",
30 | ]
31 |
32 | [project.optional-dependencies]
33 | testing = [
34 | 'attrs',
35 | 'dill',
36 | 'jax',
37 | 'jaxlib',
38 | 'jaxtyping',
39 | 'pyarrow',
40 | 'tensorflow-datasets',
41 | ]
42 | parquet = [
43 | 'pyarrow',
44 | ]
45 |
46 | [project.urls]
47 | homepage = "https://github.com/google/grain"
48 |
49 | [tool.setuptools.packages.find]
50 | include = ["grain*"]
51 |
52 | [tool.setuptools.package-data]
53 | "*" = ["*.so"]
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setup.py file for grain.
2 |
3 | Most project configs are in `pyproject.toml` -- prefer to modify
4 | `pyproject.toml` over this file if possible.
5 | """
6 |
7 | import setuptools
8 | from setuptools import dist
9 |
10 |
11 | class BinaryDistribution(dist.Distribution):
12 | """This class makes 'bdist_wheel' include an ABI tag on the wheel."""
13 |
14 | def has_ext_modules(self):
15 | return True
16 |
17 |
18 | setuptools.setup(
19 | distclass=BinaryDistribution,
20 | )
21 |
--------------------------------------------------------------------------------
/test_requirements.in:
--------------------------------------------------------------------------------
1 | # Copyright 2025 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # This is the list of our Python third_party dependencies that Bazel should
16 | # pull from PyPi.
17 | # Note that requirements.txt must be re-generated using
18 | # bazel run //:requirements.update in the OSS version.
19 | array-record
20 | absl-py
21 | dm-tree
22 | etils[epath,epy]
23 | cloudpickle
24 | jax
25 | jaxtyping
26 | numpy
27 | attrs
28 | pyarrow
29 | parameterized
--------------------------------------------------------------------------------