├── .github └── workflows │ ├── build_and_publish_template.yml │ ├── preview_docs.yml │ ├── publish_nightly.yml │ ├── publish_release.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── BUILD ├── CONTRIBUTING.md ├── LICENSE ├── MODULE.bazel ├── README.md ├── docs ├── CONTRIBUTING.md ├── README.md ├── api_choice.md ├── behind_the_scenes.md ├── conf.py ├── data_loader │ ├── samplers.md │ └── transformations.md ├── data_sources.md ├── grain.checkpoint.rst ├── grain.constants.rst ├── grain.dataset.rst ├── grain.experimental.rst ├── grain.multiprocessing.rst ├── grain.rst ├── grain.samplers.rst ├── grain.sharding.rst ├── grain.sources.rst ├── grain.transforms.rst ├── images │ ├── data_flow_multiple_workers.png │ ├── data_flow_zero_workers.png │ └── grain_pipeline.svg ├── index.md ├── installation.md ├── requirements.txt └── tutorials │ ├── data_loader_tutorial.ipynb │ ├── data_loader_tutorial.md │ ├── data_sources │ ├── arrayrecord_data_source_tutorial.ipynb │ ├── arrayrecord_data_source_tutorial.md │ ├── huggingface_dataset_tutorial.ipynb │ ├── huggingface_dataset_tutorial.md │ ├── index.rst │ ├── load_from_gcs_tutorial.ipynb │ ├── load_from_gcs_tutorial.md │ ├── load_from_s3_tutorial.ipynb │ ├── load_from_s3_tutorial.md │ ├── parquet_dataset_tutorial.ipynb │ ├── parquet_dataset_tutorial.md │ ├── pytorch_dataset_tutorial.ipynb │ └── pytorch_dataset_tutorial.md │ ├── dataset_advanced_tutorial.ipynb │ ├── dataset_advanced_tutorial.md │ ├── dataset_basic_tutorial.ipynb │ ├── dataset_basic_tutorial.md │ ├── dataset_debugging_tutorial.ipynb │ └── dataset_debugging_tutorial.md ├── grain ├── BUILD ├── __init__.py ├── _src │ ├── __init__.py │ ├── core │ │ ├── BUILD │ │ ├── config.py │ │ ├── constants.py │ │ ├── exceptions.py │ │ ├── grain_random.py │ │ ├── monitoring.py │ │ ├── parallel.py │ │ ├── parallel_test.py │ │ ├── sharding.py │ │ ├── sharding_test.py │ │ ├── smoke_test_with_jax.py │ │ ├── smoke_test_with_tf.py │ │ ├── transforms.py │ │ ├── transforms_test.py │ │ ├── tree_lib.py │ │ ├── tree_lib_jax_test.py │ │ ├── tree_lib_test.py │ │ ├── usage_logging.py │ │ └── version_test.py │ └── python │ │ ├── BUILD │ │ ├── checkpoint_handlers.py │ │ ├── checkpointing.py │ │ ├── data_loader.py │ │ ├── data_loader_test.py │ │ ├── data_sources.py │ │ ├── data_sources_test.py │ │ ├── dataset │ │ ├── BUILD │ │ ├── base.py │ │ ├── base_test.py │ │ ├── dataset.py │ │ ├── dataset_test.py │ │ ├── elastic_iterator.py │ │ ├── elastic_iterator_test.py │ │ ├── sources │ │ │ ├── BUILD │ │ │ ├── parquet_dataset.py │ │ │ ├── parquet_dataset_test.py │ │ │ ├── tfrecord_dataset.py │ │ │ └── tfrecord_dataset_test.py │ │ ├── stats.py │ │ ├── stats_test.py │ │ ├── stats_utils.py │ │ ├── stats_utils_test.py │ │ ├── transformations │ │ │ ├── BUILD │ │ │ ├── batch.py │ │ │ ├── batch_test.py │ │ │ ├── filter.py │ │ │ ├── filter_test.py │ │ │ ├── flatmap.py │ │ │ ├── flatmap_test.py │ │ │ ├── interleave.py │ │ │ ├── interleave_test.py │ │ │ ├── limit.py │ │ │ ├── limit_test.py │ │ │ ├── map.py │ │ │ ├── map_test.py │ │ │ ├── mix.py │ │ │ ├── mix_test.py │ │ │ ├── packing.py │ │ │ ├── packing_concat_then_split.py │ │ │ ├── packing_concat_then_split_test.py │ │ │ ├── packing_packed_batch.py │ │ │ ├── packing_test.py │ │ │ ├── prefetch.py │ │ │ ├── prefetch_test.py │ │ │ ├── repeat.py │ │ │ ├── repeat_test.py │ │ │ ├── shuffle.py │ │ │ ├── shuffle_test.py │ │ │ ├── slice.py │ │ │ ├── slice_test.py │ │ │ ├── source.py │ │ │ ├── source_test.py │ │ │ ├── testing_util.py │ │ │ ├── zip.py │ │ │ └── zip_test.py │ │ ├── visualize.py │ │ └── visualize_test.py │ │ ├── experimental │ │ ├── example_packing │ │ │ ├── BUILD │ │ │ ├── packing.py │ │ │ └── packing_test.py │ │ └── index_shuffle │ │ │ ├── BUILD │ │ │ ├── index_shuffle.cc │ │ │ ├── index_shuffle.h │ │ │ └── python │ │ │ ├── BUILD │ │ │ ├── index_shuffle_module.cc │ │ │ ├── index_shuffle_python.py │ │ │ ├── index_shuffle_python_test.py │ │ │ └── index_shuffle_test.py │ │ ├── grain_logging.py │ │ ├── grain_logging_test.py │ │ ├── grain_pool.py │ │ ├── grain_pool_test.py │ │ ├── load.py │ │ ├── load_test.py │ │ ├── multiprocessing_common.py │ │ ├── multiprocessing_common_test.py │ │ ├── operations.py │ │ ├── operations_test.py │ │ ├── options.py │ │ ├── record.py │ │ ├── record_test.py │ │ ├── samplers.py │ │ ├── samplers_test.py │ │ ├── shared_memory_array.py │ │ ├── shared_memory_array_test.py │ │ ├── testdata │ │ ├── BUILD │ │ ├── digits.array_record-00000-of-00002 │ │ ├── digits.array_record-00001-of-00002 │ │ └── morris_sequence_first_5.tfrecord │ │ └── testing │ │ ├── BUILD │ │ ├── __init__.py │ │ └── experimental.py ├── checkpoint.py ├── constants.py ├── experimental.py ├── multiprocessing.py ├── oss │ ├── Dockerfile │ ├── array_record │ │ ├── Dockerfile.patch │ │ ├── WORKSPACE.patch │ │ ├── array_record_reader.patch │ │ ├── build_whl.patch │ │ ├── runner_common.patch │ │ └── setup.patch │ ├── build_whl.sh │ └── runner_common.sh ├── proto │ ├── BUILD │ └── execution_summary.proto ├── python │ ├── __init__.py │ └── experimental.py ├── samplers.py ├── sharding.py ├── sources.py └── transforms.py ├── pyproject.toml ├── setup.py ├── test_requirements.in ├── test_requirements_lock_3_10.txt ├── test_requirements_lock_3_11.txt ├── test_requirements_lock_3_12.txt └── test_requirements_lock_3_13.txt /.github/workflows/build_and_publish_template.yml: -------------------------------------------------------------------------------- 1 | # This workflow builds Grain wheels and uploads them as artifacts. 2 | 3 | name: Build & Publish Template 4 | 5 | on: 6 | workflow_call: 7 | inputs: 8 | pypi_project_url: 9 | required: true 10 | type: string 11 | run_tests: 12 | required: true 13 | type: boolean 14 | is_nightly: 15 | required: true 16 | type: boolean 17 | 18 | permissions: 19 | contents: read 20 | 21 | jobs: 22 | build-and-test: 23 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 24 | runs-on: "${{ matrix.os }}" 25 | 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | python-version: ["3.10", "3.11", "3.12", "3.13"] 30 | os: [ubuntu-22.04, ubuntu-22.04-arm, macos-14] 31 | 32 | steps: 33 | - uses: "actions/checkout@v3" 34 | - name: Create directory 35 | run: | 36 | mkdir -p /tmp/grain 37 | cp -r . /tmp/grain 38 | - name: Build package 39 | run: | 40 | set -xe 41 | export PYTHON_VERSION=${{ matrix.python-version }} 42 | export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1) 43 | export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2) 44 | export BAZEL_VERSION="7.1.1" 45 | export OUTPUT_DIR="/tmp/grain" 46 | export SOURCE_DIR="/tmp/grain" 47 | export RUN_TESTS=${{ inputs.run_tests }} 48 | export IS_NIGHTLY=${{ inputs.is_nightly }} 49 | . "${SOURCE_DIR}"'/grain/oss/runner_common.sh' 50 | build_and_test_grain 51 | - name: Upload Grain artifacts 52 | uses: actions/upload-artifact@v4 53 | with: 54 | name: built-grain-wheels-${{ matrix.os }}-${{ matrix.python-version }} 55 | path: /tmp/grain/all_dist/*.whl 56 | 57 | publish-wheel: 58 | runs-on: ubuntu-22.04 59 | needs: build-and-test 60 | permissions: 61 | id-token: write 62 | environment: 63 | name: pypi 64 | url: ${{ inputs.pypi_project_url }} 65 | steps: 66 | - name: Download Grain artifacts 67 | uses: actions/download-artifact@v4 68 | with: 69 | pattern: built-grain-wheels-* 70 | path: dist/ 71 | merge-multiple: true 72 | - name: Publish package distributions to PyPI 73 | uses: pypa/gh-action-pypi-publish@release/v1 -------------------------------------------------------------------------------- /.github/workflows/preview_docs.yml: -------------------------------------------------------------------------------- 1 | # Add a link to preview the documentation on Read the Docs for every pull request. 2 | name: "RTD preview" 3 | 4 | on: 5 | pull_request_target: 6 | types: 7 | - opened 8 | 9 | permissions: 10 | pull-requests: write 11 | 12 | jobs: 13 | documentation-links: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: readthedocs/actions/preview@v1 17 | with: 18 | project-slug: "google-grain" 19 | single-version: true -------------------------------------------------------------------------------- /.github/workflows/publish_nightly.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Nightly 2 | 3 | on: 4 | schedule: 5 | # At 04:00 on Monday. 6 | - cron: '0 4 * * 1' 7 | 8 | jobs: 9 | call-workflow: 10 | uses: ./.github/workflows/build_and_publish_template.yml 11 | with: 12 | pypi_project_url: https://pypi.org/project/grain-nightly 13 | run_tests: true 14 | is_nightly: true -------------------------------------------------------------------------------- /.github/workflows/publish_release.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Release 2 | 3 | on: 4 | workflow_dispatch: 5 | inputs: 6 | run_tests: 7 | description: 'Run unit tests' 8 | required: false 9 | default: true 10 | type: boolean 11 | 12 | jobs: 13 | call-workflow: 14 | uses: ./.github/workflows/build_and_publish_template.yml 15 | with: 16 | pypi_project_url: https://pypi.org/project/grain 17 | run_tests: ${{ inputs.run_tests }} 18 | is_nightly: false -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Build & Test 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | jobs: 10 | build-and-test: 11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 12 | runs-on: "${{ matrix.os }}" 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.10", "3.11", "3.12", "3.13"] 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: "actions/checkout@v2" 21 | - name: Create directory 22 | run: | 23 | mkdir -p /tmp/grain 24 | cp -r . /tmp/grain 25 | - name: Build package 26 | run: | 27 | set -xe 28 | export PYTHON_VERSION=${{ matrix.python-version }} 29 | export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1) 30 | export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2) 31 | export BAZEL_VERSION="7.1.1" 32 | export AUDITWHEEL_PLATFORM="manylinux2014_x86_64" 33 | export RUN_TESTS="true" 34 | cd /tmp/grain 35 | DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \ 36 | --build-arg AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \ 37 | --build-arg PYTHON_VERSION=${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION} \ 38 | --build-arg BAZEL_VERSION=${BAZEL_VERSION} \ 39 | -t grain:${PYTHON_VERSION} grain/oss 40 | docker run --rm -a stdin -a stdout -a stderr \ 41 | --env PYTHON_VERSION=${PYTHON_VERSION} \ 42 | --env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \ 43 | --env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \ 44 | --env BAZEL_VERSION=${BAZEL_VERSION} \ 45 | --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \ 46 | --env RUN_TESTS=${RUN_TESTS} \ 47 | -v /tmp/grain:/tmp/grain \ 48 | --name grain grain:${PYTHON_VERSION} \ 49 | bash grain/oss/build_whl.sh -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # Unit test / coverage reports 27 | htmlcov/ 28 | .tox/ 29 | .nox/ 30 | .coverage 31 | .coverage.* 32 | .cache 33 | nosetests.xml 34 | coverage.xml 35 | *.cover 36 | *.py,cover 37 | .hypothesis/ 38 | .pytest_cache/ 39 | cover/ 40 | 41 | # Translations 42 | *.mo 43 | *.pot 44 | 45 | # Sphinx documentation 46 | docs/_build/ 47 | 48 | # Jupyter Notebook 49 | .ipynb_checkpoints 50 | 51 | # IPython 52 | profile_default/ 53 | ipython_config.py 54 | 55 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 56 | __pypackages__/ 57 | 58 | # Environments 59 | .venv 60 | 61 | # Bazel outputs 62 | bazel-bin/ 63 | bazel-grain/ 64 | bazel-out/ 65 | bazel-testlogs/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v5.0.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-merge-conflict 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | 10 | - repo: https://github.com/mwouts/jupytext 11 | rev: v1.15.2 12 | hooks: 13 | - id: jupytext 14 | files: docs/tutorials/ 15 | args: [--sync] -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-lts-latest 5 | tools: 6 | python: "3.12" 7 | 8 | sphinx: 9 | configuration: docs/conf.py 10 | # Note this is set to false for now while the warnings are resolved 11 | fail_on_warning: false 12 | 13 | python: 14 | install: 15 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /BUILD: -------------------------------------------------------------------------------- 1 | load("@python//3.10:defs.bzl", compile_pip_requirements_3_10 = "compile_pip_requirements") 2 | load("@python//3.11:defs.bzl", compile_pip_requirements_3_11 = "compile_pip_requirements") 3 | load("@python//3.12:defs.bzl", compile_pip_requirements_3_12 = "compile_pip_requirements") 4 | load("@python//3.13:defs.bzl", compile_pip_requirements_3_13 = "compile_pip_requirements") 5 | 6 | py_library( 7 | name = "setup", 8 | srcs = ["setup.py"], 9 | srcs_version = "PY3", 10 | ) 11 | 12 | compile_pip_requirements_3_10( 13 | name = "requirements_3_10", 14 | requirements_in = "test_requirements.in", 15 | requirements_txt = "test_requirements_lock_3_10.txt", 16 | ) 17 | 18 | compile_pip_requirements_3_11( 19 | name = "requirements_3_11", 20 | requirements_in = "test_requirements.in", 21 | requirements_txt = "test_requirements_lock_3_11.txt", 22 | ) 23 | 24 | compile_pip_requirements_3_12( 25 | name = "requirements_3_12", 26 | requirements_in = "test_requirements.in", 27 | requirements_txt = "test_requirements_lock_3_12.txt", 28 | ) 29 | 30 | compile_pip_requirements_3_13( 31 | name = "requirements_3_13", 32 | requirements_in = "test_requirements.in", 33 | requirements_txt = "test_requirements_lock_3_13.txt", 34 | ) 35 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /MODULE.bazel: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | module( 15 | name = "grain", 16 | version = "0.2.9", 17 | repo_name = "com_google_grain", 18 | ) 19 | 20 | bazel_dep(name = "bazel_skylib", version = "1.6.1") 21 | bazel_dep(name = "platforms", version = "0.0.8") 22 | bazel_dep(name = "rules_python", version = "0.37.0") 23 | bazel_dep(name = "rules_cc", version = "0.0.9") 24 | bazel_dep(name = "pybind11_bazel", version = "2.11.1") 25 | bazel_dep(name = "abseil-py", version = "2.1.0") 26 | bazel_dep(name = "abseil-cpp", version = "20230802.0.bcr.1") 27 | bazel_dep(name = "protobuf", version = "24.4", repo_name = "com_google_protobuf") 28 | 29 | http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 30 | 31 | http_archive( 32 | name = "pybind11", 33 | build_file = "@pybind11_bazel//:pybind11.BUILD", 34 | sha256 = "201966a61dc826f1b1879a24a3317a1ec9214a918c8eb035be2f30c3e9cfbdcb", 35 | strip_prefix = "pybind11-2.10.3", 36 | urls = ["https://github.com/pybind/pybind11/archive/refs/tags/v2.10.3.zip"], 37 | ) 38 | 39 | SUPPORTED_PYTHON_VERSIONS = [ 40 | "3.10", 41 | "3.11", 42 | "3.12", 43 | "3.13", 44 | ] 45 | 46 | DEFAULT_PYTHON_VERSION = "3.10" 47 | 48 | python_configure = use_extension("@pybind11_bazel//:python_configure.bzl", "extension") 49 | use_repo(python_configure, "local_config_python") 50 | 51 | python = use_extension("@rules_python//python/extensions:python.bzl", "python") 52 | 53 | [ 54 | python.toolchain( 55 | ignore_root_user_error = True, 56 | is_default = python_version == DEFAULT_PYTHON_VERSION, 57 | python_version = python_version, 58 | ) 59 | for python_version in SUPPORTED_PYTHON_VERSIONS 60 | ] 61 | 62 | use_repo(python, python = "python_versions") 63 | 64 | pip = use_extension("@rules_python//python/extensions:pip.bzl", "pip") 65 | 66 | # requirements_lock.txt is generated by 67 | # bazel run //:requirements.update 68 | [ 69 | pip.parse( 70 | hub_name = "pypi", 71 | python_version = version, 72 | requirements_lock = "test_requirements_lock_" + version.replace(".", "_") + ".txt", 73 | ) 74 | for version in SUPPORTED_PYTHON_VERSIONS 75 | ] 76 | 77 | use_repo(pip, "pypi") 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Grain - Feeding JAX Models 2 | 3 | [![Continuous integration](https://github.com/google/grain/actions/workflows/tests.yml/badge.svg)](https://github.com/google/grain/actions/workflows/tests.yml) 4 | [![PyPI version](https://img.shields.io/pypi/v/grain)](https://pypi.org/project/grain/) 5 | 6 | 7 | [**Installation**](#installation) 8 | | [**Quickstart**](#quickstart) 9 | | [**Reference docs**](https://google-grain.readthedocs.io/en/latest/) 10 | 11 | Grain is a Python library for reading and processing data for training and 12 | evaluating JAX models. It is flexible, fast and deterministic. 13 | 14 | Grain allows to define data processing steps in a simple declarative way: 15 | 16 | ```python 17 | import grain 18 | 19 | dataset = ( 20 | grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) 21 | .shuffle(seed=42) # Shuffles elements globally. 22 | .map(lambda x: x+1) # Maps each element. 23 | .batch(batch_size=2) # Batches consecutive elements. 24 | ) 25 | 26 | for batch in dataset: 27 | # Training step. 28 | ``` 29 | 30 | Grain is designed to work with JAX models but it does not require JAX to run 31 | and can be used with other frameworks as well. 32 | 33 | ## Installation 34 | 35 | Grain is available on [PyPI](https://pypi.org/project/grain/) and can be 36 | installed with `pip install grain`. 37 | 38 | ### Supported platforms 39 | 40 | Grain does not directly use GPU or TPU in its transformations, the processing 41 | within Grain will be done on the CPU by default. 42 | 43 | | | Linux | Mac | Windows | 44 | |---------|---------|---------|---------| 45 | | x86_64 | yes | no | no | 46 | | aarch64 | yes | yes | n/a | 47 | 48 | ## Quickstart 49 | 50 | - [Basic `Dataset` tutorial](https://google-grain.readthedocs.io/en/latest/tutorials/dataset_basic_tutorial.html) 51 | 52 | ## Existing users 53 | 54 | Grain is used by [MaxText](https://github.com/google/maxtext/tree/main), 55 | [Gemma](https://github.com/google-deepmind/gemma), 56 | [kauldron](https://github.com/google-research/kauldron) and multiple internal 57 | Google projects. -------------------------------------------------------------------------------- /docs/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Grain 2 | 3 | 4 | 5 | ## Contributing to the Grain project documentation 6 | 7 | ### Pre-requisites 8 | 9 | To contribute to the documentation, you will need to set your development environment. 10 | 11 | You can create a virtual environment or conda environment and install the packages in 12 | `docs/requirements.txt`. 13 | 14 | ```bash 15 | # Create a virtual environment 16 | python3 -m venv .venv 17 | # Activate the virtual environment 18 | source .venv/bin/activate 19 | # Install the requirements 20 | pip install -r docs/requirements.txt 21 | ``` 22 | 23 | or with conda 24 | 25 | ```bash 26 | # Create a conda environment 27 | conda create -n "grain-docs" python=3.12 28 | # Activate the conda environment 29 | conda activate grain-docs 30 | # Install the requirements 31 | python3 -m pip install -r docs/requirements.txt 32 | ``` 33 | 34 | ### Building the documentation locally 35 | 36 | To build the documentation locally, you can run the following command: 37 | 38 | ```bash 39 | # Change to the docs/ directory 40 | cd docs 41 | sphinx-build -b html . _build/html 42 | ``` 43 | 44 | You can then open the generated HTML files in your browser by opening 45 | `docs/_build/html/index.html`. 46 | 47 | ## Documentation via Jupyter notebooks 48 | 49 | The `pygrain` documentation includes Jupyter notebooks that are rendered 50 | directly into the website via the [myst-nb](https://myst-nb.readthedocs.io/) extension. 51 | To ease review and diff of notebooks, we keep markdown versions of the content 52 | synced via [jupytext](https://jupytext.readthedocs.io/). 53 | 54 | Note you will need to install `jupytext` to sync the notebooks with markdown files: 55 | 56 | ```bash 57 | # With pip 58 | python3 -m pip install jupytext 59 | 60 | # With conda 61 | conda install -c conda-forge jupytext 62 | ``` 63 | 64 | ### Adding a new notebook 65 | 66 | We aim to have one notebook per topic or tutorial covered. 67 | To add a new notebook to the repository, first move the notebook into the appropriate 68 | location in the `docs` directory: 69 | 70 | ```bash 71 | mv ~/new-tutorial.ipynb docs/tutorials/new_tutorial.ipynb 72 | ``` 73 | 74 | Next, we use `jupytext` to mark the notebook for syncing with Markdown: 75 | 76 | ```bash 77 | jupytext --set-formats ipynb,md:myst docs/tutorials/new_tutorial.ipynb 78 | ``` 79 | 80 | Finally, we can sync the notebook and markdown source: 81 | 82 | ```bash 83 | jupytext --sync docs/tutorials/new_tutorial.ipynb 84 | ``` 85 | 86 | To ensure that the new notebook is rendered as part of the site, be sure to add 87 | references to a `toctree` declaration somewhere in the source tree, for example 88 | in `docs/index.md`. You will also need to add references in `docs/conf.py` 89 | to specify whether the notebook should be executed, and to specify which file 90 | sphinx should use when generating the site. 91 | 92 | ### Editing an existing notebook 93 | 94 | When editing the text of an existing notebook, it is recommended to edit the 95 | markdown file only, and then automatically sync using `jupytext` via the 96 | `pre-commit` framework, which we use to check in GitHub CI that notebooks are 97 | properly synced. 98 | For example, say you have edited `docs/tutorials/new_tutorial.md`, then 99 | you can do the following: 100 | 101 | ```bash 102 | pip install pre-commit 103 | git add docs/tutorials/new_tutorial.* # stage the new changes 104 | pre-commit run # run pre-commit checks on added files 105 | git add docs/tutorials/new_tutorial.* # stage the files updated by pre-commit 106 | git commit -m "Update new tutorial" # commit to the branch -------------------------------------------------------------------------------- /docs/api_choice.md: -------------------------------------------------------------------------------- 1 | # Choice of API 2 | 3 | 4 | 5 | 6 | 7 | Grain offers two different ways of defining data processing pipelines: 8 | [`DataLoader`](tutorials/data_loader_tutorial.md) and [`Dataset`](tutorials/dataset_basic_tutorial.md). 9 | 10 | > TL;DR: If you need to do one of the following: 11 | > 12 | > * mix multiple data sources 13 | > * pack variable length elements 14 | > * split dataset elements and globally shuffle the splits 15 | > 16 | > then you should use `Dataset`, otherwise use simpler `DataLoader`. 17 | 18 | ## `DataLoader` 19 | 20 | `DataLoader` is a high-level API that uses the following abstractions to define 21 | data processing: 22 | 23 | * [`RandomAccessDataSource`](https://github.com/google/grain/tree/main/grain/_src/python/data_sources.py) 24 | that reads raw input data. 25 | * A 26 | [`Sampler`](https://github.com/google/grain/tree/main/grain/_src/python/samplers.py) 27 | that defines the order in which the raw data should be read. 28 | * A flat sequence of 29 | [`Transformation`s](https://github.com/google/grain/tree/main/grain/_src/core/transforms.py) 30 | to apply to the raw data. 31 | 32 | You can specify other execution parameters for asynchronous data processing, 33 | sharding, shuffling, and `DataLoader` will automatically take care of inserting 34 | them in the right places between the data processing steps. 35 | 36 | These are simple and usually general enough to cover most data processing use 37 | cases. Prefer using `DataLoader` if your workflow can be described using the 38 | abstractions above. See [tutorial](tutorials/data_loader_tutorial.md) 39 | for more details. 40 | 41 | ## `Dataset` 42 | 43 | `Dataset` is a lower-level API that uses chaining syntax to define data 44 | transformation steps. It allows more general types of processing (e.g. dataset 45 | mixing) and more control over the execution (e.g. different order of data 46 | sharding and shuffling). `Dataset` transformations are composed in a way that 47 | allows to preserve random access property past the source and some of the 48 | transformations. This, among other things, can be used for debugging by 49 | evaluating dataset elements at specific positions without processing the entire 50 | dataset. 51 | 52 | There are 3 main classes comprising the `Dataset` API: 53 | 54 | * [`MapDataset`](https://github.com/google/grain/tree/main/grain/_src/python/dataset/dataset.py) 55 | defines a dataset that supports efficient random access. Think of it as an 56 | (infinite) `Sequence` that computes values lazily. 57 | * [`IterDataset`](https://github.com/google/grain/tree/main/grain/_src/python/dataset/dataset.py) 58 | defines a dataset that does not support efficient random access and only 59 | supports iterating over it. It's an `Iterable`. Any `MapDataset` can be 60 | turned into a `IterDataset` by calling `to_iter_dataset()`. 61 | * [`DatasetIterator`](https://github.com/google/grain/tree/main/grain/_src/python/dataset/dataset.py) 62 | defines a stateful iterator of an `IterDataset`. The state of the iterator 63 | can be saved and restored. 64 | 65 | Most data pipelines will start with one or more `MapDataset` (often derived from 66 | a `RandomAccessDataSource`) and switch to `IterDataset` late or not at all. See 67 | [tutorial](tutorials/dataset_basic_tutorial.md) 68 | for more details. 69 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | """Configuration file for the Sphinx documentation builder. 2 | 3 | For the full list of built-in configuration values, see the documentation: 4 | https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | """ 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | import os 13 | import pathlib 14 | import sys 15 | 16 | sys.path.insert(0, str(pathlib.Path('..', 'grain').resolve())) 17 | sys.path.insert(0, os.path.abspath('..')) 18 | 19 | # -- Project information ----------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 21 | 22 | project = 'Grain' 23 | copyright = '2024, Grain team' # pylint: disable=redefined-builtin 24 | author = 'Grain team' 25 | 26 | # -- General configuration --------------------------------------------------- 27 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 28 | 29 | extensions = [ 30 | 'myst_nb', 31 | 'sphinx_copybutton', 32 | 'sphinx_design', 33 | 'sphinx.ext.autodoc', 34 | 'sphinx.ext.autosummary', 35 | 'sphinx.ext.napoleon', 36 | ] 37 | 38 | templates_path = ['_templates'] 39 | source_suffix = ['.rst', '.ipynb', '.md'] 40 | exclude_patterns = [ 41 | '_build', 42 | 'Thumbs.db', 43 | '.DS_Store', 44 | 'tutorials/dataset_basic_tutorial.md', 45 | ] 46 | 47 | # Suppress warning in exception basic_data_tutorial 48 | suppress_warnings = [ 49 | 'misc.highlighting_failure', 50 | ] 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 54 | 55 | html_theme = 'sphinx_book_theme' 56 | html_title = 'Grain' 57 | html_static_path = ['_static'] 58 | 59 | # TODO: Add logo and favicon 60 | # html_logo = '_static/' 61 | # html_favicon = '_static/favicon.png' 62 | 63 | # Theme-specific options 64 | # https://sphinx-book-theme.readthedocs.io/en/stable/reference.html 65 | html_theme_options = { 66 | 'show_navbar_depth': 1, 67 | 'show_toc_level': 3, 68 | 'repository_url': 'https://github.com/google/grain', 69 | 'use_issues_button': True, 70 | 'use_repository_button': True, 71 | 'path_to_docs': 'docs/', 72 | 'navigation_with_keys': True, 73 | } 74 | 75 | # Autodoc settings 76 | autosummary_generate = True 77 | autodoc_typehints = 'description' 78 | # We mock dependencies and internal modules that require building to be able to 79 | # import the python symbols for extracting docstrings. 80 | autodoc_mock_imports = [ 81 | 'grain.proto.execution_summary_pb2', 82 | 'grain._src.python.experimental.index_shuffle.python.index_shuffle_module', 83 | 'cloudpickle', 84 | 'numpy', 85 | 'orbax', 86 | 'tree', 87 | 'absl', 88 | 'absl.logging', 89 | 'array_record', 90 | ] 91 | 92 | # -- Myst configurations ------------------------------------------------- 93 | myst_enable_extensions = ['colon_fence'] 94 | nb_execution_mode = 'force' 95 | nb_execution_allow_errors = False 96 | nb_merge_streams = True 97 | nb_execution_show_tb = True 98 | 99 | # Notebook cell execution timeout; defaults to 30. 100 | nb_execution_timeout = 100 101 | 102 | # List of patterns, relative to source directory, that match notebook 103 | # files that will not be executed. 104 | nb_execution_excludepatterns = [ 105 | 'tutorials/dataset_advanced_tutorial.ipynb', 106 | 'tutorials/dataset_basic_tutorial.ipynb', 107 | 'tutorials/data_loader_tutorial.ipynb', 108 | 'tutorials/dataset_debugging_tutorial.ipynb', 109 | 'tutorials/data_sources/load_from_s3_tutorial.ipynb', 110 | 'tutorials/data_sources/load_from_gcs_tutorial.ipynb', 111 | 'tutorials/data_sources/parquet_dataset_tutorial.ipynb', 112 | 'tutorials/data_sources/arrayrecord_data_source_tutorial.ipynb', 113 | 'tutorials/data_sources/huggingface_dataset_tutorial.ipynb', 114 | 'tutorials/data_sources/pytorch_dataset_tutorial.ipynb', 115 | ] 116 | -------------------------------------------------------------------------------- /docs/data_loader/samplers.md: -------------------------------------------------------------------------------- 1 | # Samplers 2 | 3 | Samplers in Grain are responsible for determining the order in which records 4 | are processed. This allows Grain to implement global transformations (e.g. 5 | global shuffling, sharding, repeating for multiple epochs) before reading any 6 | records. 7 | 8 | 9 | 10 | Samplers need to implement the following iterator protocol: 11 | 12 | ```python 13 | class Sampler(Protocol): 14 | 15 | def __iter__(self): 16 | ... 17 | 18 | def __next__(self) -> record.RecordMetadata: 19 | ... 20 | 21 | @dataclasses.dataclass 22 | class RecordMetadata: 23 | """RecordMetadata contains metadata about individual records. 24 | 25 | RecordMetadata objects are emitted by the sampler to refer to which record to 26 | read next (record_key), what its index is (for keeping progress and 27 | checkpointing) as well as having an optional rng for stateless random 28 | transformations. In addition, they are also used to keep information about 29 | records as they flow through the pipeline from one operation to the other. 30 | """ 31 | index: int 32 | record_key: Optional[int] = None 33 | rng: Optional[np.random.Generator] = None 34 | ``` 35 | 36 | ## Index Sampler 37 | 38 | This is our recommended Sampler. It supports: 39 | 40 | * Sharding across multiple machines (`shard_options` parameter). 41 | * Global shuffle of the data (`shuffle` parameter). 42 | * Repeating records for multiple epochs (`num_epochs` parameter). Note that 43 | the shuffle order changes across epochs. Behind the scenes, this relies on 44 | [tf.random_index_shuffle](https://www.tensorflow.org/api_docs/python/tf/random_index_shuffle). 45 | * Stateless random operations. Each `RecordMetadata` object emitted by the 46 | `IndexSampler` contains an RNG uniquely seeded on a per-record basis. This 47 | RNG can be used for random augmentations while not relying on a global 48 | state. 49 | 50 | ```python 51 | index_sampler = pygrain.IndexSampler( 52 | num_records=5, 53 | num_epochs=2, 54 | shard_options=pygrain.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True), 55 | shuffle=True, 56 | seed=0) 57 | for record_metadata in index_sampler: 58 | print(record_metadata) 59 | 60 | # Output 61 | # RecordMetadata(index=0, record_key=0, rng=Generator(Philox) at 0x7FB09947AF80) 62 | # RecordMetadata(index=1, record_key=4, rng=Generator(Philox) at 0x7FB0994789E0) 63 | # RecordMetadata(index=2, record_key=2, rng=Generator(Philox) at 0x7FB099478740) 64 | # RecordMetadata(index=3, record_key=3, rng=Generator(Philox) at 0x7FB0994789E0) 65 | # RecordMetadata(index=4, record_key=1, rng=Generator(Philox) at 0x7FB099478740) 66 | # RecordMetadata(index=5, record_key=1, rng=Generator(Philox) at 0x7FB0994789E0) 67 | # RecordMetadata(index=6, record_key=0, rng=Generator(Philox) at 0x7FB099478740) 68 | # RecordMetadata(index=7, record_key=3, rng=Generator(Philox) at 0x7FB0994789E0) 69 | # RecordMetadata(index=8, record_key=4, rng=Generator(Philox) at 0x7FB099478740) 70 | # RecordMetadata(index=9, record_key=2, rng=Generator(Philox) at 0x7FB0994789E0) 71 | ``` 72 | 73 | ## Implement your Own Sampler 74 | Grain can accommodate custom user-defined samplers. Users implementing their 75 | own sampler should ensure it: 76 | 77 | * implements the aforementioned interface. 78 | * is adequately performant. Since Grain's `DataLoader` iterates sequentially 79 | through the sampler to distribute indices to child processes, a slow sampler 80 | will become a bottleneck and reduce end-to-end pipeline performance. As a 81 | reference, we recommend sampler iteration performance of at approx. 50,000 82 | elements / sec for most use cases. 83 | -------------------------------------------------------------------------------- /docs/data_loader/transformations.md: -------------------------------------------------------------------------------- 1 | # Transformations 2 | 3 | 4 | 5 | .md 6 | 7 | 8 | 9 | Grain Transforms interface denotes transformations which are applied to data. In 10 | the case of local transformations (such as map, random map, filter), the 11 | transforms receive an element on which custom changes are applied. For global 12 | transformations (such as batching), one must provide the batch size. 13 | 14 | The Grain core transforms interface code is 15 | [here](https://github.com/google/grain/tree/main/grain/_src/core/transforms.py). 16 | 17 | ## MapTransform 18 | 19 | `MapTransform` is for 1:1 transformations of elements. Elements can be of any 20 | type, it is the user's responsibility to use the transformation such that the 21 | inputs it receives correspond to the signature. 22 | 23 | Example of transformation which implements `MapTransform` (for elements of type 24 | `int`): 25 | 26 | ```python 27 | class PlusOne(transforms.MapTransform): 28 | 29 | def map(self, x: int) -> int: 30 | return x + 1 31 | ``` 32 | 33 | ## MapWithIndexTransform 34 | 35 | `MapWithIndexTransform` is similar to `MapTransform` in being a 1:1 36 | transformations of elements, but also takes in the index/position of the element 37 | as the first argument. This is useful for pairing elements with an index key or 38 | even keeping it as metadata alongside the actual data. 39 | 40 | Example of transformation which implements `MapWithIndexTransform` (for elements 41 | of type `int`): 42 | 43 | ```python 44 | class PlusOneWithIndexKey(transforms.MapWithIndexTransform): 45 | 46 | def map_with_index(self, i: int, x: int) -> tuple[int, int]: 47 | return (x + 1, i) 48 | ``` 49 | 50 | ## RandomMapTransform 51 | 52 | `RandomMapTransform` is for 1:1 random transformations of elements. The 53 | interface requires a `np.random.Generator` as parameter to the `random_map` 54 | function. 55 | 56 | Example of a `RandomMapTransform`: 57 | 58 | ```python 59 | class PlusRandom(transforms.RandomMapTransform): 60 | 61 | def random_map(self, x: int, rng: np.random.Generator) -> int: 62 | return x + rng.integers(100_000) 63 | ``` 64 | 65 | ## FlatMapTransform 66 | 67 | `FlatMapTransform` is for splitting operations of individual elements. The 68 | `max_fan_out` is the maximum number of splits that an element can generate. 69 | Please consult the code for detailed info. 70 | 71 | Example of a `FlatMapTransform`: 72 | 73 | ```python 74 | class FlatMapTransformExample(transforms.FlatMapTransform): 75 | max_fan_out: int 76 | 77 | def flat_map(self, element: int): 78 | for _ in range(self.max_fan_out): 79 | yield element 80 | ``` 81 | 82 | ## FilterTransform 83 | 84 | `FilterTransform` is for applying filtering to individual elements. Elements for 85 | which the filter function returns False will be removed. 86 | 87 | Example of a `FilterTransform` that removes all even elements: 88 | 89 | ```python 90 | class RemoveEvenElements(FilterTransform): 91 | 92 | def filter(self, element: int) -> bool: 93 | return element % 2 94 | ``` 95 | 96 | ## Batch 97 | 98 | To apply the `Batch` transform, pass `grain.Batch(batch_size=batch_size, 99 | drop_remainder=drop_remainder)`. 100 | 101 | Note: The batch size used when passing `Batch` transform will be the global 102 | batch size if it is done before sharding and the *per host* batch size if it is 103 | after. Typically usage with `IndexSampler` is after sharding. 104 | -------------------------------------------------------------------------------- /docs/grain.checkpoint.rst: -------------------------------------------------------------------------------- 1 | 2 | ``grain.checkpoint`` module 3 | =========================== 4 | 5 | .. automodule:: grain.checkpoint 6 | 7 | List of Members 8 | --------------- 9 | 10 | .. autosummary:: 11 | :toctree: _autosummary 12 | 13 | CheckpointHandler 14 | CheckpointSave 15 | CheckpointRestore -------------------------------------------------------------------------------- /docs/grain.constants.rst: -------------------------------------------------------------------------------- 1 | ``grain.constants`` module 2 | ========================== 3 | 4 | .. automodule:: grain.constants 5 | 6 | List of Constants 7 | ----------------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | DATASET_INDEX 13 | EPOCH 14 | INDEX 15 | META_FEATURES 16 | RECORD 17 | RECORD_KEY 18 | SEED -------------------------------------------------------------------------------- /docs/grain.dataset.rst: -------------------------------------------------------------------------------- 1 | ``grain`` Dataset 2 | ================= 3 | 4 | .. automodule:: grain._src.python.dataset.dataset 5 | .. currentmodule:: grain 6 | 7 | List of Members 8 | --------------- 9 | 10 | .. autoclass:: _src.python.dataset.dataset.MapDatasetMeta 11 | :members: 12 | 13 | .. autoclass:: MapDataset 14 | :special-members: __init__, __len__, __getitem__, __iter__ 15 | :members: 16 | :show-inheritance: 17 | :inherited-members: 18 | :undoc-members: 19 | 20 | .. autoclass:: _src.python.dataset.dataset.IterDatasetMeta 21 | :members: 22 | 23 | .. autoclass:: IterDataset 24 | :special-members: __init__, __iter__ 25 | :members: 26 | :show-inheritance: 27 | :inherited-members: 28 | :undoc-members: 29 | 30 | .. autoclass:: DatasetIterator 31 | :special-members: __init__, __iter__ 32 | :members: 33 | :show-inheritance: 34 | :inherited-members: 35 | :undoc-members: 36 | -------------------------------------------------------------------------------- /docs/grain.experimental.rst: -------------------------------------------------------------------------------- 1 | ``grain.experimental`` module 2 | ============================= 3 | 4 | .. automodule:: grain.experimental 5 | 6 | List of Members 7 | --------------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | FlatMapTransform 13 | DatasetOptions 14 | ExecutionTrackingMode 15 | apply_transformations 16 | ElasticIterator 17 | WithOptionsIterDataset 18 | ParquetIterDataset 19 | FlatMapMapDataset 20 | FlatMapIterDataset 21 | InterleaveIterDataset 22 | LimitIterDataset 23 | RngPool 24 | FirstFitPackIterDataset 25 | BOSHandling 26 | ConcatThenSplitIterDataset 27 | ThreadPrefetchIterDataset 28 | WindowShuffleMapDataset 29 | WindowShuffleIterDataset 30 | ZipMapDataset 31 | ZipIterDataset 32 | PackAndBatchOperation 33 | assert_equal_output_after_checkpoint 34 | -------------------------------------------------------------------------------- /docs/grain.multiprocessing.rst: -------------------------------------------------------------------------------- 1 | ``grain.multiprocessing`` module 2 | ================================ 3 | 4 | .. automodule:: grain.multiprocessing 5 | 6 | List of Members 7 | --------------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | MultiprocessingOptions 13 | SharedMemoryArray 14 | -------------------------------------------------------------------------------- /docs/grain.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: grain 2 | 3 | Public API: ``grain`` package 4 | ============================= 5 | 6 | Subpackages 7 | ----------- 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | 12 | grain.checkpoint 13 | grain.constants 14 | grain.experimental 15 | grain.multiprocessing 16 | grain.samplers 17 | grain.sharding 18 | grain.sources 19 | grain.transforms 20 | 21 | 22 | Flexible low-level pipelines 23 | ---------------------------- 24 | 25 | .. autosummary:: 26 | 27 | MapDataset 28 | IterDataset 29 | DatasetIterator 30 | ReadOptions 31 | 32 | 33 | Simple high-level pipelines 34 | --------------------------- 35 | 36 | 37 | .. autosummary:: 38 | :toctree: _autosummary 39 | 40 | load 41 | DataLoader 42 | DataLoaderIterator 43 | Record 44 | RecordMetadata -------------------------------------------------------------------------------- /docs/grain.samplers.rst: -------------------------------------------------------------------------------- 1 | ``grain.samplers`` module 2 | ========================= 3 | 4 | .. automodule:: grain.samplers 5 | 6 | List of Members 7 | --------------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | IndexSampler 13 | Sampler 14 | SequentialSampler 15 | -------------------------------------------------------------------------------- /docs/grain.sharding.rst: -------------------------------------------------------------------------------- 1 | ``grain.sharding`` module 2 | ========================= 3 | 4 | .. automodule:: grain.sharding 5 | 6 | List of Members 7 | --------------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | NoSharding 13 | ShardByJaxProcess 14 | ShardOptions 15 | -------------------------------------------------------------------------------- /docs/grain.sources.rst: -------------------------------------------------------------------------------- 1 | ``grain.sources`` module 2 | ======================== 3 | 4 | .. automodule:: grain.sources 5 | 6 | List of Members 7 | --------------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | ArrayRecordDataSource 13 | SharedMemoryDataSource 14 | RandomAccessDataSource 15 | RangeDataSource 16 | -------------------------------------------------------------------------------- /docs/grain.transforms.rst: -------------------------------------------------------------------------------- 1 | ``grain.transforms`` module 2 | =========================== 3 | 4 | .. automodule:: grain.transforms 5 | 6 | List of Members 7 | --------------- 8 | 9 | .. autosummary:: 10 | :toctree: _autosummary 11 | 12 | Batch 13 | Filter 14 | Map 15 | MapWithIndex 16 | RandomMap 17 | Transformation 18 | Transformations 19 | DatasetSelectionMap 20 | -------------------------------------------------------------------------------- /docs/images/data_flow_multiple_workers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/docs/images/data_flow_multiple_workers.png -------------------------------------------------------------------------------- /docs/images/data_flow_zero_workers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/docs/images/data_flow_zero_workers.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Grain - Feeding JAX Models 2 | 3 | 4 | 5 | Grain is a library for reading data for training and evaluating JAX models. It's 6 | open source, fast and deterministic. 7 | 8 | ::::{grid} 1 2 2 3 9 | :gutter: 1 1 1 2 10 | 11 | :::{grid-item-card} {octicon}`zap;1.5em;sd-mr-1` Powerful 12 | Users can bring arbitrary Python transformations. 13 | ::: 14 | 15 | :::{grid-item-card} {octicon}`tools;1.5em;sd-mr-1` Flexible 16 | Grain is designed to 17 | be modular. Users can readily override Grain components if need be with their 18 | own implementation. 19 | ::: 20 | 21 | :::{grid-item-card} {octicon}`versions;1.5em;sd-mr-1` Deterministic 22 | Multiple runs of the same pipeline will produce the same output. 23 | ::: 24 | 25 | :::{grid-item-card} {octicon}`check-circle;1.5em;sd-mr-1` Resilient to preemptions 26 | Grain is designed such that checkpoints have minimal size. After 27 | pre-emption, Grain can resume from where it left off and produce the same output 28 | as if it was never preempted. 29 | ::: 30 | 31 | :::{grid-item-card} {octicon}`zap;1.5em;sd-mr-1` Performant 32 | We took care while designing Grain to ensure that it's performant (refer to the 33 | [Behind the Scenes](behind_the_scenes.md) section of the documentation.) We also 34 | tested it against multiple data modalities (e.g.Text/Audio/Images/Videos). 35 | ::: 36 | 37 | :::{grid-item-card} {octicon}`package;1.5em;sd-mr-1` With minimal dependencies 38 | Grain minimizes its set of dependencies when possible. For example, it should 39 | not depend on TensorFlow. 40 | ::: 41 | :::: 42 | 43 | ``` {toctree} 44 | :maxdepth: 1 45 | :hidden: 46 | :caption: Get started 47 | installation 48 | api_choice 49 | data_sources 50 | behind_the_scenes 51 | ``` 52 | 53 | ``` {toctree} 54 | :maxdepth: 1 55 | :hidden: 56 | :caption: Data Loader 57 | data_loader/samplers 58 | data_loader/transformations 59 | ``` 60 | 61 | ```{toctree} 62 | :maxdepth: 3 63 | :hidden: 64 | :caption: Tutorials 65 | tutorials/data_loader_tutorial 66 | tutorials/dataset_basic_tutorial 67 | tutorials/dataset_advanced_tutorial 68 | tutorials/dataset_debugging_tutorial 69 | tutorials/dataset_load_from_s3_tutorial 70 | tutorials/data_sources/index 71 | ``` 72 | 73 | ``` {toctree} 74 | :maxdepth: 1 75 | :hidden: 76 | :caption: API reference 77 | grain 78 | ``` 79 | 80 | ``` {toctree} 81 | :maxdepth: 1 82 | :hidden: 83 | :caption: Contributor guides 84 | CONTRIBUTING 85 | ``` -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | # Installing Grain 2 | 3 | 4 | 5 | To install Grain, you can use pip: 6 | 7 | ```bash 8 | pip install grain 9 | ``` 10 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Sphinx-related requirements. 2 | sphinx 3 | sphinx-book-theme>=1.0.1 4 | myst-nb 5 | myst-parser[linkify] 6 | sphinx-book-theme 7 | sphinx-copybutton 8 | sphinx-design 9 | # Avoiding an issue with the collapsible sidebar. 10 | pydata-sphinx-theme<0.16.0 11 | # To generate API documentation. 12 | sphinx-autoapi 13 | sphinx-autodoc2 14 | # To import the Grain package. We mock all other dependencies, but this one has 15 | # context managers that are tricky to mock. 16 | etils[epath,epy] -------------------------------------------------------------------------------- /docs/tutorials/data_sources/arrayrecord_data_source_tutorial.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupyter: 3 | jupytext: 4 | text_representation: 5 | extension: .md 6 | format_name: markdown 7 | format_version: '1.3' 8 | jupytext_version: 1.17.1 9 | kernelspec: 10 | display_name: Python 3 11 | name: python3 12 | --- 13 | 14 | 15 | # Reading ArrayRecord Files 16 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/arrayrecord_data_source_tutorial.ipynb) 17 | 18 | This tutorial provides an example of how to retrieve records from ArrayRecord 19 | files using `grain.sources.ArrayRecordDataSource`, also covers how to process 20 | and transform the data with Grain. 21 | 22 | 23 | 24 | 25 | ## Install and Load Dependencies 26 | 27 | 28 | ```python id="tzWZLNklr4Iy" 29 | !pip install grain array_record 30 | ``` 31 | 32 | ```python id="8NF4E-cCbyjV" 33 | import pickle 34 | import grain 35 | import tensorflow_datasets as tfds 36 | from array_record.python import array_record_module 37 | ``` 38 | 39 | 40 | ## Write a temp ArrayRecord file 41 | 42 | 43 | ```python id="WrCQ-jH53t-K" 44 | # Load a public tensorflow dataset. 45 | test_tfds = tfds.data_source("bool_q", split="train") 46 | ``` 47 | 48 | ```python id="_0yBaN7hXmbu" 49 | # Write the dataset into a test array_record file. 50 | example_file_path = "./test.array_record" 51 | writer = array_record_module.ArrayRecordWriter( 52 | example_file_path, "group_size:1" 53 | ) 54 | record_count = 0 55 | for record in test_tfds: 56 | writer.write(pickle.dumps(record)) 57 | record_count += 1 58 | writer.close() 59 | 60 | print( 61 | f"Number of records written to array_record file {example_file_path} :" 62 | f" {record_count}" 63 | ) 64 | ``` 65 | 66 | ```python id="HKJ_49JCXmbu" 67 | # @title Load Data Source 68 | example_array_record_data_source = (grain.sources.ArrayRecordDataSource( 69 | example_file_path 70 | )) 71 | print(f"Number of records: {len(example_array_record_data_source)}") 72 | ``` 73 | 74 | ```python id="NVRGllY3Xmbu" 75 | print(example_array_record_data_source[0]) 76 | ``` 77 | 78 | 79 | ## Define Transformation Function 80 | 81 | 82 | ```python id="0AS5w9quXmbu" 83 | # Load a pre trained tokenizer 84 | from tokenizers import Tokenizer 85 | 86 | tokenizer = Tokenizer.from_pretrained("bert-base-cased") 87 | ``` 88 | 89 | ```python id="YiS85paBXmbu" 90 | class ParseAndTokenizeText(grain.transforms.Map): 91 | """This function takes a serialized dict (as bytes), decodes it, 92 | 93 | applies a tokenizer to a specified feature within the dict, 94 | and returns the first 10 tokens from results. 95 | """ 96 | 97 | def __init__(self, tokenizer, feature_name): 98 | self._tokenizer = tokenizer 99 | self._feature_name = feature_name 100 | 101 | def map(self, element: bytes) -> [str]: 102 | parsed_element = pickle.loads(element) 103 | # only pick the first 10 token IDs from the tokenized text for testing 104 | return self._tokenizer.encode( 105 | parsed_element[self._feature_name].decode('utf-8') 106 | ).tokens[:10] 107 | ``` 108 | 109 | 110 | ## Load and process data via the Dataset API 111 | 112 | 113 | ```python id="RPIy05gGUBzI" 114 | # Example using Grain's MapDataset with ArrayRecord file source. 115 | example_datasets = ( 116 | grain.MapDataset.source(example_array_record_data_source) 117 | .shuffle(seed=42) 118 | .map(ParseAndTokenizeText(tokenizer, "question")) 119 | .batch(batch_size=10) 120 | ) 121 | ``` 122 | 123 | ```python id="xqJSeQ9hdAmF" 124 | # Output a record at a random index 125 | print(example_datasets[100]) 126 | ``` 127 | -------------------------------------------------------------------------------- /docs/tutorials/data_sources/index.rst: -------------------------------------------------------------------------------- 1 | .. _data-sources-tutorials-section: 2 | 3 | Data Sources 4 | ============ 5 | 6 | This section contains tutorials for using Grain to read data from various sources. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | parquet_dataset_tutorial.md 12 | arrayrecord_data_source_tutorial.md 13 | load_from_s3_tutorial.md 14 | load_from_gcs_tutorial.md 15 | huggingface_dataset_tutorial.md 16 | pytorch_dataset_tutorial.md 17 | -------------------------------------------------------------------------------- /docs/tutorials/data_sources/load_from_gcs_tutorial.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupyter: 3 | jupytext: 4 | text_representation: 5 | extension: .md 6 | format_name: markdown 7 | format_version: '1.3' 8 | jupytext_version: 1.17.1 9 | kernelspec: 10 | display_name: Python 3 11 | name: python3 12 | --- 13 | 14 | 15 | 16 | # Reading from GCS 17 | 18 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/load_from_gcs_tutorial.ipynb) 19 | 20 | This document demonstrates how to access and load data from Google Cloud Storage using Grain. To achieve this, we'll utilize Cloud Storage [FUSE](https://cloud.google.com/storage/docs/cloud-storage-fuse/overview), an adapter that allows you to mount GCS buckets as local file systems. By using Cloud Storage FUSE to mount GCS buckets as local file systems, you can access cloud storage data just like local files. 21 | 22 | 23 | 24 | ## Mount a Cloud Storage location into the local filesystem 25 | 26 | 27 | ```python id="h6HqcZSQ_Y23" 28 | # Authenticate. 29 | from google.colab import auth 30 | auth.authenticate_user() 31 | ``` 32 | 33 | 34 | 35 | 36 | The gcsfuse CLI offers various configurable options, detailed at https://cloud.google.com/storage/docs/gcsfuse-cli. Utilizing certain options, such as the caching features described at https://cloud.google.com/storage/docs/cloud-storage-fuse/caching, can enhance read performance and lower costs. For instance, MaxText setup gcsfuse flags ([MaxText gcsfuse setting link](https://github.com/AI-Hypercomputer/maxtext/blob/4e36b61cf40698224c5251c4aa4086df24140bd1/setup_gcsfuse.sh#L48)) to reduce data loading time for training. We advise users to consider adopting similar settings or customizing their own gcsfuse options. 37 | 38 | 39 | ```python id="bqz6cD7xl7F3" 40 | # Mount a Cloud Storage bucket or location, without the gs:// prefix. 41 | mount_path = "my-bucket" # or a location like "my-bucket/path/to/mount" 42 | local_path = f"./mnt/gs/{mount_path}" 43 | 44 | !mkdir -p {local_path} 45 | # The flags below are configured to improve GCS data loading performance. Users are encouraged to explore alternative settings and we would greatly appreciate any feedback or insights shared with the Grain team. 46 | !gcsfuse --implicit-dirs --type-cache-max-size-mb=-1 --stat-cache-max-size-mb=-1 --kernel-list-cache-ttl-secs=-1 --metadata-cache-ttl-secs=-1 {mount_path} {local_path} 47 | ``` 48 | 49 | ```python id="j2e8nv0j_Y23" 50 | # Then you can access it like a local path. 51 | !ls -lh {local_path} 52 | ``` 53 | 54 | 55 | ## Read files using Grain 56 | 57 | If your data is in an ArrayRecord file, you can directly load it using `grain.sources.ArrayRecordDataSource`. For information on handling other file formats, please see the Grain data sources documentation at: https://google-grain.readthedocs.io/en/latest/data_sources.html 58 | 59 | 60 | ```python id="yisjIpbZ_Y23" 61 | # Install Grain. 62 | !pip install grain 63 | ``` 64 | 65 | ```python id="pvNTx6sL_Y23" 66 | import grain 67 | 68 | source = grain.sources.ArrayRecordDataSource(local_path+"/local_file_name") 69 | 70 | # Create a dataset from the data source then process the data. 71 | dataset = ( 72 | grain.MapDataset.source(source) 73 | .shuffle(seed=10) # Shuffles globally. 74 | .batch(batch_size=2) # Batches consecutive elements. 75 | ) 76 | ``` 77 | 78 | ```python id="bJIYx60H_Y23" 79 | # Output a record at a random index 80 | print(dataset[10]) 81 | ``` 82 | -------------------------------------------------------------------------------- /docs/tutorials/data_sources/load_from_s3_tutorial.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupyter: 3 | jupytext: 4 | text_representation: 5 | extension: .md 6 | format_name: markdown 7 | format_version: '1.3' 8 | jupytext_version: 1.17.1 9 | kernelspec: 10 | display_name: Python 3 11 | name: python3 12 | --- 13 | 14 | 15 | ## Reading from AWS S3 16 | 17 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/load_from_s3_tutorial.ipynb) 18 | 19 | This document outlines how to read data from an Amazon S3 bucket and construct a Grain pipeline. We will leverage [S3 Mountpoint](https://docs.aws.amazon.com/AmazonS3/latest/userguide/mountpoint.html), a service provided by AWS. S3 Mountpoint enables you to mount your S3 bucket as a local file system, allowing you to access and read data as if it were stored locally. 20 | 21 | 22 | 23 | ### Install Mountpoint for Amazon S3 24 | 25 | 26 | ```python id="K6UTOyamWlWf" 27 | !wget https://s3.amazonaws.com/mountpoint-s3-release/latest/x86_64/mount-s3.deb 28 | ``` 29 | 30 | ```python id="iHA-C85NhwFJ" 31 | !sudo apt-get install -y ./mount-s3.deb 32 | ``` 33 | 34 | 35 | ### Configure AWS credentials 36 | 37 | 38 | ```python id="8fhEOwxcWlWf" 39 | !pip install aws configure 40 | !pip install awscli 41 | ``` 42 | 43 | ```python id="5Lt_644G7G9R" 44 | !aws configure 45 | ``` 46 | 47 | 48 | ### Mount your S3 bucket to your local filepath 49 | 50 | 51 | ```python id="G6boYrD5WlWf" 52 | !mount-s3 /path/to/mount/files 53 | ``` 54 | 55 | 56 | ### Install Grain and other dependencies 57 | 58 | 59 | ```python id="3BZP9fBiWlWf" 60 | !pip install grain 61 | !pip install array_record 62 | ``` 63 | 64 | 65 | ### Write temp ArrayRecord files to the bucket 66 | 67 | 68 | ```python id="xVGVuDKNic0B" 69 | from array_record.python import array_record_module 70 | 71 | digits = [b"1", b"2", b"3", b"4", b"5"] 72 | 73 | writer = array_record_module.ArrayRecordWriter("/path/to/mount/files/data.array_record") 74 | for i in digits: 75 | writer.write(i) 76 | writer.close() 77 | ``` 78 | 79 | 80 | ### Read ArrayRecord files using Grain 81 | 82 | 83 | ```python id="3l4Pnc4bWlWf" 84 | import grain 85 | from pprint import pprint 86 | 87 | source = grain.sources.ArrayRecordDataSource(paths="/path/to/mount/files/data.array_record") 88 | 89 | dataset = ( 90 | grain.MapDataset.source(source) 91 | .shuffle(seed=10) # Shuffles globally. 92 | .batch(batch_size=2) # Batches consecutive elements. 93 | ) 94 | 95 | pprint(list(dataset)) 96 | ``` 97 | -------------------------------------------------------------------------------- /docs/tutorials/data_sources/parquet_dataset_tutorial.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupyter: 3 | jupytext: 4 | text_representation: 5 | extension: .md 6 | format_name: markdown 7 | format_version: '1.3' 8 | jupytext_version: 1.17.1 9 | kernelspec: 10 | display_name: Python 3 11 | name: python3 12 | --- 13 | 14 | 15 | # Reading Apache Parquet Files 16 | 17 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/data_sources/parquet_dataset_tutorial.ipynb) 18 | 19 | This tutorial provides an example of how to read data from [Apache Parquet](https://parquet.apache.org/) file, also covers how to process and transform the data with Grain. 20 | 21 | 22 | 23 | 24 | ## Generate a test Parquet file on local 25 | 26 | 27 | ```python id="dFnwN6NNw0Oe" 28 | import pyarrow as pa 29 | import pyarrow.parquet as pq 30 | 31 | # Generate a sample PyArrow table containing email subjects and bodies. 32 | table = pa.table({ 33 | 'email_subject': [ 34 | "Meeting Reminder: Project X Update", 35 | "Important Announcement Regarding Company Policy", 36 | "FWD: Quick Question", 37 | "Your Order Confirmation #12345", 38 | "Invitation to Team Building Activity" 39 | ], 40 | 'email_body': [ 41 | "Hi team,\n\nJust a reminder about our Project X update meeting tomorrow at 10 AM PST. Please come prepared to discuss your progress and any roadblocks.\n\nSee you there,\n[Your Name]", 42 | "Dear employees,\n\nPlease be advised of a new company policy regarding remote work, effective May 1st, 2025. You can find the full details on the company intranet.\n\nRegards,\nManagement", 43 | "Hi [Name],\n\nForwarding you this email as you might have the answer to this quick question:\n\n[Original Email Content]", 44 | "Dear [Customer Name],\n\nThank you for your recent order! This email confirms your order #12345. You can view the details and track its shipment here: [Link]\n\nSincerely,\nThe [Company Name] Team", 45 | "Hello everyone,\n\nYou're invited to participate in our upcoming team building activity on Friday, April 28th. It will be a fun afternoon of [Activity]. Please RSVP by Wednesday.\n\nBest,\n[Organizer Name]" 46 | ] 47 | }) 48 | 49 | # Write this table to a parquet file. 50 | writer = pq.ParquetWriter('emails.parquet', table.schema) 51 | writer.write_table(table) 52 | writer.close() 53 | 54 | ``` 55 | 56 | 57 | ## Load Dataset 58 | 59 | 60 | ```python id="fNIApmKT34ac" 61 | # Install Grain 62 | !pip install grain 63 | ``` 64 | 65 | ```python id="nb5yPasvDwaj" 66 | import grain 67 | import pprint 68 | ``` 69 | 70 | ```python id="62_D74LkdrQn" 71 | ds = grain.experimental.ParquetIterDataset('./emails.parquet') 72 | ``` 73 | 74 | ```python id="DlhbJX5zdrQo" 75 | list(ds)[0] 76 | ``` 77 | 78 | 79 | ## Transform Dataset 80 | 81 | 82 | ```python id="1qevksHMdrQo" 83 | # Load a pre trained tokenizer. 84 | from tokenizers import Tokenizer 85 | tokenizer = Tokenizer.from_pretrained("bert-base-cased") 86 | ``` 87 | 88 | ```python id="PNPv8LEGdrQo" 89 | class TokenizeText(grain.transforms.Map): 90 | """Tokenizes the text values within each element using a provided tokenizer.""" 91 | def __init__(self, tokenizer): 92 | self._tokenizer = tokenizer 93 | 94 | def map(self, element): 95 | return [self._tokenizer.encode(item).tokens for item in element.values()] 96 | ``` 97 | 98 | ```python id="O5vnwj7cek30" 99 | # Tokenize the data using the provided tokenizer. 100 | ds = ds.map(TokenizeText(tokenizer)) 101 | ``` 102 | 103 | ```python id="46ZVbfmtek30" 104 | # Create an iterator object of the dataset. 105 | ds_iter = iter(ds) 106 | # Print the first element in the dataset. 107 | pprint.pprint(next(ds_iter)) 108 | ``` 109 | -------------------------------------------------------------------------------- /docs/tutorials/dataset_debugging_tutorial.md: -------------------------------------------------------------------------------- 1 | --- 2 | jupytext: 3 | formats: ipynb,md:myst 4 | main_language: python 5 | text_representation: 6 | extension: .md 7 | format_name: myst 8 | format_version: 0.13 9 | jupytext_version: 1.16.1 10 | kernelspec: 11 | display_name: Python 3 12 | name: python3 13 | --- 14 | 15 | +++ {"id": "OHoxgqr6sRKE"} 16 | 17 | # Performance & Debugging tool 18 | Grain offers two configurable modes that can be set to gain deeper insights into 19 | pipeline execution and identify potential issues. 20 | 21 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/grain/blob/main/docs/tutorials/dataset_debugging_tutorial.ipynb) 22 | 23 | ```{code-cell} 24 | :id: xw_-jT1r6zNM 25 | 26 | # @test {"output": "ignore"} 27 | !pip install grain 28 | ``` 29 | 30 | +++ {"id": "YLaRRlCPsRKE"} 31 | 32 | ## Visualization mode 33 | To get an overview of your dataset pipeline structure and clear understanding of 34 | how the data flows, enable visualization mode. This will log a visual 35 | representation of your pipeline, allowing you to easily identify different 36 | transformation stages and their relationships. To enable visualization mode, set 37 | the flag `--grain_py_dataset_visualization_output_dir=""` or call 38 | `grain.config.update("py_dataset_visualization_output_dir", "")` 39 | 40 | ```{code-cell} 41 | :id: 4y89Wx7PsRKE 42 | 43 | # @test {"output": "ignore"} 44 | import grain.python as grain 45 | 46 | grain.config.update("py_dataset_visualization_output_dir", "") 47 | ds = ( 48 | grain.MapDataset.range(20) 49 | .seed(seed=42) 50 | .shuffle() 51 | .batch(batch_size=2) 52 | .map(lambda x: x) 53 | .to_iter_dataset() 54 | ) 55 | it = iter(ds) 56 | 57 | # Visualization graph is constructed once the dataset produces the first element 58 | for _ in range(10): 59 | next(it) 60 | ``` 61 | 62 | +++ {"id": "_3h-u2I1i7wv"} 63 | 64 | ## Debug mode 65 | To troubleshoot performance issues in your dataset pipeline, enable debug mode. 66 | This will log a real-time execution summary of the pipeline at one-minute 67 | intervals. This execution summary provides a detailed information on each 68 | transformation stage such as processing time, number of elements processed and 69 | other details that helps in identifying the slower stages in the pipeline. 70 | To enable debug mode, set the flag `--grain_py_debug_mode=true` or call 71 | `grain.config.update("py_debug_mode",True)` 72 | 73 | ```{code-cell} 74 | :id: bN45Z58E3jGS 75 | 76 | import time 77 | 78 | 79 | # Define a dummy slow preprocessing function 80 | def _dummy_slow_fn(x): 81 | time.sleep(10) 82 | return x 83 | ``` 84 | 85 | ```{code-cell} 86 | --- 87 | colab: 88 | height: 897 89 | id: bN45Z58E3jGS 90 | outputId: f3d640a8-1eae-414f-e6eb-e7c02c9a91df 91 | --- 92 | # @test {"output": "ignore"} 93 | import time 94 | 95 | grain.config.update("py_debug_mode", True) 96 | 97 | ds = ( 98 | grain.MapDataset.range(20) 99 | .seed(seed=42) 100 | .shuffle() 101 | .batch(batch_size=2) 102 | .map(_dummy_slow_fn) 103 | .to_iter_dataset() 104 | .map(_dummy_slow_fn) 105 | ) 106 | it = iter(ds) 107 | 108 | for _ in range(10): 109 | next(it) 110 | ``` 111 | 112 | +++ {"id": "eSu9SOP8_x6A"} 113 | 114 | In the above execution summary, 86% of the time is spent in the 115 | `MapDatasetIterator` node and is the slowest stage of the pipeline. 116 | 117 | Note that although from the `total_processing_time`, it might appear that 118 | `MapMapDataset`(id:2) is the slowest stage, nodes from the id 2 to 6 are 119 | executed in multiple threads and hence, the `total_processing_time` of these 120 | nodes should be compared to the `total_processing_time` of iterator nodes(id:0) 121 | -------------------------------------------------------------------------------- /grain/BUILD: -------------------------------------------------------------------------------- 1 | 2 | load("@bazel_skylib//:bzl_library.bzl", "bzl_library") 3 | 4 | package(default_visibility = ["//grain:__subpackages__"]) 5 | 6 | licenses(["notice"]) 7 | 8 | exports_files(["LICENSE"]) 9 | 10 | # Backwards compatibility alias. 11 | alias( 12 | name = "python", 13 | actual = ":grain", 14 | visibility = ["//visibility:public"], 15 | ) 16 | 17 | py_library( 18 | name = "grain", 19 | srcs = [ 20 | "__init__.py", 21 | "_src/__init__.py", 22 | "checkpoint.py", 23 | "constants.py", 24 | "experimental.py", 25 | "multiprocessing.py", 26 | "python/__init__.py", 27 | "python/experimental.py", 28 | "samplers.py", 29 | "sharding.py", 30 | "sources.py", 31 | "transforms.py", 32 | ], 33 | data = ["//grain/_src/python/experimental/index_shuffle/python:index_shuffle_module.so"], 34 | srcs_version = "PY3", 35 | # Implicit build flag 36 | visibility = ["//visibility:public"], 37 | deps = [ 38 | "//grain/_src/core:config", 39 | "//grain/_src/core:constants", 40 | "//grain/_src/core:monitoring", 41 | "//grain/_src/core:sharding", 42 | "//grain/_src/core:transforms", 43 | "//grain/_src/python:checkpoint_handlers", 44 | "//grain/_src/python:data_loader", 45 | "//grain/_src/python:data_sources", 46 | "//grain/_src/python:load", 47 | "//grain/_src/python:operations", 48 | "//grain/_src/python:options", 49 | "//grain/_src/python:record", 50 | "//grain/_src/python:samplers", 51 | "//grain/_src/python:shared_memory_array", 52 | "//grain/_src/python/dataset", 53 | "//grain/_src/python/dataset:base", 54 | "//grain/_src/python/dataset:elastic_iterator", 55 | "//grain/_src/python/dataset:stats", 56 | "//grain/_src/python/dataset:visualize", 57 | "//grain/_src/python/dataset/sources:parquet_dataset", 58 | "//grain/_src/python/dataset/sources:tfrecord_dataset", 59 | "//grain/_src/python/dataset/transformations:interleave", 60 | "//grain/_src/python/dataset/transformations:limit", 61 | "//grain/_src/python/dataset/transformations:packing", 62 | "//grain/_src/python/dataset/transformations:packing_concat_then_split", 63 | "//grain/_src/python/dataset/transformations:zip", 64 | "//grain/_src/python/experimental/example_packing:packing", 65 | "//grain/_src/python/testing:experimental", 66 | "//grain/proto:execution_summary_py_pb2", 67 | ], 68 | ) 69 | -------------------------------------------------------------------------------- /grain/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public API for Grain.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=unused-import 19 | # pylint: disable=g-multiple-import 20 | # pylint: disable=g-import-not-at-top 21 | 22 | # We import all public modules here to enable the use of `grain.foo.Bar` 23 | # instead of forcing users to write `from grain import foo as grain_foo`. 24 | from grain import ( 25 | experimental, 26 | checkpoint, 27 | constants, 28 | multiprocessing, 29 | samplers, 30 | sharding, 31 | sources, 32 | transforms, 33 | ) 34 | 35 | from grain._src.core.config import config 36 | from grain._src.python.data_loader import ( 37 | DataLoader, 38 | DataLoaderIterator, 39 | ) 40 | from grain._src.python.dataset.dataset import ( 41 | DatasetIterator, 42 | IterDataset, 43 | MapDataset, 44 | ) 45 | from grain._src.python.load import load 46 | from grain._src.python.options import ReadOptions 47 | from grain._src.python.record import Record, RecordMetadata 48 | -------------------------------------------------------------------------------- /grain/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /grain/_src/core/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//grain:__subpackages__"]) 2 | 3 | licenses(["notice"]) 4 | 5 | py_library( 6 | name = "config", 7 | srcs = ["config.py"], 8 | srcs_version = "PY3", 9 | deps = [ 10 | ":monitoring", 11 | "@abseil-py//absl/flags", 12 | ], 13 | ) 14 | 15 | py_library( 16 | name = "constants", 17 | srcs = ["constants.py"], 18 | srcs_version = "PY3", 19 | ) 20 | 21 | py_library( 22 | name = "exceptions", 23 | srcs = ["exceptions.py"], 24 | srcs_version = "PY3", 25 | ) 26 | 27 | py_library( 28 | name = "monitoring", 29 | srcs = ["monitoring.py"], 30 | srcs_version = "PY3", 31 | deps = [ 32 | ], 33 | ) 34 | 35 | py_library( 36 | name = "parallel", 37 | srcs = ["parallel.py"], 38 | srcs_version = "PY3", 39 | deps = [ 40 | ], 41 | ) 42 | 43 | py_test( 44 | name = "parallel_test", 45 | srcs = ["parallel_test.py"], 46 | srcs_version = "PY3", 47 | deps = [ 48 | ":parallel", 49 | "@abseil-py//absl/testing:absltest", 50 | "@abseil-py//absl/testing:parameterized", 51 | ], 52 | ) 53 | 54 | py_library( 55 | name = "grain_random", 56 | srcs = ["grain_random.py"], 57 | srcs_version = "PY3", 58 | deps = [ 59 | "@abseil-py//absl/logging", 60 | "@pypi//jax:pkg", 61 | "@pypi//numpy:pkg", 62 | ], 63 | ) 64 | 65 | py_library( 66 | name = "sharding", 67 | srcs = ["sharding.py"], 68 | srcs_version = "PY3", 69 | deps = [ 70 | "@abseil-py//absl/logging", 71 | ], 72 | ) 73 | 74 | py_test( 75 | name = "sharding_test", 76 | srcs = ["sharding_test.py"], 77 | srcs_version = "PY3", 78 | deps = [ 79 | ":sharding", 80 | "@abseil-py//absl/testing:absltest", 81 | "@abseil-py//absl/testing:parameterized", 82 | ], 83 | ) 84 | 85 | py_library( 86 | name = "usage_logging", 87 | srcs = ["usage_logging.py"], 88 | srcs_version = "PY3", 89 | ) 90 | 91 | py_library( 92 | name = "transforms", 93 | srcs = ["transforms.py"], 94 | srcs_version = "PY3", 95 | deps = [ 96 | "@pypi//numpy:pkg", 97 | ], 98 | ) 99 | 100 | py_test( 101 | name = "transforms_test", 102 | srcs = ["transforms_test.py"], 103 | srcs_version = "PY3", 104 | deps = [ 105 | ":transforms", 106 | "@abseil-py//absl/testing:absltest", 107 | "@abseil-py//absl/testing:parameterized", 108 | ], 109 | ) 110 | 111 | py_library( 112 | name = "tree_lib", 113 | srcs = [ 114 | "tree_lib.py", 115 | ], 116 | srcs_version = "PY3", 117 | deps = [ 118 | "@pypi//dm_tree:pkg", 119 | "@pypi//numpy:pkg", 120 | ], 121 | ) 122 | 123 | py_library( 124 | name = "tree_test_lib", 125 | testonly = 1, 126 | srcs = ["tree_lib_test.py"], 127 | srcs_version = "PY3", 128 | deps = [ 129 | ":tree_lib", 130 | "@abseil-py//absl/testing:absltest", 131 | "@abseil-py//absl/testing:parameterized", 132 | "@pypi//numpy:pkg", 133 | ], 134 | ) 135 | 136 | py_test( 137 | name = "tree_lib_test", 138 | srcs = ["tree_lib_test.py"], 139 | srcs_version = "PY3", 140 | deps = [ 141 | ":tree_test_lib", 142 | ], 143 | ) 144 | 145 | py_test( 146 | name = "tree_lib_jax_test", 147 | srcs = ["tree_lib_jax_test.py"], 148 | srcs_version = "PY3", 149 | deps = [ 150 | ":tree_lib", 151 | ":tree_test_lib", 152 | "@abseil-py//absl/testing:absltest", 153 | "@pypi//attrs:pkg", 154 | "@pypi//jax:pkg", 155 | "@pypi//numpy:pkg", 156 | ], 157 | ) 158 | 159 | py_test( 160 | name = "version_test", 161 | srcs = ["version_test.py"], 162 | srcs_version = "PY3", 163 | deps = [ 164 | "//grain:python", 165 | "@abseil-py//absl/testing:absltest", 166 | ], 167 | ) 168 | -------------------------------------------------------------------------------- /grain/_src/core/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Shared constants for various Grain APIs.""" 15 | 16 | # Below are names of meta features used by index_dataset (and pipelines building 17 | # on top of it). These features are generated on the fly and help to track 18 | # progress over the dataset. Users can read these but shouldn't alter them. They 19 | # start with "_" to indicate that they are "private". 20 | # Index into the stream of all records (globally unique). Starts with 0. 21 | INDEX = "_index" 22 | # Key of the record. If DATASET_INDEX is present it's the key in the dataset. 23 | # Starts with 0. 24 | RECORD_KEY = "_record_key" 25 | # Index of the dataset from which to take the record. Only present when mixing. 26 | # Starts with 0. 27 | DATASET_INDEX = "_dataset_index" 28 | # Epoch for the record. When mixing datasets this is the epoch over the dataset, 29 | # not the mixture. Starts with 1. 30 | EPOCH = "_epoch" 31 | # Random seed for stateless random operations. This is unique per record 32 | # and changes every epoch. 33 | SEED = "_seed" 34 | # Serialized record. 35 | RECORD = "_record" 36 | 37 | META_FEATURES = frozenset( 38 | [INDEX, RECORD_KEY, DATASET_INDEX, EPOCH, SEED, RECORD] 39 | ) 40 | -------------------------------------------------------------------------------- /grain/_src/core/exceptions.py: -------------------------------------------------------------------------------- 1 | """Custom exception for PyGrain-specific internal errors.""" 2 | 3 | 4 | class PyGrainInternalError(Exception): 5 | pass 6 | -------------------------------------------------------------------------------- /grain/_src/core/monitoring.py: -------------------------------------------------------------------------------- 1 | """Grain metrics.""" 2 | 3 | import enum 4 | 5 | 6 | @enum.unique 7 | class Units(enum.Enum): 8 | """Grain metric units.""" 9 | 10 | NANOSECONDS = enum.auto() 11 | MILLISECONDS = enum.auto() 12 | 13 | 14 | class NoOpMetric: 15 | """Grain metric no-op implementation.""" 16 | 17 | def __init__(self, *args, **kwargs): 18 | del args, kwargs 19 | 20 | def IncrementBy(self, *args, **kwargs): 21 | del args, kwargs 22 | 23 | def Increment(self, *args, **kwargs): 24 | self.IncrementBy(1, *args, **kwargs) 25 | 26 | def Set(self, *args, **kwargs): 27 | del args, kwargs 28 | 29 | def Record(self, *args, **kwargs): 30 | del args, kwargs 31 | 32 | def Get(self, *args, **kwargs): 33 | del args, kwargs 34 | 35 | 36 | class Metadata: 37 | """Grain metric no-op metadata.""" 38 | 39 | def __init__(self, *args, **kwargs): 40 | del args, kwargs 41 | 42 | 43 | class Bucketer: 44 | """Grain metric no-op bucketer.""" 45 | 46 | def __init__(self, *args, **kwargs): 47 | del args, kwargs 48 | 49 | def PowersOf(self, *args, **kwargs): 50 | del args, kwargs 51 | 52 | 53 | Counter = Metric = EventMetric = NoOpMetric 54 | 55 | def get_monitoring_root() -> None: 56 | return None 57 | -------------------------------------------------------------------------------- /grain/_src/core/parallel.py: -------------------------------------------------------------------------------- 1 | """Provides a methods to run functions in parallel using a thread pool.""" 2 | from collections.abc import Mapping, Sequence 3 | from typing import Any, Callable, TypeVar 4 | 5 | from concurrent import futures 6 | 7 | T = TypeVar("T") 8 | 9 | 10 | def run_in_parallel( 11 | function: Callable[..., T], 12 | list_of_kwargs_to_function: Sequence[Mapping[str, Any]], 13 | num_workers: int, 14 | thread_name_prefix: str = "parallel_", 15 | ) -> list[T]: 16 | """Run a function on a list of kwargs in parallel with ThreadPoolExecutor. 17 | 18 | Works best when there is IO boundedness, not when there is CPU boundedness, as 19 | the threads used are bound by the GIL. 20 | 21 | Propagates first exception to the calling thread. If cancel_futures=True, 22 | then stop as many of the ongoing work units as possible. 23 | 24 | Example usage: 25 | def io_bound_function(p): 26 | get_contents_from_cns(p) 27 | 28 | run_in_parallel( 29 | function=io_bound_function, 30 | list_of_kwargs_to_function=[{"p": p} for p in long_list_of_paths], 31 | num_workers=3) 32 | 33 | Args: 34 | function: a function. 35 | list_of_kwargs_to_function: A list of dicts mapping from string to argument 36 | value. These will be passed into `function` as kwargs. 37 | num_workers: int. 38 | thread_name_prefix: The thread name prefix string. Processes are run in 39 | threads, and each thread is named. This parameter allows the user to 40 | control the prefix for that thread name. 41 | 42 | Returns: 43 | list of return values from function, in the same order as the arguments in 44 | list_of_kwargs_to_function. 45 | """ 46 | if num_workers < 1: 47 | raise ValueError( 48 | "Number of workers must be greater than 0. Was {}".format(num_workers) 49 | ) 50 | 51 | thread_name = thread_name_prefix + getattr(function, "__name__", "unknown") 52 | with futures.ThreadPoolExecutor( 53 | num_workers, thread_name_prefix=thread_name 54 | ) as executor: 55 | fs = [] 56 | 57 | for kwargs in list_of_kwargs_to_function: 58 | f = executor.submit(function, **kwargs) 59 | fs.append(f) 60 | 61 | futures_as_completed = futures.as_completed(fs) 62 | 63 | for completed in futures_as_completed: 64 | if completed.exception(): 65 | # Cancel all remaining futures, if possible. 66 | for remaining_future in fs: 67 | remaining_future.cancel() 68 | 69 | # Propagate exception to main thread. 70 | raise completed.exception() 71 | 72 | return [f.result() for f in fs] 73 | -------------------------------------------------------------------------------- /grain/_src/core/parallel_test.py: -------------------------------------------------------------------------------- 1 | """Tests for parallel.""" 2 | import threading 3 | 4 | from absl.testing import absltest 5 | from absl.testing import parameterized 6 | from grain._src.core import parallel 7 | 8 | 9 | def ReturnThreadName(): 10 | return threading.current_thread().name 11 | 12 | 13 | def Identity(i): 14 | return i 15 | 16 | 17 | def FnThatAlwaysFails(arg): 18 | del arg 19 | raise ValueError("I always fail") 20 | 21 | 22 | def FnThatFailsOnOddInputs(i): 23 | if i % 2 == 1: 24 | raise ValueError("Failed on an odd input") 25 | return i 26 | 27 | 28 | class ParallelTest(parameterized.TestCase): 29 | 30 | @parameterized.named_parameters( 31 | dict( 32 | testcase_name=" empty list of args", 33 | num_workers=1, 34 | input_dict_list=[], 35 | expected=[], 36 | ), 37 | dict( 38 | testcase_name=" one worker, nonempty list", 39 | num_workers=1, 40 | input_dict_list=[dict(i=k) for k in range(1, 10)], 41 | expected=list(range(1, 10)), 42 | ), 43 | dict( 44 | testcase_name=" fewer workers than jobs, nonempty list", 45 | num_workers=3, 46 | input_dict_list=[dict(i=k) for k in range(1, 10)], 47 | expected=list(range(1, 10)), 48 | ), 49 | dict( 50 | testcase_name=" more workers than jobs, nonempty list", 51 | num_workers=20, 52 | input_dict_list=[dict(i=k) for k in range(1, 10)], 53 | expected=list(range(1, 10)), 54 | ), 55 | ) 56 | def testRunInParallel(self, input_dict_list, num_workers: int, expected): 57 | actual = parallel.run_in_parallel(Identity, input_dict_list, num_workers) 58 | self.assertEqual(actual, expected) 59 | 60 | def testRunInParallelOnAlwaysFailingFn(self): 61 | with self.assertRaisesRegex(ValueError, "I always fail"): 62 | parallel.run_in_parallel(FnThatAlwaysFails, [dict(arg="hi")], 10) 63 | 64 | @parameterized.named_parameters( 65 | dict( 66 | testcase_name=" one failing input, one worker", 67 | num_workers=1, 68 | input_dict_list=[{"i": 1}], 69 | ), 70 | dict( 71 | testcase_name=" one failing input, many workers", 72 | num_workers=5, 73 | input_dict_list=[{"i": 1}], 74 | ), 75 | dict( 76 | testcase_name=" one failing input, one succeeding input", 77 | num_workers=5, 78 | input_dict_list=[{"i": 1}, {"i": 2}], 79 | ), 80 | dict( 81 | testcase_name=" two failing inputs, one succeeding input", 82 | num_workers=5, 83 | input_dict_list=[{"i": 1}, {"i": 2}, {"i": 3}], 84 | ), 85 | ) 86 | def testRunInParallelFailsIfSomeFnCallsFail( 87 | self, input_dict_list, num_workers: int 88 | ): 89 | with self.assertRaisesRegex(ValueError, "Failed on an odd input"): 90 | parallel.run_in_parallel( 91 | FnThatFailsOnOddInputs, input_dict_list, num_workers 92 | ) 93 | 94 | def testRunInParallelForThreadNamePrefix(self): 95 | input_kwarg_list = [{}] 96 | thread_names = parallel.run_in_parallel( 97 | ReturnThreadName, input_kwarg_list, 5, thread_name_prefix="Customized-" 98 | ) 99 | self.assertStartsWith(thread_names[0], "Customized-") 100 | 101 | def testRunInParallelForDefaultThreadNamePrefix(self): 102 | input_kwarg_list = [{}] 103 | thread_names = parallel.run_in_parallel( 104 | ReturnThreadName, input_kwarg_list, 5 105 | ) 106 | self.assertStartsWith(thread_names[0], "parallel_") 107 | 108 | 109 | if __name__ == "__main__": 110 | absltest.main() 111 | -------------------------------------------------------------------------------- /grain/_src/core/sharding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Classes for handling sharding of data sources arcoss machines/VMs.""" 15 | import dataclasses 16 | 17 | from absl import logging 18 | 19 | 20 | @dataclasses.dataclass(frozen=True) 21 | class ShardOptions: 22 | """Dataclass to hold options for sharding a data source. 23 | 24 | Attributes: 25 | shard_index: The index of the shard to use in this process. Must be in [0, 26 | shard_count - 1]. 27 | shard_count: The total number of shards. 28 | drop_remainder: If True shard() will create even splits and drop the 29 | remainder examples (all shards will have the same number of examples). If 30 | False will distribute the remainder N over the first N shards. 31 | """ 32 | 33 | shard_index: int 34 | shard_count: int 35 | drop_remainder: bool = False 36 | 37 | def __post_init__(self): 38 | if self.shard_count <= 0: 39 | raise ValueError( 40 | "Number of shards must be a positive integer but got " 41 | f"{self.shard_count}." 42 | ) 43 | if self.shard_index < 0 or self.shard_index >= self.shard_count: 44 | raise ValueError( 45 | "Shard shard_index must be in [0, shard_count - 1], shard_count was " 46 | f"{self.shard_count} and shard_index was {self.shard_index}." 47 | ) 48 | 49 | 50 | class NoSharding(ShardOptions): 51 | """Doesn't shard data. Each process will load all data.""" 52 | 53 | def __init__(self): 54 | super().__init__(shard_index=0, shard_count=1, drop_remainder=False) 55 | 56 | 57 | class ShardByJaxProcess(ShardOptions): 58 | """Shards the data across JAX processes.""" 59 | 60 | def __init__(self, drop_remainder: bool = False): 61 | process_index, process_count = get_process_index_and_count() 62 | super().__init__( 63 | shard_index=process_index, 64 | shard_count=process_count, 65 | drop_remainder=drop_remainder, 66 | ) 67 | 68 | 69 | def even_split(num_examples: int, options: ShardOptions) -> tuple[int, int]: 70 | """Returns the interval for the shard when sharding `num_examples` evenly. 71 | 72 | This splits the interval [0, num_examples - 1] into `shard_count` intervals 73 | and returns the `shard_index`'s interval. If `drop_remainder` is True all 74 | intervals will have the same size. 75 | 76 | Args: 77 | num_examples: Number of examples to shard. 78 | options: Options for sharding the data in this process. 79 | 80 | Returns: 81 | Tuple with the start and end of the interval. The start is the first 82 | example that should be included in this interval and end - 1 is the last 83 | example to be include in the shard. 84 | """ 85 | examples_per_shard = num_examples // options.shard_count 86 | shard_start = examples_per_shard * options.shard_index 87 | shard_end = examples_per_shard * (options.shard_index + 1) 88 | 89 | # Handle remaining examples. 90 | num_unused_examples = num_examples % options.shard_count 91 | 92 | if num_unused_examples > 0: 93 | if options.drop_remainder: 94 | logging.warning( 95 | "Dropping %d examples of %d examples (shard %d).", 96 | num_unused_examples, 97 | num_examples, 98 | options.shard_count, 99 | ) 100 | else: 101 | shard_start += min(options.shard_index, num_unused_examples) 102 | shard_end += min(options.shard_index + 1, num_unused_examples) 103 | return shard_start, shard_end 104 | 105 | 106 | def get_process_index_and_count(): 107 | try: 108 | import jax # pylint:disable=g-import-not-at-top # pytype:disable=import-error 109 | 110 | return jax.process_index(), jax.process_count() 111 | except ImportError: 112 | return 0, 1 113 | -------------------------------------------------------------------------------- /grain/_src/core/sharding_test.py: -------------------------------------------------------------------------------- 1 | """Tests for sharding.""" 2 | from absl.testing import absltest 3 | from absl.testing import parameterized 4 | from grain._src.core import sharding 5 | 6 | 7 | class ShardingTest(parameterized.TestCase): 8 | 9 | @parameterized.parameters( 10 | # num_examples, shard_index, shard_count, drop_remainder, expected_output. 11 | (9, 0, 1, True, (0, 9)), 12 | (9, 0, 2, True, (0, 4)), 13 | (9, 1, 2, True, (4, 8)), # Last example gets dropped. 14 | (9, 0, 3, True, (0, 3)), 15 | (9, 1, 3, True, (3, 6)), 16 | (9, 2, 3, True, (6, 9)), 17 | (9, 0, 1, False, (0, 9)), 18 | (9, 0, 2, False, (0, 5)), # First shard gets an extra example. 19 | (9, 1, 2, False, (5, 9)), 20 | (8, 0, 3, False, (0, 3)), # First 2 shards get 1 example each. 21 | (8, 1, 3, False, (3, 6)), 22 | (8, 2, 3, False, (6, 8)), 23 | ) 24 | def test_sharding( 25 | self, 26 | num_examples: int, 27 | shard_index: int, 28 | shard_count: int, 29 | drop_remainder, 30 | expected_output: tuple[int, int], 31 | ): 32 | shard_options = sharding.ShardOptions( 33 | shard_index=shard_index, 34 | shard_count=shard_count, 35 | drop_remainder=drop_remainder, 36 | ) 37 | actual_output = sharding.even_split(num_examples, shard_options) 38 | self.assertEqual(actual_output, expected_output) 39 | 40 | 41 | if __name__ == '__main__': 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /grain/_src/core/smoke_test_with_jax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Checks that OSS Grain Package works end-to-end with JAX.""" 15 | from typing import Sequence 16 | from absl import app 17 | import grain 18 | import jax.numpy as jnp 19 | 20 | 21 | def main(argv: Sequence[str]) -> None: 22 | del argv 23 | ds = grain.MapDataset.source(jnp.arange(10)).map(lambda x: x + 1) 24 | 25 | for _ in ds: 26 | pass 27 | 28 | 29 | if __name__ == "__main__": 30 | app.run(main) 31 | -------------------------------------------------------------------------------- /grain/_src/core/smoke_test_with_tf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Checks that OSS Grain Package works end-to-end with TF.""" 15 | from typing import Sequence 16 | from absl import app 17 | import grain 18 | import tensorflow as tf 19 | 20 | 21 | def main(argv: Sequence[str]) -> None: 22 | del argv 23 | ds = grain.MapDataset.source(range(10)).map(tf.convert_to_tensor) 24 | 25 | for _ in ds: 26 | pass 27 | 28 | 29 | if __name__ == "__main__": 30 | app.run(main) 31 | -------------------------------------------------------------------------------- /grain/_src/core/transforms_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import functools 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | from grain._src.core import transforms 19 | 20 | 21 | class _TestFilter(transforms.Filter): 22 | 23 | def filter(self, x): 24 | return x % 2 == 0 25 | 26 | 27 | class _TestFilterWithStr(transforms.Filter): 28 | 29 | def filter(self, x): 30 | return x % 2 == 0 31 | 32 | def __str__(self): 33 | return "CustomStr" 34 | 35 | 36 | class _TestMapWithRepr(transforms.MapTransform): 37 | 38 | def map(self, x): 39 | return x % 2 == 0 40 | 41 | def __repr__(self): 42 | return "CustomRepr" 43 | 44 | 45 | class GetPrettyTransformNameTest(parameterized.TestCase): 46 | 47 | @parameterized.parameters( 48 | dict( 49 | transform=lambda x: x, 50 | expected_substring=" @ .../_src/core/transforms_test.py:", 51 | ), 52 | dict( 53 | transform=transforms.get_pretty_transform_name, 54 | expected_substring=( 55 | "get_pretty_transform_name @ .../_src/core/transforms.py:" 56 | ), 57 | ), 58 | dict(transform=list, expected_substring="list"), 59 | dict( 60 | transform=functools.partial(lambda x, y: x + y, 1), 61 | expected_substring="functools.partial", 62 | ), 63 | dict(transform=_TestFilter(), expected_substring="_TestFilter"), 64 | dict( 65 | transform=_TestFilterWithStr(), 66 | expected_substring="CustomStr", 67 | ), 68 | dict( 69 | transform=_TestMapWithRepr(), 70 | expected_substring="CustomRepr", 71 | ), 72 | ) 73 | def test_get_pretty_transform_name(self, transform, expected_substring): 74 | self.assertIn( 75 | expected_substring, transforms.get_pretty_transform_name(transform) 76 | ) 77 | 78 | 79 | if __name__ == "__main__": 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /grain/_src/core/tree_lib_jax_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Testes for tree_lib.py with JAX dependency present.""" 15 | 16 | from absl.testing import absltest 17 | import attrs 18 | from grain._src.core import tree_lib 19 | from grain._src.core import tree_lib_test 20 | import jax 21 | import numpy as np 22 | 23 | 24 | class MyTree: 25 | 26 | def __init__(self, a, b): 27 | self.a = a 28 | self.b = b 29 | 30 | def __eq__(self, other): 31 | return self.a == other.a and self.b == other.b 32 | 33 | 34 | class MyClass: 35 | 36 | def __init__(self, c): 37 | self.c = c 38 | 39 | 40 | @attrs.define 41 | class MyAttrs: 42 | d: int 43 | e: str 44 | 45 | 46 | class TreeJaxTest(tree_lib_test.TreeTest): 47 | 48 | def test_map_custom_tree(self): 49 | jax.tree_util.register_pytree_node( 50 | MyTree, lambda t: ((t.a, t.b), None), lambda _, args: MyTree(*args) 51 | ) 52 | self.assertEqual( 53 | tree_lib.map_structure(lambda x: x + 1, MyTree(1, 2)), MyTree(2, 3) 54 | ) 55 | 56 | def test_spec_like_with_class(self): 57 | self.assertEqual( 58 | tree_lib.spec_like({"B": 1232.4, "C": MyClass(1)}), 59 | { 60 | "B": "[]", 61 | "C": "[]", 62 | }, 63 | ) 64 | 65 | def test_spec_like_with_list(self): 66 | self.assertEqual( 67 | tree_lib.spec_like({ 68 | "B": 1232.4, 69 | "C": [ 70 | tree_lib_test.TestClass(a=1, b="v2"), 71 | tree_lib_test.TestClass(a=2, b="v2"), 72 | ], 73 | }), 74 | { 75 | "B": "[]", 76 | "C": "list[2]", 77 | }, 78 | ) 79 | 80 | def test_spec_like_with_unknown_shape(self): 81 | self.assertEqual( 82 | tree_lib.spec_like({ 83 | "B": [np.zeros([2]), np.zeros([1])], 84 | "C": [], 85 | }), 86 | {"B": "list[unknown shape]", "C": "list<>[0]"}, 87 | ) 88 | 89 | def test_spec_like_with_dataclass(self): 90 | self.assertEqual( 91 | tree_lib.spec_like(tree_lib_test.TestClass(a=1, b="v2")), 92 | "\n" 93 | "{'a': \"[]\", 'b': \"[]\"}[]", 94 | ) 95 | 96 | def test_spec_like_with_attrs(self): 97 | self.assertEqual( 98 | tree_lib.spec_like(MyAttrs(d=1, e="v2")), 99 | "\n" 100 | "{'d': \"[]\", 'e': \"[]\"}[]", 101 | ) 102 | 103 | 104 | if __name__ == "__main__": 105 | absltest.main() 106 | -------------------------------------------------------------------------------- /grain/_src/core/usage_logging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License 14 | """Internal usage logging.""" 15 | 16 | 17 | def log_event(tag: str, *, tag_2: str = "", tag_3: str = "") -> None: 18 | return 19 | -------------------------------------------------------------------------------- /grain/_src/core/version_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Checks that tests in OSS are run with the correct version of Python.""" 15 | # Make sure grain can be imported. 16 | from grain import python as grain # pylint: disable=unused-import 17 | 18 | import os 19 | import sys 20 | from absl.testing import absltest 21 | 22 | 23 | class VersionTest(absltest.TestCase): 24 | 25 | def test_python_version(self): 26 | expected = os.getenv("PYTHON_VERSION") 27 | current = f"{sys.version_info.major}.{sys.version_info.minor}" 28 | if current != expected: 29 | raise ValueError( 30 | f"expected version '{expected}' is different than returned" 31 | f" '{current}'" 32 | ) 33 | 34 | 35 | if __name__ == "__main__": 36 | absltest.main() 37 | -------------------------------------------------------------------------------- /grain/_src/python/checkpoint_handlers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module provides a PyGrain CheckpointHandler for integration with Orbax.""" 15 | import dataclasses 16 | import json 17 | from typing import Any, Optional, TypeVar 18 | 19 | from etils import epath 20 | from grain._src.core import sharding 21 | from grain._src.python import data_loader 22 | from grain._src.python.dataset import dataset 23 | 24 | IteratorType = TypeVar( 25 | "IteratorType", data_loader.DataLoaderIterator, dataset.DatasetIterator 26 | ) 27 | 28 | 29 | # Ipmlements orbax.checkpoint.CheckpointHandler. 30 | class CheckpointHandler: 31 | """Orbax CheckpointHandler for PyGrain iterators.""" 32 | 33 | def save( 34 | self, 35 | directory: epath.Path, 36 | # `item` is for backwards compatibility with older Orbax API, see 37 | # https://orbax.readthedocs.io/en/latest/api_refactor.html. 38 | item: Optional[IteratorType] = None, 39 | args: Any = None, 40 | ): 41 | """Saves the given iterator to the checkpoint in `directory`.""" 42 | item = item or args.item # pytype:disable=attribute-error 43 | if isinstance(item, dataset.DatasetIterator): 44 | state = json.dumps(item.get_state(), indent=4) 45 | else: 46 | state = item.get_state().decode() 47 | process_index, process_count = sharding.get_process_index_and_count() 48 | filename = directory / f"process_{process_index}-of-{process_count}.json" 49 | filename.write_text(state) 50 | 51 | def restore( 52 | self, 53 | directory: epath.Path, 54 | item: Optional[IteratorType] = None, 55 | args: Any = None, 56 | ) -> IteratorType: 57 | """Restores the given iterator from the checkpoint in `directory`.""" 58 | item = item or args.item # pytype:disable=attribute-error 59 | process_index, process_count = sharding.get_process_index_and_count() 60 | filename = directory / f"process_{process_index}-of-{process_count}.json" 61 | if not filename.exists(): 62 | raise ValueError(f"File {filename} does not exist.") 63 | state = filename.read_text() 64 | if isinstance(item, dataset.DatasetIterator): 65 | state = json.loads(state) 66 | else: 67 | state = state.encode() 68 | item.set_state(state) 69 | return item 70 | 71 | # Required by interface but not supported by PyGrain checkpoints. 72 | def structure(self, directory: epath.Path) -> Any: 73 | del directory 74 | return None 75 | 76 | # Required by interface. 77 | 78 | def metadata(self, directory: epath.Path) -> Optional[Any]: 79 | del directory 80 | return None 81 | 82 | def finalize(self, directory: epath.Path): 83 | pass 84 | 85 | def close(self): 86 | pass 87 | 88 | @classmethod 89 | def typestr(cls): 90 | return f"{cls.__module__}.{cls.__qualname__}" 91 | 92 | 93 | try: 94 | # Register the handler to be used with the new checkpointing API if Orbax is 95 | # present. 96 | import orbax.checkpoint as ocp # pylint:disable=g-import-not-at-top # pytype:disable=import-error 97 | 98 | @ocp.args.register_with_handler(CheckpointHandler, for_save=True) # pytype:disable=wrong-arg-types 99 | @dataclasses.dataclass 100 | class CheckpointSave(ocp.args.CheckpointArgs): 101 | item: Any 102 | 103 | @ocp.args.register_with_handler(CheckpointHandler, for_restore=True) # pytype:disable=wrong-arg-types 104 | @dataclasses.dataclass 105 | class CheckpointRestore(ocp.args.CheckpointArgs): 106 | item: Any 107 | 108 | 109 | except (ImportError, TypeError, AttributeError): 110 | pass 111 | -------------------------------------------------------------------------------- /grain/_src/python/checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Utilities providing checkpointing capabilities for Grain iterators.""" 15 | 16 | import asyncio 17 | from typing import Callable, Protocol 18 | from etils import epath 19 | from grain._src.core import sharding 20 | 21 | 22 | class PathAwaitingCreation(Protocol): 23 | """A path that is in the process of being created. 24 | 25 | Please see orbax/checkpoint for the full definition of this type. 26 | """ 27 | 28 | async def await_creation(self) -> epath.Path: 29 | """Waits for the path to be created. 30 | 31 | This function MUST be called before accessing the physical path. Prefer to 32 | perform in the background operation, rather than the main-thread-blocking 33 | operation. 34 | 35 | Returns: 36 | The path that is now created. 37 | """ 38 | 39 | 40 | async def background_save(directory: PathAwaitingCreation, state: str): 41 | """An async function that saves iterator state in a background thread. 42 | 43 | Args: 44 | directory: The directory to save the state to. 45 | state: The state to save. 46 | """ 47 | directory = await directory.await_creation() 48 | process_index, process_count = sharding.get_process_index_and_count() 49 | filename = directory / f"process_{process_index}-of-{process_count}.json" 50 | await asyncio.to_thread(filename.write_text, state) 51 | 52 | 53 | async def background_load( 54 | directory: epath.Path, set_state_fn: Callable[[str], None] 55 | ): 56 | process_index, process_count = sharding.get_process_index_and_count() 57 | filename = directory / f"process_{process_index}-of-{process_count}.json" 58 | if not await asyncio.to_thread(filename.exists): 59 | raise ValueError(f"File {filename} does not exist.") 60 | state = await asyncio.to_thread(filename.read_text) 61 | set_state_fn(state) 62 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for base.py.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | from grain._src.python import data_sources 19 | from grain._src.python.dataset import base 20 | 21 | 22 | class RandomAccessDataSourceTest(parameterized.TestCase): 23 | 24 | @parameterized.parameters( 25 | data_sources.ArrayRecordDataSource, 26 | data_sources.RangeDataSource, 27 | data_sources.SharedMemoryDataSource, 28 | ) 29 | def test_protocol(self, source_cls): 30 | self.assertIsInstance(source_cls, base.RandomAccessDataSource) 31 | 32 | 33 | class DatasetOptionsTest(parameterized.TestCase): 34 | 35 | @parameterized.named_parameters( 36 | dict( 37 | testcase_name="no_conflicts", 38 | a=base.DatasetOptions(filter_warn_threshold_ratio=0.1), 39 | b=base.DatasetOptions(filter_raise_threshold_ratio=0.2), 40 | expected=base.DatasetOptions( 41 | filter_warn_threshold_ratio=0.1, 42 | filter_raise_threshold_ratio=0.2, 43 | ), 44 | ), 45 | dict( 46 | testcase_name="all_fields_default", 47 | a=base.DatasetOptions(), 48 | b=base.DatasetOptions( 49 | filter_warn_threshold_ratio=0.4, 50 | filter_raise_threshold_ratio=0.3, 51 | ), 52 | expected=base.DatasetOptions( 53 | filter_warn_threshold_ratio=0.4, 54 | filter_raise_threshold_ratio=0.3, 55 | ), 56 | ), 57 | dict( 58 | testcase_name="field_conflict", 59 | a=base.DatasetOptions(filter_raise_threshold_ratio=0.1), 60 | b=base.DatasetOptions(filter_raise_threshold_ratio=0.2), 61 | expected=base.DatasetOptions( 62 | filter_raise_threshold_ratio=0.1, 63 | ), 64 | ), 65 | ) 66 | def test_merge(self, a, b, expected): 67 | self.assertEqual(a.merge(b), expected) 68 | 69 | 70 | class IteratorContextTest(parameterized.TestCase): 71 | 72 | def test_merge(self): 73 | a = base.IteratorContext( 74 | dataset_options=base.DatasetOptions(filter_warn_threshold_ratio=0.1) 75 | ) 76 | b = base.IteratorContext( 77 | dataset_options=base.DatasetOptions( 78 | filter_warn_threshold_ratio=0.2, filter_raise_threshold_ratio=0.2 79 | ) 80 | ) 81 | a.merge(b) 82 | self.assertEqual( 83 | a, 84 | base.IteratorContext( 85 | dataset_options=base.DatasetOptions( 86 | filter_warn_threshold_ratio=0.1, 87 | filter_raise_threshold_ratio=0.2, 88 | ) 89 | ), 90 | ) 91 | 92 | def test_merge_with_different_mp_context(self): 93 | a = base.IteratorContext( 94 | mp_context=base.MultiprocessingContext(process_index=0, process_count=1) 95 | ) 96 | b = base.IteratorContext( 97 | mp_context=base.MultiprocessingContext(process_index=1, process_count=2) 98 | ) 99 | with self.assertRaisesRegex( 100 | ValueError, "Cannot merge contexts from different worker processes" 101 | ): 102 | a.merge(b) 103 | 104 | 105 | if __name__ == "__main__": 106 | absltest.main() 107 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/elastic_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Iterator supporting changes in the number of hosts (dataset shards).""" 15 | 16 | import functools 17 | from typing import Any 18 | 19 | from grain._src.core import sharding 20 | from grain._src.python import options 21 | from grain._src.python.dataset import dataset 22 | from grain._src.python.dataset.transformations import ( 23 | filter as filter_dataset, 24 | ) 25 | 26 | _GLOBAL_NEXT_INDEX_STATE_KEY = "global_next_index" 27 | 28 | 29 | class ElasticIterator(dataset.DatasetIterator): 30 | """Iterator supporting recovery from a checkpoint after changes in sharding. 31 | 32 | The input dataset is expected to be unbatched and unsharded. In order to 33 | provide elasticity guarantee this iterator includes both, batching and 34 | sharding. The iterator supports elastic re-configuration by having each 35 | shard produce the same exact checkpoint (while producing different data) as 36 | long as they are advanced the same number of steps. 37 | 38 | State of any shard can be used to restore the state of all of the shards after 39 | changes in sharding and global batch size. 40 | 41 | This iterator explicitly disallows many-to-one transformations without 42 | a fixed ratio, like `filter` and generic `IterDataset` transformations. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | ds: dataset.MapDataset, 48 | global_batch_size: int, 49 | shard_options: sharding.ShardOptions, 50 | *, 51 | read_options: options.ReadOptions = options.ReadOptions(), 52 | multiprocessing_options: options.MultiprocessingOptions | None = None, 53 | ): 54 | super().__init__() 55 | to_check = [ds] 56 | while to_check: 57 | next_ds = to_check.pop() 58 | if isinstance(next_ds, filter_dataset.FilterMapDataset): 59 | raise ValueError( 60 | "ElasticIterator does not support `filter` transformation." 61 | ) 62 | to_check.extend(next_ds.parents) 63 | self._ds = ds 64 | self._global_batch_size = global_batch_size 65 | self._shard_options = shard_options 66 | self._global_next_index = 0 67 | self._read_options = read_options 68 | self._multiprocessing_options = multiprocessing_options 69 | 70 | @functools.cached_property 71 | def _iterator(self) -> dataset.DatasetIterator: 72 | ds = self._ds[ 73 | self._global_next_index 74 | + self._shard_options.shard_index :: self._shard_options.shard_count 75 | ] 76 | host_batch_size, remainder = divmod( 77 | self._global_batch_size, self._shard_options.shard_count 78 | ) 79 | if remainder: 80 | raise ValueError( 81 | f"Global batch size {self._global_batch_size} is not divisible by" 82 | f" shard count {self._shard_options.shard_count}." 83 | ) 84 | ds = ds.batch(host_batch_size, drop_remainder=True) 85 | ds = ds.to_iter_dataset(read_options=self._read_options) 86 | if self._multiprocessing_options is not None: 87 | ds = ds.mp_prefetch(self._multiprocessing_options) 88 | return ds.__iter__() 89 | 90 | def __iter__(self) -> dataset.DatasetIterator: 91 | return self 92 | 93 | def __next__(self) -> Any: 94 | result = next(self._iterator) 95 | self._global_next_index += self._global_batch_size 96 | return result 97 | 98 | def get_state(self) -> dict[str, Any]: 99 | return { 100 | _GLOBAL_NEXT_INDEX_STATE_KEY: self._global_next_index, 101 | } 102 | 103 | def set_state(self, state: dict[str, Any]): 104 | self._global_next_index = state[_GLOBAL_NEXT_INDEX_STATE_KEY] 105 | # Reset the iterator if it was already created. 106 | self.__dict__.pop("_iterator", None) 107 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/sources/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//grain:__subpackages__"]) 2 | 3 | licenses(["notice"]) 4 | 5 | py_library( 6 | name = "parquet_dataset", 7 | srcs = ["parquet_dataset.py"], 8 | srcs_version = "PY3", 9 | deps = [ 10 | "//grain/_src/python/dataset", 11 | "@pypi//etils:pkg", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "parquet_dataset_test", 17 | srcs = ["parquet_dataset_test.py"], 18 | srcs_version = "PY3", 19 | deps = [ 20 | ":parquet_dataset", 21 | "//grain:python", 22 | "@abseil-py//absl/flags", 23 | "@abseil-py//absl/testing:absltest", 24 | "@pypi//pyarrow:pkg", 25 | ], 26 | ) 27 | 28 | py_library( 29 | name = "tfrecord_dataset", 30 | srcs = ["tfrecord_dataset.py"], 31 | srcs_version = "PY3", 32 | deps = ["//grain/_src/python/dataset"], 33 | ) 34 | 35 | py_test( 36 | name = "tfrecord_dataset_test", 37 | srcs = ["tfrecord_dataset_test.py"], 38 | args = ["--test_srcdir=grain/_src/python/testdata"], 39 | data = [ 40 | "//grain/_src/python/testdata:morris_sequence_first_5.tfrecord", 41 | ], 42 | srcs_version = "PY3", 43 | deps = [ 44 | ":tfrecord_dataset", 45 | "//grain", 46 | "@abseil-py//absl/flags", 47 | "@abseil-py//absl/testing:absltest", 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/sources/parquet_dataset.py: -------------------------------------------------------------------------------- 1 | """Provides an `IterDataset` for Parquet file format.""" 2 | 3 | from typing import TypeVar 4 | 5 | from etils import epy 6 | from grain._src.python.dataset import dataset 7 | 8 | 9 | # lazy import for pyarrow 10 | with epy.lazy_imports(): 11 | import pyarrow.parquet as pq # pytype: disable=import-error # pylint: disable=g-import-not-at-top 12 | 13 | 14 | T = TypeVar("T") 15 | 16 | 17 | class _ParquetDatasetIterator(dataset.DatasetIterator[T]): 18 | """A DatasetIterator for Parquet file format.""" 19 | 20 | def __init__( 21 | self, path: str, row_group: int = 0, index_within_row_group: int = 0 22 | ): 23 | super().__init__() 24 | self._row_group = row_group 25 | self._index_within_row_group = index_within_row_group 26 | self._pq_path = path 27 | self._pq_file = pq.ParquetFile(self._pq_path) 28 | self._np_table = {} 29 | self._row_group_len = 0 30 | self._read_row_group_to_np_table() 31 | 32 | def _read_row_group_to_np_table(self): 33 | table = self._pq_file.read_row_group(self._row_group) 34 | self._row_group_len = len(table) 35 | self._np_table = {} 36 | for i in range(table.num_columns): 37 | self._np_table[table.field(i).name] = table.column(i).to_numpy() 38 | 39 | def __next__(self): 40 | if self._index_within_row_group >= self._row_group_len: 41 | if self._row_group < self._pq_file.num_row_groups - 1: 42 | self._row_group += 1 43 | self._index_within_row_group = 0 44 | self._read_row_group_to_np_table() 45 | return self.__next__() 46 | else: 47 | raise StopIteration() 48 | else: 49 | item = { 50 | k: v[self._index_within_row_group] for k, v in self._np_table.items() 51 | } 52 | self._index_within_row_group += 1 53 | return item 54 | 55 | def get_state(self): 56 | return { 57 | "row_group": self._row_group, 58 | "index_within_row_group": self._index_within_row_group, 59 | } 60 | 61 | def set_state(self, state): 62 | self._row_group = state["row_group"] 63 | self._index_within_row_group = state["index_within_row_group"] 64 | self._read_row_group_to_np_table() 65 | 66 | 67 | class ParquetIterDataset(dataset.IterDataset[T]): 68 | """An IterDataset for a parquet format file.""" 69 | 70 | def __init__(self, path: str): 71 | """Initializes ParquetIterDataset. 72 | 73 | Args: 74 | path: A path to a record io format file. 75 | """ 76 | super().__init__() 77 | self._path = path 78 | 79 | def __iter__(self) -> _ParquetDatasetIterator[T]: 80 | return _ParquetDatasetIterator(self._path) 81 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/sources/parquet_dataset_test.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | from absl.testing import absltest 3 | from grain._src.python.dataset.sources import parquet_dataset 4 | import grain.python as grain 5 | import pyarrow as pa 6 | import pyarrow.parquet as pq 7 | 8 | flags.FLAGS.mark_as_parsed() 9 | 10 | SOME_TEXT = [ 11 | [ 12 | "This is the first file the first record", 13 | "This is the first file the second record", 14 | "This is the first file the third record", 15 | "This is the first file the forth record", 16 | ], 17 | [ 18 | "This is the second file the first record", 19 | "This is the second file the second record", 20 | "This is the second file the third record", 21 | "This is the second file the forth record", 22 | ], 23 | ] 24 | INTERLEAVED_TEXT = [ 25 | "This is the first file the first record", 26 | "This is the second file the first record", 27 | "This is the first file the second record", 28 | "This is the second file the second record", 29 | "This is the first file the third record", 30 | "This is the second file the third record", 31 | "This is the first file the forth record", 32 | "This is the second file the forth record", 33 | ] 34 | WINDOWSHUFFLED_TEXT = [ 35 | "This is the first file the second record", 36 | "This is the second file the first record", 37 | "This is the first file the first record", 38 | "This is the first file the third record", 39 | "This is the second file the second record", 40 | "This is the second file the third record", 41 | "This is the second file the forth record", 42 | "This is the first file the forth record", 43 | ] 44 | 45 | 46 | class ParquetIterDatasetTest(absltest.TestCase): 47 | 48 | def setUp(self): 49 | super().setUp() 50 | self.filenames = [] 51 | for i in range(len(SOME_TEXT)): 52 | temp_file = self.create_tempfile() 53 | filename = temp_file.full_path 54 | self.filenames.append(filename) 55 | table = pa.table({"text": SOME_TEXT[i]}) 56 | pq.write_table(table, filename, row_group_size=2) 57 | 58 | def test_read_row_group(self): 59 | dataset = parquet_dataset.ParquetIterDataset(self.filenames[0]) 60 | records = list(dataset) 61 | self.assertSequenceEqual(records, [{"text": x} for x in SOME_TEXT[0]]) 62 | 63 | def test_checkpointing(self): 64 | dataset = parquet_dataset.ParquetIterDataset(self.filenames[0]) 65 | grain.experimental.assert_equal_output_after_checkpoint(dataset) 66 | 67 | def test_sharded_files_and_interleaved_dataset(self): 68 | dataset = grain.MapDataset.source(self.filenames) 69 | dataset = dataset.map(parquet_dataset.ParquetIterDataset) 70 | dataset = grain.experimental.InterleaveIterDataset( 71 | dataset, cycle_length=len(self.filenames) 72 | ) 73 | self.assertSequenceEqual( 74 | list(iter(dataset)), [{"text": x} for x in INTERLEAVED_TEXT] 75 | ) 76 | 77 | dataset = grain.experimental.WindowShuffleIterDataset( 78 | dataset, window_size=3, seed=42 79 | ) 80 | self.assertSequenceEqual( 81 | list(iter(dataset)), [{"text": x} for x in WINDOWSHUFFLED_TEXT] 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | absltest.main() 87 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/sources/tfrecord_dataset.py: -------------------------------------------------------------------------------- 1 | """Provides an `IterDataset` for TFRecord file format.""" 2 | 3 | import codecs 4 | import struct 5 | from typing import TypeVar 6 | 7 | from grain._src.python.dataset import dataset 8 | 9 | 10 | T = TypeVar("T") 11 | 12 | # Format of a single tf_record: 13 | # uint64 length of record in bytes 14 | # uint32 masked crc of length 15 | # bytes record data 16 | # uint32 masked crc of data 17 | _UNIT32_SIZE_IN_BYTES = 4 18 | _UNIT64_SIZE_IN_BYTES = 8 19 | 20 | 21 | class _TFRecordReader: 22 | """A reader for TFRecord files.""" 23 | 24 | def __init__(self, path: str): 25 | self._reader = open(path, "rb") 26 | 27 | def __next__(self) -> bytes: 28 | """Reads the next record from the reader.""" 29 | # Read the length and the length mask of the tf_record (uint64 and uint32 30 | # respectively) 31 | buf_length_expected = _UNIT64_SIZE_IN_BYTES + _UNIT32_SIZE_IN_BYTES 32 | buf = self._reader.read(buf_length_expected) 33 | if not buf: 34 | # If the buffer is empty, we have reached the end of the dataset. 35 | raise StopIteration() 36 | if len(buf) != buf_length_expected: 37 | raise ValueError( 38 | f"Not a valid TFRecord. Fewer than {buf_length_expected} bytes:" 39 | f" {codecs.encode(buf, 'hex')}" 40 | ) 41 | length, _ = struct.unpack(" int: 61 | return self._reader.tell() 62 | 63 | def __del__(self): 64 | if hasattr(self, "_reader") and self._reader: 65 | self._reader.close() 66 | 67 | 68 | class _TFRecordDatasetIterator(dataset.DatasetIterator[T]): 69 | """A DatasetIterator for TFRecord file format.""" 70 | 71 | def __init__(self, path: str): 72 | super().__init__() 73 | self._reader = _TFRecordReader(path) 74 | 75 | def __next__(self) -> T: 76 | return next(self._reader) 77 | 78 | def get_state(self) -> dict[str, int]: 79 | return { 80 | "reader_offset": self._reader.tell(), 81 | } 82 | 83 | def set_state(self, state: dict[str, int]): 84 | self._reader.seek(state["reader_offset"]) 85 | 86 | 87 | class TFRecordIterDataset(dataset.IterDataset[T]): 88 | """An IterDataset for a TFRecord format file.""" 89 | 90 | def __init__(self, path: str): 91 | super().__init__() 92 | self._path = path 93 | 94 | def __iter__(self) -> dataset.DatasetIterator[T]: 95 | return _TFRecordDatasetIterator[T](self._path) 96 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/sources/tfrecord_dataset_test.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from absl import flags 4 | from absl.testing import absltest 5 | from grain._src.python.dataset.sources import tfrecord_dataset 6 | import grain.python as pygrain 7 | 8 | 9 | class TFRecordIterDatasetTest(absltest.TestCase): 10 | 11 | def setUp(self): 12 | super().setUp() 13 | self.testdata_dir = pathlib.Path(flags.FLAGS.test_srcdir) 14 | self.testdata_file_path = self.testdata_dir 15 | self.testdata_file_path /= "morris_sequence_first_5.tfrecord" 16 | self.expected_data = [ 17 | b"1", 18 | b"1 1", 19 | b"2 1", 20 | b"1 2 1 1", 21 | b"1 1 1 2 2 1", 22 | ] 23 | 24 | def test_nonexistent_tfrecord_file(self): 25 | dataset = tfrecord_dataset.TFRecordIterDataset( 26 | str(self.testdata_dir / "non_existent_file.tfrecord") 27 | ) 28 | with self.assertRaises(FileNotFoundError): 29 | list(dataset) 30 | 31 | def test_empty_tfrecord_file(self): 32 | empty_tf_record_file = self.create_tempfile("empty_file.tfrecord") 33 | dataset = tfrecord_dataset.TFRecordIterDataset( 34 | empty_tf_record_file.full_path 35 | ) 36 | self.assertSequenceEqual(list(dataset), []) 37 | 38 | def test_invalid_tfrecord_file(self): 39 | truncated_length_tf_record_file = self.create_tempfile( 40 | "truncated_length_file.tfrecord" 41 | ) 42 | # Directly write the data to the file instead of using the tfrecord writer, 43 | # and without the length prefix. This will create an invalid tfrecord file. 44 | with open(truncated_length_tf_record_file, "wb") as f: 45 | f.write(b"1") 46 | 47 | dataset = tfrecord_dataset.TFRecordIterDataset( 48 | truncated_length_tf_record_file.full_path 49 | ) 50 | with self.assertRaises(ValueError): 51 | list(dataset) 52 | 53 | def test_read_tfrecord_file(self): 54 | dataset = tfrecord_dataset.TFRecordIterDataset(str(self.testdata_file_path)) 55 | self.assertSequenceEqual(list(dataset), self.expected_data) 56 | 57 | def test_checkpointing(self): 58 | dataset = tfrecord_dataset.TFRecordIterDataset(str(self.testdata_file_path)) 59 | pygrain.experimental.assert_equal_output_after_checkpoint(dataset) 60 | 61 | 62 | if __name__ == "__main__": 63 | absltest.main() 64 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/limit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implements limit transformations.""" 15 | 16 | from typing import Any, TypeVar 17 | 18 | from grain._src.python.dataset import dataset 19 | from grain._src.python.dataset import stats 20 | 21 | Element = Any 22 | T = TypeVar("T") # pylint: disable=invalid-name 23 | 24 | 25 | class _LimitDatasetIterator(dataset.DatasetIterator[T]): 26 | """Iterator that limits the number of elements in the dataset.""" 27 | 28 | def __init__( 29 | self, 30 | parent: dataset.DatasetIterator[T], 31 | count: int, 32 | ): 33 | super().__init__(parent) 34 | self._count = count 35 | self._count_elements_read = 0 36 | 37 | @stats.record_next_duration_if_output 38 | def __next__(self): 39 | if self._count_elements_read >= self._count: 40 | raise StopIteration 41 | value = next(self._parent) 42 | self._count_elements_read += 1 43 | return value 44 | 45 | def get_state(self): 46 | return { 47 | "parent": self._parent.get_state(), 48 | "count_elements_read": self._count_elements_read, 49 | } 50 | 51 | def set_state(self, state): 52 | self._parent.set_state(state["parent"]) 53 | self._count_elements_read = state["count_elements_read"] 54 | 55 | 56 | class LimitIterDataset(dataset.IterDataset[T]): 57 | """Limits the number of elements in the dataset. 58 | 59 | Example usage: 60 | 61 | ``` 62 | list(LimitIterDataset(MapDataset.range(5).to_iter_dataset(), 2) == [0, 1] 63 | ``` 64 | 65 | Attributes: 66 | parent: The dataset to limit. 67 | count: The maximum number of elements to include in the dataset. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | parent: dataset.IterDataset[T], 73 | count: int, 74 | ): 75 | """Initializes the limit dataset.""" 76 | if count <= 0: 77 | raise ValueError(f"Count must be a non-negative integer. Got {count}") 78 | super().__init__(parent) 79 | self._count = count 80 | 81 | def __iter__(self) -> _LimitDatasetIterator[T]: 82 | parent_iter = self._parent.__iter__() 83 | return _LimitDatasetIterator(parent_iter, self._count) 84 | 85 | def __str__(self) -> str: 86 | return f"LimitIterDataset(count={self._count})" 87 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/limit_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for limit transformations.""" 15 | 16 | from absl.testing import absltest 17 | from absl.testing import parameterized 18 | from grain._src.python.dataset import dataset 19 | from grain._src.python.dataset.transformations import limit 20 | import grain._src.python.testing.experimental as test_util 21 | 22 | 23 | class LimitIterDatasetTest(parameterized.TestCase): 24 | 25 | @parameterized.parameters([0, -1, -5]) 26 | def test_non_positive_count_raises_error(self, count): 27 | ds = dataset.MapDataset.range(0, 10).to_iter_dataset() 28 | with self.assertRaises(ValueError): 29 | _ = limit.LimitIterDataset(ds, count=count) 30 | 31 | def test_stop_iteration_raised_after_limit_reached(self): 32 | ds = dataset.MapDataset.range(0, 10).to_iter_dataset() 33 | ds = limit.LimitIterDataset(ds, count=1) 34 | ds_iter = iter(ds) 35 | _ = next(ds_iter) 36 | with self.assertRaises(StopIteration): 37 | next(ds_iter) 38 | 39 | @parameterized.parameters([1, 3, 5, 7, 10]) 40 | def test_count(self, count): 41 | ds = dataset.MapDataset.range(0, 10).to_iter_dataset() 42 | ds = limit.LimitIterDataset(ds, count=count) 43 | actual_data = list(ds) 44 | self.assertLen(actual_data, count) 45 | self.assertEqual(actual_data, list(range(count))) 46 | 47 | def test_count_over_epochs(self): 48 | ds = dataset.MapDataset.range(0, 10).repeat(2).to_iter_dataset() 49 | ds = limit.LimitIterDataset(ds, count=15) 50 | actual_data = list(ds) 51 | self.assertLen(actual_data, 15) 52 | self.assertEqual(actual_data, list(range(10)) + list(range(5))) 53 | 54 | def test_limit_after_batch(self): 55 | def flatten_batches(batches): 56 | actual_data = [] 57 | for batch in batches: 58 | actual_data.extend(batch.tolist()) 59 | return actual_data 60 | 61 | ds = dataset.MapDataset.range(0, 10).batch(3).to_iter_dataset() 62 | 63 | ds_1 = limit.LimitIterDataset(ds, count=2) 64 | batches = list(ds_1) 65 | actual_data = flatten_batches(batches) 66 | self.assertEqual(actual_data, list(range(6))) 67 | 68 | ds_2 = limit.LimitIterDataset(ds, count=5) 69 | batches = list(ds_2) 70 | actual_data = flatten_batches(batches) 71 | self.assertLen(batches, 4) 72 | self.assertEqual(actual_data, list(range(10))) 73 | 74 | def test_checkpointing(self): 75 | ds = dataset.MapDataset.range(0, 10).batch(3).to_iter_dataset() 76 | limited_ds = limit.LimitIterDataset(ds, count=2) 77 | test_util.assert_equal_output_after_checkpoint(limited_ds) 78 | 79 | 80 | if __name__ == "__main__": 81 | absltest.main() 82 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/packing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for batch transformation.""" 15 | 16 | from absl.testing import absltest 17 | from grain._src.python.dataset.transformations import testing_util 18 | 19 | 20 | class FirstFitPackIterDatasetTest(testing_util.BaseFirstFitPackIterDatasetTest): 21 | 22 | def setUp(self): 23 | super().setUp() 24 | self.kwargs = {} 25 | 26 | if __name__ == "__main__": 27 | absltest.main() 28 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/repeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implements repeat transformation.""" 15 | import sys 16 | from typing import Optional, TypeVar 17 | 18 | from grain._src.python.dataset import dataset 19 | 20 | T = TypeVar("T") 21 | 22 | 23 | class RepeatMapDataset(dataset.MapDataset[T]): 24 | """Repeats the underlying dataset for num_epochs. 25 | 26 | This effectively just changes the length, which indicates the size of a single 27 | epoch, of the dataset. This makes it easier to iterate for a fixed number 28 | of steps. 29 | """ 30 | 31 | _MUTATES_ELEMENT_SPEC = False 32 | 33 | def __init__( 34 | self, 35 | parent: dataset.MapDataset[T], 36 | num_epochs: Optional[int] = None, 37 | ): 38 | super().__init__(parent) 39 | if num_epochs is not None and num_epochs <= 0: 40 | raise ValueError(f"num_epochs must be positive, but got {num_epochs}.") 41 | if len(parent) >= sys.maxsize: 42 | raise ValueError( 43 | f"Repeating already infinite dataset {parent} does nothing." 44 | ) 45 | self._num_epochs = num_epochs 46 | if num_epochs is None: 47 | if len(parent) == 0: # pylint: disable=g-explicit-length-test 48 | self._length: int = 0 49 | else: 50 | self._length: int = sys.maxsize 51 | else: 52 | self._length = num_epochs * len(parent) 53 | 54 | def __len__(self) -> int: 55 | return self._length 56 | 57 | def __str__(self) -> str: 58 | return f"RepeatMapDataset(num_epochs={self._num_epochs})" 59 | 60 | def __getitem__(self, index): 61 | if isinstance(index, slice): 62 | return self.slice(index) 63 | return self._stats.record_output_spec(self._parent[index]) 64 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/repeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for repeat transformation.""" 15 | import sys 16 | 17 | from absl.testing import absltest 18 | from grain._src.python.dataset import dataset 19 | from grain._src.python.dataset.transformations import repeat 20 | from typing_extensions import override 21 | 22 | 23 | class EmptyMapDataset(dataset.MapDataset[int]): 24 | 25 | def __init__(self): 26 | super().__init__(parents=[]) 27 | 28 | @override 29 | def __len__(self) -> int: 30 | return 0 31 | 32 | @override 33 | def __getitem__(self, index): 34 | raise IndexError("Index out of range") 35 | 36 | 37 | class RepeatMapDatasetTest(absltest.TestCase): 38 | 39 | def test_finite_num_epochs_changes_length(self): 40 | ds = dataset.MapDataset.range(6) 41 | self.assertLen(ds, 6) 42 | ds = repeat.RepeatMapDataset(ds, num_epochs=3) 43 | self.assertLen(ds, 18) 44 | 45 | def test_finite_num_epochs_produces_expected_elements_when_iterated(self): 46 | ds = dataset.MapDataset.range(4) 47 | ds = repeat.RepeatMapDataset(ds, num_epochs=3) 48 | self.assertSequenceEqual(list(ds), [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]) 49 | 50 | def test_infinite_epochs_sets_length_to_maxsize(self): 51 | ds = dataset.MapDataset.range(6) 52 | ds = repeat.RepeatMapDataset(ds, num_epochs=None) 53 | self.assertLen(ds, sys.maxsize) 54 | 55 | def test_repeat_after_setting_infinite_epochs_raises_value_error(self): 56 | ds = dataset.MapDataset.range(6) 57 | ds = repeat.RepeatMapDataset(ds, num_epochs=None) 58 | with self.assertRaises(ValueError): 59 | repeat.RepeatMapDataset(ds, num_epochs=2) 60 | 61 | def test_setting_zero_epochs_raises_value_error(self): 62 | ds = dataset.MapDataset.range(6) 63 | with self.assertRaises(ValueError): 64 | repeat.RepeatMapDataset(ds, num_epochs=0) 65 | 66 | def test_setting_negative_epochs_raises_value_error(self): 67 | ds = dataset.MapDataset.range(6) 68 | with self.assertRaises(ValueError): 69 | repeat.RepeatMapDataset(ds, num_epochs=-1) 70 | 71 | def test_infinite_epochs_of_empty_dataset_keeps_length_zero(self): 72 | ds = EmptyMapDataset() 73 | ds = repeat.RepeatMapDataset(ds, num_epochs=None) 74 | self.assertEmpty(ds) 75 | 76 | 77 | if __name__ == "__main__": 78 | absltest.main() 79 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/slice.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implements slice transformation.""" 15 | from typing import TypeVar 16 | 17 | from grain._src.python.dataset import dataset 18 | 19 | T = TypeVar("T") 20 | 21 | 22 | class SliceMapDataset(dataset.MapDataset[T]): 23 | """Slices a MapDataset similar to the slicing syntax in Python.""" 24 | 25 | _MUTATES_ELEMENT_SPEC = False 26 | 27 | def __init__(self, parent: dataset.MapDataset[T], sl: slice): 28 | super().__init__(parent) 29 | if not isinstance(sl, slice): 30 | raise ValueError(f"sl is not a slice: {type(sl)}") 31 | self._start, self._stop, self._step = sl.indices(len(parent)) 32 | self._length = len(range(self._start, self._stop, self._step)) 33 | 34 | def __len__(self) -> int: 35 | return self._length 36 | 37 | def __getitem__(self, index): 38 | if isinstance(index, slice): 39 | return SliceMapDataset(self, index) 40 | with self._stats.record_self_time(): 41 | parent_index = self._start + (index % len(self)) * self._step 42 | return self._parent[parent_index] 43 | 44 | def __str__(self) -> str: 45 | return f"SliceMapDataset[{self._start}:{self._stop}:{self._step}]" 46 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/slice_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for slice transformation.""" 15 | 16 | import itertools 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from grain._src.python.dataset import dataset 21 | import grain._src.python.dataset.transformations.slice as slice_ds 22 | from typing_extensions import override 23 | 24 | 25 | class EmptyMapDataset(dataset.MapDataset[int]): 26 | 27 | def __init__(self): 28 | super().__init__(parents=[]) 29 | 30 | @override 31 | def __len__(self) -> int: 32 | return 0 33 | 34 | @override 35 | def __getitem__(self, index): 36 | raise IndexError("Index out of range") 37 | 38 | 39 | class SliceMapDatasetTest(parameterized.TestCase): 40 | 41 | @parameterized.parameters( 42 | (0, 1, 20), 43 | (0, 2, 10), 44 | (1, 2, 10), 45 | (0, 3, 7), 46 | (1, 3, 7), 47 | (2, 3, 6), 48 | (30, 100, 0), 49 | ) 50 | def test_len(self, start: int, step: int, expected_len: int): 51 | ds = dataset.MapDataset.range(20) 52 | sl = slice(start, 20, step) 53 | range_ds_for_process = slice_ds.SliceMapDataset(ds, sl) 54 | self.assertLen(range_ds_for_process, expected_len) 55 | 56 | @parameterized.parameters( 57 | itertools.product(range(-8, 8), range(-9, 8), [-2, -1, 1, 2]) 58 | ) 59 | def test_getitem(self, start: int, stop: int, step: int): 60 | ds = dataset.MapDataset.range(20) 61 | ds = slice_ds.SliceMapDataset(ds, slice(start, stop, step)) 62 | ds_items = [ds[i] for i in range(len(ds))] 63 | self.assertSequenceEqual(ds_items, list(range(20))[start:stop:step]) 64 | 65 | @parameterized.parameters( 66 | itertools.product(range(-8, 8), range(-9, 8), [-2, -1, 1, 2]) 67 | ) 68 | def test_getitem_slice(self, start: int, stop: int, step: int): 69 | ds = dataset.MapDataset.range(20) 70 | ds = ds[start:stop:step] 71 | ds_items = [ds[i] for i in range(len(ds))] 72 | self.assertSequenceEqual(ds_items, list(range(20))[start:stop:step]) 73 | 74 | @parameterized.parameters( 75 | itertools.product(range(-8, 8), range(-9, 8), [-2, -1, 1, 2]) 76 | ) 77 | def test_iter(self, start: int, stop: int, step: int): 78 | ds = dataset.MapDataset.range(20) 79 | ds = slice_ds.SliceMapDataset(ds, slice(start, stop, step)) 80 | ds_iter = iter(ds) 81 | ds_items = list(ds_iter) 82 | self.assertSequenceEqual(ds_items, list(range(20))[start:stop:step]) 83 | 84 | def test_slice_of_empty_dataset_is_empty(self): 85 | ds = EmptyMapDataset() 86 | ds = slice_ds.SliceMapDataset(ds, slice(0, 10)) 87 | self.assertEmpty(ds) 88 | 89 | def test_accessing_items_beyond_len_minus_one_succeeds(self): 90 | ds = dataset.MapDataset.range(20) 91 | ds = slice_ds.SliceMapDataset(ds, slice(5)) # 0, 1, 2, 3, 4 92 | self.assertLen(ds, 5) 93 | self.assertEqual(ds[5], 0) 94 | self.assertEqual(ds[13], 3) 95 | self.assertEqual(ds[42], 2) 96 | 97 | def test_composing_slices_contains_correct_elements(self): 98 | ds = dataset.MapDataset.range(20) 99 | ds = slice_ds.SliceMapDataset(ds, slice(0, 15, 3)) # 0, 3, 6, 9, 12 100 | ds = slice_ds.SliceMapDataset(ds, slice(0, 20, 2)) # 0, 6, 12 101 | self.assertSequenceEqual(list(ds), [0, 6, 12]) 102 | 103 | 104 | if __name__ == "__main__": 105 | absltest.main() 106 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/source.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """LazyDataset data sources.""" 15 | from __future__ import annotations 16 | 17 | import functools 18 | from typing import Sequence, Union 19 | 20 | from absl import logging 21 | from grain._src.python import options 22 | from grain._src.python.dataset import base 23 | from grain._src.python.dataset import dataset 24 | 25 | 26 | class SourceMapDataset(dataset.MapDataset): 27 | """Simple wrapper for random access data sources.""" 28 | 29 | def __init__(self, source: base.RandomAccessDataSource): 30 | super().__init__() 31 | self._source = source 32 | 33 | def __len__(self) -> int: 34 | return len(self._source) 35 | 36 | def __str__(self) -> str: 37 | return f"SourceMapDataset(source={self._source.__class__.__name__})" 38 | 39 | def __getitem__(self, index): 40 | if isinstance(index, slice): 41 | return self.slice(index) 42 | with self._stats.record_self_time(): 43 | return self._stats.record_output_spec(self._source[index % len(self)]) 44 | 45 | def log_lineage(self): 46 | pass 47 | 48 | @property 49 | def paths(self) -> str | Sequence[str]: 50 | if hasattr(self._source, "paths"): 51 | assert isinstance(self._source, base.RandomAccessDataSource) 52 | return self._source.paths 53 | else: 54 | return [] 55 | 56 | 57 | def log_lineage_for_sources( 58 | root: Union[dataset.MapDataset, dataset.IterDataset], 59 | ): 60 | """Traverses tree of transformations and logs lineage on source datasets.""" 61 | pass 62 | 63 | 64 | class RangeMapDataset(dataset.MapDataset[int]): 65 | """Range data source, similar to python range() function.""" 66 | 67 | def __init__(self, start: int, stop: int | None = None, step: int = 1): 68 | super().__init__() 69 | self.start = 0 if stop is None else start 70 | self.stop = start if stop is None else stop 71 | self.step = step 72 | 73 | @functools.cached_property 74 | def _length(self) -> int: 75 | return len(range(self.start, self.stop, self.step)) 76 | 77 | def __len__(self) -> int: 78 | return self._length 79 | 80 | def __str__(self) -> str: 81 | return ( 82 | f"RangeMapDataset(start={self.start}, stop={self.stop}," 83 | f" step={self.step})" 84 | ) 85 | 86 | def __getitem__(self, index): 87 | if isinstance(index, slice): 88 | return self.slice(index) 89 | with self._stats.record_self_time(): 90 | return self._stats.record_output_spec( 91 | self.start + (index % self._length) * self.step 92 | ) 93 | 94 | def to_iter_dataset( 95 | self, 96 | read_options: options.ReadOptions | None = None, 97 | *, 98 | allow_nones: bool = False, 99 | ) -> dataset.IterDataset[int]: 100 | # Override the default multithreaded execution to avoid wasting memory. 101 | # The prefetch is not necessary since there's no IO. 102 | return super().to_iter_dataset( 103 | read_options=( 104 | read_options or options.ReadOptions(prefetch_buffer_size=0) 105 | ), 106 | allow_nones=allow_nones, 107 | ) 108 | -------------------------------------------------------------------------------- /grain/_src/python/dataset/transformations/source_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for LazyDataset data sources.""" 15 | import random 16 | from unittest import mock 17 | 18 | from absl.testing import absltest 19 | from grain._src.python.dataset import dataset 20 | from grain._src.python.dataset.transformations import source 21 | 22 | 23 | class _Interleave(dataset.MapDataset): 24 | 25 | def __len__(self): 26 | return sum((len(p) for p in self.parents)) 27 | 28 | def __getitem__(self, index): 29 | index, parent_index = divmod(index, len(self.parents)) 30 | return self.parents[parent_index][index] 31 | 32 | 33 | class SourceMapDatasetTest(absltest.TestCase): 34 | 35 | def setUp(self): 36 | super().setUp() 37 | self.sample_data_source = [1, 2, 3, 4, 5] 38 | self.lazy_dataset_source = source.SourceMapDataset( # pytype: disable=wrong-arg-types 39 | self.sample_data_source 40 | ) 41 | 42 | def test_lazy_dataset_source_len(self): 43 | self.assertLen(self.lazy_dataset_source, 5) 44 | 45 | def test_lazy_dataset_source_sequential_get(self): 46 | indices_to_read = [0, 1, 2, 3, 4] 47 | expected_data = [1, 2, 3, 4, 5] 48 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read] 49 | self.assertEqual(expected_data, actual_data) 50 | 51 | def test_lazy_dataset_source_reverse_sequential_get(self): 52 | indices_to_read = [0, 1, 2, 3, 4] 53 | expected_data = [1, 2, 3, 4, 5] 54 | indices_to_read.reverse() 55 | expected_data.reverse() 56 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read] 57 | self.assertEqual(expected_data, actual_data) 58 | 59 | def test_lazy_dataset_source_random_get(self): 60 | indices_to_read = [0, 1, 2, 3, 4] 61 | random.shuffle(indices_to_read) 62 | expected_data = [self.sample_data_source[i] for i in indices_to_read] 63 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read] 64 | self.assertEqual(expected_data, actual_data) 65 | 66 | def test_lazy_dataset_source_random_modulo_get(self): 67 | len_data_source = len(self.lazy_dataset_source) 68 | indices_to_read = [100, 207, 303, 401] 69 | expected_data = [ 70 | self.sample_data_source[i % len_data_source] for i in indices_to_read 71 | ] 72 | actual_data = [self.lazy_dataset_source[i] for i in indices_to_read] 73 | self.assertEqual(expected_data, actual_data) 74 | 75 | 76 | class RangeMapDatasetTest(absltest.TestCase): 77 | 78 | def test_len(self): 79 | ds = source.RangeMapDataset(12) 80 | self.assertLen(ds, 12) 81 | ds = source.RangeMapDataset(0, 12) 82 | self.assertLen(ds, 12) 83 | ds = source.RangeMapDataset(2, 12) 84 | self.assertLen(ds, 10) 85 | ds = source.RangeMapDataset(2, 12, 1) 86 | self.assertLen(ds, 10) 87 | ds = source.RangeMapDataset(2, 12, 2) 88 | self.assertLen(ds, 5) 89 | ds = source.RangeMapDataset(2, 13, 2) 90 | self.assertLen(ds, 6) 91 | 92 | def test_getitem(self): 93 | ds = source.RangeMapDataset(12) 94 | for i in range(12): 95 | self.assertEqual(ds[i], i) 96 | for i in range(12): 97 | self.assertEqual(ds[i + 12], i) 98 | ds = source.RangeMapDataset(2, 9, 2) 99 | self.assertEqual(ds[0], 2) 100 | self.assertEqual(ds[1], 4) 101 | self.assertEqual(ds[2], 6) 102 | self.assertEqual(ds[3], 8) 103 | self.assertEqual(ds[4], 2) 104 | self.assertEqual(ds[5], 4) 105 | 106 | def test_iter(self): 107 | ds = source.RangeMapDataset(12) 108 | ds_iter = iter(ds) 109 | elements = [next(ds_iter) for _ in range(12)] 110 | self.assertEqual(elements, list(range(12))) 111 | ds = source.RangeMapDataset(2, 9, 2) 112 | ds_iter = iter(ds) 113 | elements = [next(ds_iter) for _ in range(4)] 114 | self.assertEqual(elements, [2, 4, 6, 8]) 115 | 116 | 117 | if __name__ == "__main__": 118 | absltest.main() 119 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/example_packing/BUILD: -------------------------------------------------------------------------------- 1 | # Experimental transformation for example packing in PyGrain. 2 | 3 | package(default_visibility = ["//grain:__subpackages__"]) 4 | 5 | py_library( 6 | name = "packing", 7 | srcs = ["packing.py"], 8 | srcs_version = "PY3", 9 | deps = [ 10 | "//grain/_src/core:tree_lib", 11 | "//grain/_src/python:record", 12 | "@abseil-py//absl/logging", 13 | "@pypi//numpy:pkg", 14 | ], 15 | ) 16 | 17 | py_test( 18 | name = "packing_test", 19 | srcs = ["packing_test.py"], 20 | srcs_version = "PY3", 21 | deps = [ 22 | ":packing", 23 | "//grain/_src/core:tree_lib", 24 | "//grain/_src/python:record", 25 | "@abseil-py//absl/testing:absltest", 26 | "@pypi//jax:pkg", 27 | "@pypi//numpy:pkg", 28 | ], 29 | ) 30 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/index_shuffle/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//grain:__subpackages__"]) 2 | 3 | licenses(["notice"]) 4 | 5 | cc_library( 6 | name = "index_shuffle", 7 | srcs = ["index_shuffle.cc"], 8 | hdrs = ["index_shuffle.h"], 9 | ) 10 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/index_shuffle/index_shuffle.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2023 Google LLC. All Rights Reserved. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | ==============================================================================*/ 15 | 16 | #ifndef GRAIN_RANDOM_RANDOM_INDEX_SHUFFLE_H_ 17 | #define GRAIN_RANDOM_RANDOM_INDEX_SHUFFLE_H_ 18 | 19 | #include 20 | #include 21 | 22 | namespace grain { 23 | namespace random { 24 | 25 | // Returns the position of `index` in a permutation of [0, ..., max_index]. 26 | // 27 | // Index must be number in [0, ..., max_index]. 28 | // Key is the random key for the permutation. 29 | // The returned index will also be in [0, ..., max_index]. For a fixed `key` 30 | // and `max_index` all the possible `index` values and the returned values form 31 | // a bijection. 32 | // Rounds must be a positive even integer >= 4. Larger values improve 33 | // 'randomness' of permutations for small `max_index` values. The time to 34 | // compute the result scales linearly with the number of rounds. We recommend 8 35 | // rounds for a good trade off. 36 | // 37 | // For more details on the algorithm see the top of the cc file. 38 | uint64_t index_shuffle(uint64_t index, uint64_t max_index, uint32_t seed, 39 | uint32_t rounds); 40 | 41 | } // namespace random 42 | } // namespace grain 43 | 44 | #endif // GRAIN_RANDOM_RANDOM_INDEX_SHUFFLE_H_ 45 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/index_shuffle/python/BUILD: -------------------------------------------------------------------------------- 1 | load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") 2 | 3 | package(default_visibility = ["//grain:__subpackages__"]) 4 | 5 | licenses(["notice"]) 6 | 7 | pybind_extension( 8 | name = "index_shuffle_module", 9 | srcs = ["index_shuffle_module.cc"], 10 | deps = [ 11 | "//grain/_src/python/experimental/index_shuffle", 12 | ], 13 | ) 14 | 15 | py_test( 16 | name = "index_shuffle_test", 17 | srcs = ["index_shuffle_test.py"], 18 | data = [":index_shuffle_module.so"], 19 | srcs_version = "PY3", 20 | deps = ["@abseil-py//absl/testing:absltest"], 21 | ) 22 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/index_shuffle/python/index_shuffle_module.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "grain/_src/python/experimental/index_shuffle/index_shuffle.h" 4 | 5 | namespace py = pybind11; 6 | 7 | PYBIND11_MODULE(index_shuffle_module, m) { 8 | constexpr char kDoc[] = 9 | "Returns the position of `index` in a permutation of [0, ..., " 10 | "max_index]."; 11 | m.doc() = kDoc; 12 | m.def("index_shuffle", &::grain::random::index_shuffle, kDoc, 13 | py::arg("index"), py::arg("max_index"), py::arg("seed"), 14 | py::arg("rounds")); 15 | } 16 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/index_shuffle/python/index_shuffle_python.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Pure Python version of `index_shuffle`. 15 | 16 | This is roughly 10x slower than the C++ index_shuffle but still sufficiently 17 | fast for many use cases. Use it if the C++ version (and it's CLIF wrapper) don't 18 | work for you. 19 | """ 20 | 21 | import hashlib 22 | 23 | 24 | def _fingerprint(*args) -> int: 25 | """A 128-bit fingerprint based on md5. 26 | 27 | For data shuffling - not for cryptography. 28 | 29 | Args: 30 | *args: any argument list that can be converted to a string 31 | 32 | Returns: 33 | an integer in [0, 2 ** 128) 34 | """ 35 | return int.from_bytes(hashlib.md5(str(args).encode()).digest(), "little") 36 | 37 | 38 | def index_shuffle(index: int, max_index: int, seed: int, rounds: int) -> int: 39 | """computes the position of `index` after a pseudorandom permutation on `[0, max_index])`. 40 | 41 | Based on Feistel ciphers. 42 | 43 | For data shuffling - not for cryptography. 44 | 45 | if i != j, then 46 | pseudorandom_permutation(n, i, seed) != pseudorandom_permutation(n, j, seed) 47 | 48 | Args: 49 | index: an integer in [0, max_index) 50 | max_index: A positive integer. 51 | seed: A posivtive integer used as seed for the pseudorandom permutation. 52 | rounds: Ignored. For compatibility with C++ version. 53 | 54 | Returns: 55 | An integer in [0, max_index]. 56 | """ 57 | del rounds 58 | if not isinstance(max_index, int): 59 | raise ValueError("n must be an integer") 60 | 61 | if index < 0 or index > max_index: 62 | raise ValueError("out of range") 63 | 64 | if max_index == 1: 65 | return 0 66 | 67 | # smallest k such that max_index fits in 2k bits 68 | k = (max_index.bit_length() + 1) // 2 69 | assert max_index <= 4**k 70 | # Permute repeatedly in [max_index, 4 ** k) until you land back in 71 | # [0, max_index]. This constitutes a permutation of [0, max_index]. 72 | while True: 73 | # Feistel ciper on 2k bits - i.e. a permutation of [0, 4 ** k) 74 | a, b = index // (2**k), index % (2**k) 75 | for r in range(3): 76 | a, b = b, a ^ (_fingerprint(b, r, seed) % (2**k)) 77 | index = a * (2**k) + b 78 | if index <= max_index: 79 | return int(index) 80 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/index_shuffle/python/index_shuffle_python_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Minimal unit test for the Python wrapper of index_shuffle.""" 15 | 16 | from absl.testing import absltest 17 | from grain._src.python.experimental.index_shuffle.python import index_shuffle_python 18 | 19 | 20 | class IndexShuffleTest(absltest.TestCase): 21 | 22 | def test_index_shuffle(self): 23 | max_index = 46_204 24 | seen = set() 25 | for x in range(max_index + 1): 26 | y = index_shuffle_python.index_shuffle(x, max_index, seed=52, rounds=4) 27 | self.assertBetween(y, 0, max_index) 28 | seen.add(y) 29 | self.assertLen(seen, max_index + 1) 30 | 31 | def test_index_shuffle_huge_number(self): 32 | max_index = 1_234_567_891 33 | seen = set() 34 | for x in range(10_000): 35 | y = index_shuffle_python.index_shuffle(x, max_index, seed=27, rounds=4) 36 | self.assertBetween(y, 0, max_index) 37 | seen.add(y) 38 | self.assertLen(seen, 10_000) 39 | 40 | def test_index_shuffle_single_record(self): 41 | self.assertEqual( 42 | 0, 43 | index_shuffle_python.index_shuffle( 44 | index=0, max_index=0, seed=0, rounds=4 45 | ), 46 | ) 47 | 48 | 49 | if __name__ == '__main__': 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /grain/_src/python/experimental/index_shuffle/python/index_shuffle_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Minimal unit test for the Python wrapper of index_shuffle.""" 15 | 16 | from absl.testing import absltest 17 | from grain._src.python.experimental.index_shuffle.python import index_shuffle_module as index_shuffle 18 | 19 | 20 | class IndexShuffleTest(absltest.TestCase): 21 | 22 | def test_index_shuffle(self): 23 | max_index = 46_204 24 | seen = set() 25 | for x in range(max_index + 1): 26 | y = index_shuffle.index_shuffle(x, max_index, seed=52, rounds=4) 27 | self.assertBetween(y, 0, max_index) 28 | seen.add(y) 29 | self.assertLen(seen, max_index + 1) 30 | 31 | def test_index_shuffle_huge_number(self): 32 | max_index = 1_234_567_891 33 | seen = set() 34 | for x in range(10_000): 35 | y = index_shuffle.index_shuffle(x, max_index, seed=27, rounds=4) 36 | self.assertBetween(y, 0, max_index) 37 | seen.add(y) 38 | self.assertLen(seen, 10_000) 39 | 40 | def test_index_shuffle_single_record(self): 41 | self.assertEqual( 42 | 0, index_shuffle.index_shuffle(index=0, max_index=0, seed=0, rounds=4) 43 | ) 44 | 45 | 46 | if __name__ == '__main__': 47 | absltest.main() 48 | -------------------------------------------------------------------------------- /grain/_src/python/grain_logging.py: -------------------------------------------------------------------------------- 1 | """A library for adding a custom identifier to Python log messages. 2 | 3 | It allows adding a prefix identifying the process generating the log statement. 4 | The main purpose is to make logs more readable when using multiprocessing. 5 | """ 6 | 7 | import logging 8 | from absl import logging as absl_logging 9 | 10 | 11 | # Adds a prefix containing the `identifier` to all new Python log messages. 12 | def set_process_identifier_prefix(identifier: str) -> None: 13 | log_formatter = logging.Formatter(f'[{identifier}] %(message)s') 14 | absl_logging.get_absl_handler().setFormatter(log_formatter) 15 | -------------------------------------------------------------------------------- /grain/_src/python/grain_logging_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from absl import logging as absl_logging 4 | from grain._src.python import grain_logging 5 | from absl.testing import absltest 6 | 7 | 8 | class GrainLoggingTest(absltest.TestCase): 9 | 10 | def test_prefix_is_part_of_message(self): 11 | # self.assertLogs() doesn't format the messages, so we have to resort to 12 | # formatting directly with the absl handler to test whether the 13 | # prefix is being added. 14 | grain_logging.set_process_identifier_prefix('foo prefix') 15 | with self.assertLogs() as cm: 16 | absl_logging.info('example message') 17 | self.assertLen(cm.records, 1) 18 | log_record = cm.records[0] 19 | self.assertIn( 20 | 'foo prefix', absl_logging.get_absl_handler().format(log_record) 21 | ) 22 | 23 | def test_message_is_kept(self): 24 | grain_logging.set_process_identifier_prefix('Foo') 25 | with self.assertLogs() as cm: 26 | absl_logging.info('some info message: %i', 1337) 27 | self.assertLen(cm.output, 1) 28 | self.assertIn('some info message: 1337', cm.output[0]) 29 | 30 | def test_message_formatting(self): 31 | log_record = logging.LogRecord( 32 | name=absl_logging.get_absl_logger().name, 33 | level=logging.INFO, 34 | pathname='file.cc', 35 | lineno=42, 36 | msg='some info message: %i', 37 | args=(789,), 38 | exc_info=None, 39 | ) 40 | grain_logging.set_process_identifier_prefix('FooBarBaz 123') 41 | # We don't want to enforce a specific prefix format, but we want to avoid 42 | # duplicating the absl prefix (e.g.: 43 | # I0814 11:48:49.888083 1726756 grain_pool.py:161 44 | # ). 45 | self.assertTrue( 46 | re.search( 47 | r'.{0,5}FooBarBaz 123.{0,5} some info message: 789', 48 | absl_logging.get_absl_handler().format(log_record), 49 | ) 50 | ) 51 | 52 | 53 | if __name__ == '__main__': 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /grain/_src/python/load.py: -------------------------------------------------------------------------------- 1 | """High level APIs that serve as a single endpoint for very common use cases.""" 2 | 3 | from typing import Optional 4 | 5 | from grain._src.core import monitoring as grain_monitoring 6 | from grain._src.core import sharding 7 | from grain._src.core import transforms 8 | from grain._src.core import usage_logging 9 | from grain._src.python import data_loader 10 | from grain._src.python import data_sources 11 | from grain._src.python import options 12 | from grain._src.python import samplers 13 | 14 | from grain._src.core import monitoring 15 | 16 | 17 | _api_usage_counter = monitoring.Counter( 18 | "/grain/python/load/api", 19 | monitoring.Metadata(description="API initialization counter."), 20 | root=grain_monitoring.get_monitoring_root(), 21 | fields=[("name", str)], 22 | ) 23 | 24 | 25 | def load( 26 | source: data_sources.RandomAccessDataSource, 27 | *, 28 | num_epochs: Optional[int] = None, 29 | shuffle: bool = False, 30 | seed: Optional[int] = None, 31 | shard_options: sharding.ShardOptions = sharding.NoSharding(), 32 | transformations: transforms.Transformations = (), 33 | batch_size: Optional[int] = None, 34 | drop_remainder: bool = False, 35 | worker_count: Optional[int] = 0, 36 | read_options: Optional[options.ReadOptions] = None, 37 | ) -> data_loader.DataLoader: 38 | """Convenient method for simple pipelines on top of a data source. 39 | 40 | Args: 41 | source: Data source to load from. This can be one of the file data sources 42 | provided by Grain, a TFDS data source (`tfds.data_source(...)`) or your 43 | custom data source. 44 | num_epochs: See IndexSampler. 45 | shuffle: See IndexSampler. 46 | seed: See IndexSampler. 47 | shard_options: See IndexSampler. 48 | transformations: List of local (stateless) transformations: 49 | batch_size: Optional batch size. If provided will apply BatchOperation(). 50 | drop_remainder: Whether to drop partial batches. 51 | worker_count: Number of child processes launched to parallelize the 52 | transformations among. Zero means processing runs in the same process. 53 | read_options: Read options for the data loader. See ReadOptions. 54 | 55 | Returns: 56 | DataLoader for this dataset. 57 | """ 58 | usage_logging.log_event("load", tag_3="PyGrain") 59 | _api_usage_counter.Increment("load") 60 | sampler = samplers.IndexSampler( 61 | num_records=len(source), 62 | shuffle=shuffle, 63 | seed=seed, 64 | num_epochs=num_epochs, 65 | shard_options=shard_options, 66 | ) 67 | if batch_size is not None: 68 | transformations = list(transformations) 69 | transformations.append( 70 | transforms.Batch(batch_size, drop_remainder=drop_remainder) 71 | ) 72 | return data_loader.DataLoader( 73 | data_source=source, 74 | sampler=sampler, 75 | operations=transformations, 76 | worker_count=worker_count, 77 | read_options=read_options, 78 | ) 79 | -------------------------------------------------------------------------------- /grain/_src/python/load_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for load().""" 15 | 16 | from absl import flags 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from grain._src.core import transforms 20 | import multiprocessing as mp 21 | from grain._src.python import data_sources 22 | from grain._src.python import load 23 | import numpy as np 24 | 25 | 26 | FLAGS = flags.FLAGS 27 | 28 | 29 | class FilterEven(transforms.Filter): 30 | 31 | def filter(self, x: int) -> bool: 32 | return x % 2 == 0 33 | 34 | 35 | class PlusOne(transforms.MapTransform): 36 | 37 | def map(self, x: int) -> int: 38 | return x + 1 39 | 40 | 41 | class DataLoaderTest(parameterized.TestCase): 42 | 43 | def test_with_range_source(self): 44 | range_data_source = data_sources.RangeDataSource(start=0, stop=8, step=1) 45 | transformations = [ 46 | PlusOne(), 47 | FilterEven(), 48 | ] 49 | data_loader = load.load( 50 | range_data_source, transformations=transformations, num_epochs=1 51 | ) 52 | expected = [2, 4, 6, 8] 53 | actual = list(data_loader) 54 | np.testing.assert_equal(actual, expected) 55 | 56 | def test_with_range_source_with_batch(self): 57 | range_data_source = data_sources.RangeDataSource(start=0, stop=8, step=1) 58 | transformations = [ 59 | PlusOne(), 60 | FilterEven(), 61 | ] 62 | data_loader = load.load( 63 | range_data_source, 64 | transformations=transformations, 65 | batch_size=2, 66 | num_epochs=1, 67 | ) 68 | expected = [np.array([2, 4]), np.array([6, 8])] 69 | actual = list(data_loader) 70 | np.testing.assert_equal(actual, expected) 71 | 72 | 73 | if __name__ == "__main__": 74 | absltest.main() 75 | -------------------------------------------------------------------------------- /grain/_src/python/multiprocessing_common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module defines common functions for multiprocessing/threading.""" 15 | 16 | import dataclasses 17 | import multiprocessing 18 | from multiprocessing import pool 19 | import queue 20 | from typing import TypeVar, Union, Callable 21 | 22 | T = TypeVar('T') 23 | 24 | _QUEUE_WAIT_TIMEOUT_SECONDS = 0.5 25 | _ASYNC_RESULT_WAIT_TIMEOUT_SECONDS = 0.5 26 | 27 | 28 | @dataclasses.dataclass 29 | class _SystemTerminated: 30 | """When system terminates, this is returned instead of actual elements.""" 31 | 32 | 33 | SYSTEM_TERMINATED = _SystemTerminated() 34 | 35 | 36 | def add_element_to_queue( 37 | element: T, 38 | elements_queue: queue.Queue[T], 39 | should_stop: Callable[[], bool], 40 | ) -> bool: 41 | """Try adding element to queue as long as should_stop() is not True. 42 | 43 | Args: 44 | element: Element to add. 45 | elements_queue: Target queue. 46 | should_stop: Callable to check whether addition should proceed (possibly 47 | after a re-try). 48 | 49 | Returns: 50 | Bool indicating whether addition was successfull. 51 | """ 52 | while not should_stop(): 53 | try: 54 | elements_queue.put(element, timeout=_QUEUE_WAIT_TIMEOUT_SECONDS) 55 | return True 56 | except queue.Full: 57 | pass 58 | return False 59 | 60 | 61 | def get_element_from_queue( 62 | elements_queue: queue.Queue[T], 63 | should_stop: Callable[[], bool], 64 | ) -> Union[T, _SystemTerminated]: 65 | """Try getting element from queue as long as should_stop() is not True.""" 66 | while not should_stop(): 67 | try: 68 | return elements_queue.get(timeout=_QUEUE_WAIT_TIMEOUT_SECONDS) 69 | except queue.Empty: 70 | pass 71 | return SYSTEM_TERMINATED 72 | 73 | 74 | def get_async_result( 75 | async_result: pool.AsyncResult[T], 76 | should_stop: Callable[[], bool], 77 | ) -> Union[T, _SystemTerminated]: 78 | """Wait for async result as long as should_stop() is not True.""" 79 | while not should_stop(): 80 | try: 81 | return async_result.get(timeout=_ASYNC_RESULT_WAIT_TIMEOUT_SECONDS) 82 | except multiprocessing.TimeoutError: 83 | pass 84 | return SYSTEM_TERMINATED 85 | -------------------------------------------------------------------------------- /grain/_src/python/multiprocessing_common_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for multiprocessing common functions.""" 15 | 16 | import multiprocessing 17 | from multiprocessing import pool 18 | import queue 19 | 20 | from absl.testing import absltest 21 | from grain._src.python import multiprocessing_common 22 | 23 | 24 | class MultiProcessingCommonTest(absltest.TestCase): 25 | 26 | def test_add_element_to_queue(self): 27 | test_queue = multiprocessing.Queue() 28 | element = 1 29 | termination_event = multiprocessing.Event() 30 | self.assertTrue( 31 | multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types 32 | element=element, 33 | elements_queue=test_queue, 34 | should_stop=termination_event.is_set, 35 | ) 36 | ) 37 | self.assertEqual(test_queue.get(), 1) 38 | 39 | def test_add_element_to_queue_already_terminated(self): 40 | test_queue = multiprocessing.Queue() 41 | element = 1 42 | termination_event = multiprocessing.Event() 43 | termination_event.set() 44 | self.assertFalse( 45 | multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types 46 | element=element, 47 | elements_queue=test_queue, 48 | should_stop=termination_event.is_set, 49 | ) 50 | ) 51 | with self.assertRaises(queue.Empty): 52 | test_queue.get(timeout=0.1) 53 | 54 | def test_get_element_from_queue(self): 55 | test_queue = multiprocessing.Queue() 56 | expected_element = 1 57 | test_queue.put(expected_element) 58 | termination_event = multiprocessing.Event() 59 | actual_element = multiprocessing_common.get_element_from_queue( # pytype: disable=wrong-arg-types 60 | elements_queue=test_queue, 61 | should_stop=termination_event.is_set, 62 | ) 63 | self.assertEqual(actual_element, expected_element) 64 | 65 | def test_get_element_from_queue_already_terminated(self): 66 | test_queue = multiprocessing.Queue() 67 | expected_element = 1 68 | test_queue.put(expected_element) 69 | termination_event = multiprocessing.Event() 70 | termination_event.set() 71 | actual_element = multiprocessing_common.get_element_from_queue( # pytype: disable=wrong-arg-types 72 | elements_queue=test_queue, 73 | should_stop=termination_event.is_set, 74 | ) 75 | self.assertEqual(actual_element, multiprocessing_common.SYSTEM_TERMINATED) 76 | 77 | def test_get_async_result(self): 78 | thread_pool = pool.ThreadPool(1) 79 | async_result = thread_pool.apply_async(func=lambda x: x + 1, args=(1,)) 80 | termination_event = multiprocessing.Event() 81 | result = multiprocessing_common.get_async_result( 82 | should_stop=termination_event.is_set, async_result=async_result 83 | ) 84 | self.assertEqual(result, 2) 85 | 86 | def test_get_async_result_already_terminated(self): 87 | thread_pool = pool.ThreadPool(1) 88 | async_result = thread_pool.apply_async(func=lambda x: x + 1, args=(1,)) 89 | termination_event = multiprocessing.Event() 90 | termination_event.set() 91 | result = multiprocessing_common.get_async_result( 92 | should_stop=termination_event.is_set, async_result=async_result 93 | ) 94 | self.assertEqual(result, multiprocessing_common.SYSTEM_TERMINATED) 95 | 96 | 97 | if __name__ == "__main__": 98 | absltest.main() 99 | -------------------------------------------------------------------------------- /grain/_src/python/options.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Dataclasses for holdings options.""" 15 | import dataclasses 16 | 17 | 18 | @dataclasses.dataclass(slots=True) 19 | class ReadOptions: 20 | """Options for reading data from the DataSource. 21 | 22 | These settings configure a single Python process. Each process uses separate 23 | threads and buffer for reading and processing data. 24 | 25 | Example: With ReadOptions.num_threads=8 and 26 | MultiprocessingOptions.num_workers=10 there will be 80 threads reading the 27 | data (8 threads in each of 10 Python processes). 28 | 29 | Attributes: 30 | num_threads: Number of threads reading from the DataSource in parallel. If 31 | the data are already loaded in memory, we recommend setting this to 0 to 32 | avoid Python GIL contention by multiple threads. 33 | prefetch_buffer_size: Size of the buffer for reading elements per Python 34 | process (not per thread). Useful when reading from a distributed file 35 | system. 36 | """ 37 | 38 | # The current default values where chosen by running a few selected 39 | # benchmarks reading from remote hard drives. 40 | # These values should work well for datasets with elements between 1 and 41 | # 10 KiB on disk. 42 | num_threads: int = 16 43 | prefetch_buffer_size: int = 500 44 | 45 | 46 | @dataclasses.dataclass(slots=True) 47 | class MultiprocessingOptions: 48 | """Options for using Python multiprocessing. 49 | 50 | Attributes: 51 | num_workers: Number of Python worker processes. More processes can speed up 52 | the pipeline if it's compute bound and bottlenecked on the CPython's GIL. 53 | The default value of 0 means no Python multiprocessing, and as a result 54 | all data loading and transformation will run in the main Python process. 55 | per_worker_buffer_size: Size of the buffer for preprocessed elements that 56 | each worker maintains. These are elements after all transformations. If 57 | your transformations include batching this means a single element is a 58 | batch. 59 | enable_profiling: If True, profiling info is logged. This is only available 60 | when num_workers >= 1. 61 | """ 62 | 63 | num_workers: int = 0 64 | per_worker_buffer_size: int = 1 65 | enable_profiling: bool = False 66 | -------------------------------------------------------------------------------- /grain/_src/python/record.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Define record class used by various modules in the Grain Python Backend.""" 15 | 16 | import dataclasses 17 | from typing import Generic, Optional, TypeVar 18 | import numpy as np 19 | 20 | T = TypeVar("T") 21 | 22 | 23 | @dataclasses.dataclass(slots=True) 24 | class RecordMetadata: 25 | """RecordMetadata contains metadata about indidivual records. 26 | 27 | Metadata can be emitted by the sampler to refer to which record to read next. 28 | In addition, they are also used to keep information about records as they flow 29 | through the pipeline from one operation to the other. 30 | """ 31 | 32 | index: int 33 | record_key: Optional[int] = None 34 | rng: Optional[np.random.Generator] = None 35 | 36 | def remove_record_key(self): 37 | """Removes record key if exists.""" 38 | if self.record_key is None: 39 | return self 40 | else: 41 | return dataclasses.replace(self, record_key=None) 42 | 43 | def __str__(self): 44 | return ( 45 | f"RecordMetadata(index={self.index}, record_key={self.record_key}, " 46 | f"rng={self.rng})" 47 | ) 48 | 49 | 50 | @dataclasses.dataclass(slots=True) 51 | class Record(Generic[T]): 52 | metadata: RecordMetadata 53 | data: T 54 | -------------------------------------------------------------------------------- /grain/_src/python/record_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for record.""" 15 | 16 | from grain._src.python import record 17 | import numpy as np 18 | from absl.testing import absltest 19 | 20 | 21 | class RecordTest(absltest.TestCase): 22 | 23 | def test_RecordMetadata_str(self): 24 | record_metadata = record.RecordMetadata( 25 | index=0, record_key=0, rng=np.random.default_rng() 26 | ) 27 | self.assertEqual( 28 | str(record_metadata), 29 | "RecordMetadata(index=0, record_key=0, rng=Generator(PCG64))", 30 | ) 31 | 32 | def test_RecordMetadata_str_none_rng(self): 33 | record_metadata = record.RecordMetadata(index=0, record_key=0) 34 | self.assertStartsWith( 35 | str(record_metadata), 36 | "RecordMetadata(index=0, record_key=0, rng=None", 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | absltest.main() 42 | -------------------------------------------------------------------------------- /grain/_src/python/testdata/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//grain:__subpackages__"]) 2 | 3 | licenses(["notice"]) 4 | 5 | exports_files(glob(["*.array_record-*"])) 6 | 7 | exports_files(glob(["*.tfrecord"])) 8 | -------------------------------------------------------------------------------- /grain/_src/python/testdata/digits.array_record-00000-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/grain/_src/python/testdata/digits.array_record-00000-of-00002 -------------------------------------------------------------------------------- /grain/_src/python/testdata/digits.array_record-00001-of-00002: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/grain/_src/python/testdata/digits.array_record-00001-of-00002 -------------------------------------------------------------------------------- /grain/_src/python/testdata/morris_sequence_first_5.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/grain/48732668b80f475b5d91aa64cb4326c4db8daed4/grain/_src/python/testdata/morris_sequence_first_5.tfrecord -------------------------------------------------------------------------------- /grain/_src/python/testing/BUILD: -------------------------------------------------------------------------------- 1 | package(default_visibility = ["//grain:__subpackages__"]) 2 | 3 | py_library( 4 | name = "experimental", 5 | srcs = ["experimental.py"], 6 | srcs_version = "PY3", 7 | deps = [ 8 | "//grain/_src/core:tree_lib", 9 | "@pypi//numpy:pkg", 10 | ], 11 | ) 12 | -------------------------------------------------------------------------------- /grain/_src/python/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /grain/_src/python/testing/experimental.py: -------------------------------------------------------------------------------- 1 | """API to test checkpointing.""" 2 | 3 | import itertools 4 | from typing import Any 5 | 6 | from grain._src.core import tree_lib 7 | import numpy as np 8 | 9 | 10 | def assert_equal_output_after_checkpoint( 11 | ds: Any, 12 | ): 13 | """Tests restoring an iterator to various checkpointed states. 14 | 15 | Args: 16 | ds: The dataset to test. It is recommended to use a small dataset, 17 | potentially created using `grain.python.experimental.LimitIterDataset`, to 18 | restrict the number of steps being tested. The underlying dataset iterator 19 | must implement `get_state` and `set_state` for checkpointing. 20 | """ 21 | 22 | iterator = ds.__iter__() 23 | checkpoints = [] 24 | expected_values = [] 25 | state_spec = None 26 | for i in itertools.count(): 27 | current_state = iterator.get_state() 28 | if state_spec is None: 29 | state_spec = tree_lib.spec_like(current_state) 30 | else: 31 | np.testing.assert_equal( 32 | state_spec, 33 | tree_lib.spec_like(current_state), 34 | f"State spec does not match the original state spec at step {i}." 35 | f" Expected: {state_spec}," 36 | f" Actual: {tree_lib.spec_like(current_state)}", 37 | ) 38 | try: 39 | value = next(iterator) 40 | except StopIteration: 41 | break 42 | checkpoints.append(current_state) 43 | expected_values.append(value) 44 | 45 | assert expected_values, "Dataset did not produce any elements." 46 | 47 | # Restore the iterator at every state, and compare the values. 48 | for i, state in enumerate(checkpoints): 49 | new_iterator = ds.__iter__() 50 | new_iterator.set_state(state) 51 | np.testing.assert_equal( 52 | new_iterator.get_state(), 53 | state, 54 | f"Restored state does not match the original state at step {i}." 55 | f" Expected: {state}, Actual: {new_iterator.get_state()}", 56 | ) 57 | 58 | # Test the values at the current state. 59 | new_values = list(new_iterator) 60 | np.testing.assert_equal( 61 | new_values, 62 | expected_values[i:], 63 | f"Restored values mismatch at step {i} for state {state}." 64 | f" \nExpected: {expected_values[i:]}, \nActual: {new_values}", 65 | ) 66 | -------------------------------------------------------------------------------- /grain/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """APIs for saving and restoring pipeline state.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=g-import-not-at-top 19 | # pylint: disable=unused-import 20 | # pylint: disable=g-multiple-import 21 | 22 | from grain._src.python.checkpoint_handlers import ( 23 | CheckpointHandler, 24 | ) 25 | 26 | # These are imported only if Orbax is present. 27 | try: 28 | from grain._src.python.checkpoint_handlers import ( 29 | CheckpointSave, 30 | CheckpointRestore, 31 | ) 32 | except ImportError: 33 | pass 34 | -------------------------------------------------------------------------------- /grain/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Metadata constants.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=unused-import 19 | # pylint: disable=g-multiple-import 20 | 21 | from grain._src.core.constants import ( 22 | DATASET_INDEX, 23 | EPOCH, 24 | INDEX, 25 | META_FEATURES, 26 | RECORD, 27 | RECORD_KEY, 28 | SEED, 29 | ) 30 | -------------------------------------------------------------------------------- /grain/experimental.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Experimental Grain APIs.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=g-bad-import-order 19 | # pylint: disable=g-multiple-import 20 | # pylint: disable=unused-import 21 | 22 | from grain._src.core.transforms import FlatMapTransform 23 | 24 | from grain._src.python.dataset.base import ( 25 | DatasetOptions, 26 | ExecutionTrackingMode, 27 | ) 28 | from grain._src.python.dataset.dataset import ( 29 | apply_transformations, 30 | WithOptionsIterDataset, 31 | ) 32 | from grain._src.python.dataset.elastic_iterator import ElasticIterator 33 | from grain._src.python.dataset.sources.parquet_dataset import ParquetIterDataset 34 | from grain._src.python.dataset.sources.tfrecord_dataset import TFRecordIterDataset 35 | 36 | from grain._src.python.dataset.transformations.flatmap import ( 37 | FlatMapMapDataset, 38 | FlatMapIterDataset, 39 | ) 40 | from grain._src.python.dataset.transformations.interleave import ( 41 | InterleaveIterDataset, 42 | ) 43 | from grain._src.python.dataset.transformations.limit import LimitIterDataset 44 | from grain._src.python.dataset.transformations.map import RngPool 45 | from grain._src.python.dataset.transformations.packing import ( 46 | FirstFitPackIterDataset, 47 | ) 48 | from grain._src.python.dataset.transformations.packing_concat_then_split import ( 49 | BOSHandling, 50 | ConcatThenSplitIterDataset, 51 | ) 52 | from grain._src.python.dataset.transformations.prefetch import ( 53 | ThreadPrefetchIterDataset, 54 | ) 55 | from grain._src.python.dataset.transformations.shuffle import ( 56 | WindowShuffleMapDataset, 57 | WindowShuffleIterDataset, 58 | ) 59 | from grain._src.python.dataset.transformations.zip import ( 60 | ZipMapDataset, 61 | ZipIterDataset, 62 | ) 63 | from grain._src.python.experimental.example_packing.packing import ( 64 | PackAndBatchOperation, 65 | ) 66 | from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import ( 67 | index_shuffle, 68 | ) 69 | 70 | # This should eventually live under grain.testing. 71 | from grain._src.python.testing.experimental import ( 72 | assert_equal_output_after_checkpoint, 73 | ) 74 | -------------------------------------------------------------------------------- /grain/multiprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """APIs to work with multiprocessing.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=unused-import 19 | 20 | from grain._src.python.options import MultiprocessingOptions 21 | from grain._src.python.shared_memory_array import SharedMemoryArray 22 | -------------------------------------------------------------------------------- /grain/oss/Dockerfile: -------------------------------------------------------------------------------- 1 | # Constructs the environment within which we will build the grain pip wheels. 2 | 3 | 4 | ARG AUDITWHEEL_PLATFORM 5 | 6 | FROM quay.io/pypa/${AUDITWHEEL_PLATFORM} 7 | LABEL maintainer="Grain team " 8 | 9 | ARG PYTHON_VERSION 10 | ARG BAZEL_VERSION 11 | 12 | ENV DEBIAN_FRONTEND=noninteractive 13 | 14 | RUN ulimit -n 1024 && yum install -y rsync 15 | 16 | ENV PYTHON_BIN_PATH=/opt/python/cp${PYTHON_VERSION}-cp${PYTHON_VERSION}/bin 17 | ENV PATH="${PYTHON_BIN_PATH}:${PATH}" 18 | 19 | ENV PYTHON_BIN=${PYTHON_BIN_PATH}/python 20 | 21 | # Download the correct bazel version and make sure it's on path. 22 | RUN BAZEL_ARCH_SUFFIX="$(uname -m | sed s/aarch64/arm64/)" \ 23 | && curl -sSL --fail -o /usr/local/bin/bazel "https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-linux-$BAZEL_ARCH_SUFFIX" \ 24 | && chmod a+x /usr/local/bin/bazel 25 | 26 | # Install dependencies needed for grain. 27 | RUN --mount=type=cache,target=/root/.cache \ 28 | $PYTHON_BIN -m pip install -U \ 29 | setuptools; 30 | 31 | 32 | WORKDIR "/tmp/grain" -------------------------------------------------------------------------------- /grain/oss/array_record/Dockerfile.patch: -------------------------------------------------------------------------------- 1 | diff --git a/oss/build.Dockerfile b/oss/build.Dockerfile 2 | index 5fefa86..17fd296 100644 3 | --- a/oss/build.Dockerfile 4 | +++ b/oss/build.Dockerfile 5 | @@ -11,7 +11,7 @@ ARG BAZEL_VERSION 6 | 7 | ENV DEBIAN_FRONTEND=noninteractive 8 | 9 | -RUN yum install -y rsync 10 | +RUN ulimit -n 1024 && yum install -y rsync 11 | ENV PATH="${PYTHON_BIN}:${PATH}" 12 | 13 | # Download the correct bazel version and make sure it's on path. 14 | -------------------------------------------------------------------------------- /grain/oss/array_record/WORKSPACE.patch: -------------------------------------------------------------------------------- 1 | diff --git a/WORKSPACE b/WORKSPACE 2 | index e63922f..a655bdb 100644 3 | --- a/WORKSPACE 4 | +++ b/WORKSPACE 5 | @@ -3,13 +3,11 @@ workspace(name = "array_record") 6 | load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") 7 | load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") 8 | 9 | -# Abseil LTS 20230125.0 10 | http_archive( 11 | name = "com_google_absl", 12 | - sha256 = "3ea49a7d97421b88a8c48a0de16c16048e17725c7ec0f1d3ea2683a2a75adc21", # SHARED_ABSL_SHA 13 | - strip_prefix = "abseil-cpp-20230125.0", 14 | + strip_prefix = "abseil-cpp-20230802.1", 15 | urls = [ 16 | - "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230125.0.tar.gz", 17 | + "https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.1.tar.gz", 18 | ], 19 | ) 20 | # Version: pypi-v0.11.0, 2020/10/27 21 | @@ -70,15 +68,13 @@ http_archive( 22 | load("@pybind11_bazel//:python_configure.bzl", "python_configure") 23 | python_configure(name = "local_config_python") 24 | 25 | -# V21.12, 20230130 26 | # proto_library, cc_proto_library, and java_proto_library rules implicitly 27 | # depend on @com_google_protobuf for protoc and proto runtimes. 28 | # This statement defines the @com_google_protobuf repo. 29 | http_archive( 30 | name = "com_google_protobuf", 31 | - sha256 = "22fdaf641b31655d4b2297f9981fa5203b2866f8332d3c6333f6b0107bb320de", 32 | - strip_prefix = "protobuf-21.12", 33 | - urls = ["https://github.com/protocolbuffers/protobuf/archive/v21.12.tar.gz"], 34 | + strip_prefix = "protobuf-23.1", 35 | + urls = ["https://github.com/protocolbuffers/protobuf/archive/v23.1.tar.gz"], 36 | ) 37 | 38 | load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps") 39 | @@ -87,9 +83,9 @@ protobuf_deps() 40 | # Riegeli does not cut releases, so we reference the head 41 | http_archive( 42 | name = "com_google_riegeli", 43 | - strip_prefix = "riegeli-master", 44 | + strip_prefix = "riegeli-904c0c263b8632265103f0066c168a92c7713b07", 45 | urls = [ 46 | - "https://github.com/google/riegeli/archive/master.zip", 47 | + "https://github.com/google/riegeli/archive/904c0c263b8632265103f0066c168a92c7713b07.zip", 48 | ], 49 | ) 50 | # Riegeli's dependencies 51 | @@ -131,9 +127,8 @@ http_archive( 52 | http_archive( 53 | name = "highwayhash", 54 | build_file = "@com_google_riegeli//third_party:highwayhash.BUILD", 55 | - sha256 = "5380cb7cf19e7c9591f31792b7794d48084f6a3ab7c03d637cd6a32cf2ee8686", 56 | - strip_prefix = "highwayhash-a7f68e2f95fac08b24327d74747521cf634d5aff", 57 | - urls = ["https://github.com/google/highwayhash/archive/a7f68e2f95fac08b24327d74747521cf634d5aff.zip"], # 2023-08-09 58 | + strip_prefix = "highwayhash-3d6a8d35a6bc823b9dbe08804fc2a2d08d373cd7", 59 | + urls = ["https://github.com/google/highwayhash/archive/3d6a8d35a6bc823b9dbe08804fc2a2d08d373cd7.zip"], 60 | ) 61 | 62 | # Tensorflow, 20230705 63 | -------------------------------------------------------------------------------- /grain/oss/array_record/array_record_reader.patch: -------------------------------------------------------------------------------- 1 | diff --git a/cpp/array_record_reader.cc b/cpp/array_record_reader.cc 2 | index bc7675e..623911a 100644 3 | --- a/cpp/array_record_reader.cc 4 | +++ b/cpp/array_record_reader.cc 5 | @@ -196,7 +196,7 @@ void ArrayRecordReaderBase::Initialize() { 6 | max_parallelism = state_->pool->NumThreads(); 7 | if (state_->options.max_parallelism().has_value()) { 8 | max_parallelism = 9 | - std::min(max_parallelism, state_->options.max_parallelism().value()); 10 | + std::min(max_parallelism, state_->options.max_parallelism().value()); 11 | } 12 | } 13 | state_->options.set_max_parallelism(max_parallelism); 14 | @@ -331,7 +331,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecords( 15 | return absl::OkStatus(); 16 | } 17 | uint64_t num_chunk_groups = 18 | - CeilOfRatio(state_->chunk_offsets.size(), state_->chunk_group_size); 19 | + CeilOfRatio(state_->chunk_offsets.size(), state_->chunk_group_size); 20 | const auto reader = get_backing_reader(); 21 | Reader* mutable_reader = const_cast( 22 | reinterpret_cast(reader.get())); 23 | @@ -340,7 +340,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecords( 24 | uint64_t chunk_idx_start = buf_idx * state_->chunk_group_size; 25 | // inclusive index, not the conventional exclusive index. 26 | uint64_t last_chunk_idx = 27 | - std::min((buf_idx + 1) * state_->chunk_group_size - 1, 28 | + std::min((buf_idx + 1) * state_->chunk_group_size - 1, 29 | state_->chunk_offsets.size() - 1); 30 | uint64_t buf_len = state_->ChunkEndOffset(last_chunk_idx) - 31 | state_->chunk_offsets[chunk_idx_start]; 32 | @@ -406,9 +406,9 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsInRange( 33 | "Invalid range [%d, %d). Total records: %d", begin, end, NumRecords())); 34 | } 35 | uint64_t chunk_idx_begin = begin / state_->record_group_size; 36 | - uint64_t chunk_idx_end = CeilOfRatio(end, state_->record_group_size); 37 | + uint64_t chunk_idx_end = CeilOfRatio(end, state_->record_group_size); 38 | uint64_t num_chunks = chunk_idx_end - chunk_idx_begin; 39 | - uint64_t num_chunk_groups = CeilOfRatio(num_chunks, state_->chunk_group_size); 40 | + uint64_t num_chunk_groups = CeilOfRatio(num_chunks, state_->chunk_group_size); 41 | 42 | const auto reader = get_backing_reader(); 43 | Reader* mutable_reader = 44 | @@ -418,7 +418,7 @@ absl::Status ArrayRecordReaderBase::ParallelReadRecordsInRange( 45 | uint64_t chunk_idx_start = 46 | chunk_idx_begin + buf_idx * state_->chunk_group_size; 47 | // inclusive index, not the conventional exclusive index. 48 | - uint64_t last_chunk_idx = std::min( 49 | + uint64_t last_chunk_idx = std::min( 50 | chunk_idx_begin + (buf_idx + 1) * state_->chunk_group_size - 1, 51 | chunk_idx_end - 1); 52 | uint64_t buf_len = state_->ChunkEndOffset(last_chunk_idx) - 53 | @@ -617,7 +617,7 @@ bool ArrayRecordReaderBase::SeekRecord(uint64_t record_index) { 54 | if (!ok()) { 55 | return false; 56 | } 57 | - state_->record_idx = std::min(record_index, state_->num_records); 58 | + state_->record_idx = std::min(record_index, state_->num_records); 59 | return true; 60 | } 61 | 62 | @@ -667,7 +667,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) { 63 | std::vector decoders; 64 | decoders.reserve(state_->chunk_group_size); 65 | uint64_t chunk_start = buffer_idx * state_->chunk_group_size; 66 | - uint64_t chunk_end = std::min(state_->chunk_offsets.size(), 67 | + uint64_t chunk_end = std::min(state_->chunk_offsets.size(), 68 | (buffer_idx + 1) * state_->chunk_group_size); 69 | const auto reader = get_backing_reader(); 70 | for (uint64_t chunk_idx = chunk_start; chunk_idx < chunk_end; ++chunk_idx) { 71 | @@ -708,7 +708,7 @@ bool ArrayRecordReaderBase::ReadAheadFromBuffer(uint64_t buffer_idx) { 72 | chunk_offsets.reserve(state_->chunk_group_size); 73 | uint64_t chunk_start = buffer_to_add * state_->chunk_group_size; 74 | uint64_t chunk_end = 75 | - std::min(state_->chunk_offsets.size(), 76 | + std::min(state_->chunk_offsets.size(), 77 | (buffer_to_add + 1) * state_->chunk_group_size); 78 | for (uint64_t chunk_idx = chunk_start; chunk_idx < chunk_end; ++chunk_idx) { 79 | chunk_offsets.push_back(state_->chunk_offsets[chunk_idx]); 80 | -------------------------------------------------------------------------------- /grain/oss/array_record/build_whl.patch: -------------------------------------------------------------------------------- 1 | diff --git a/oss/build_whl.sh b/oss/build_whl.sh 2 | index 275c868..27080a9 100755 3 | --- a/oss/build_whl.sh 4 | +++ b/oss/build_whl.sh 5 | @@ -27,11 +27,15 @@ function main() { 6 | [ -e .bazelrc ] && rm .bazelrc 7 | 8 | write_to_bazelrc "build -c opt" 9 | + write_to_bazelrc "build --action_env MACOSX_DEPLOYMENT_TARGET=11.0" 10 | write_to_bazelrc "build --cxxopt=-std=c++17" 11 | write_to_bazelrc "build --host_cxxopt=-std=c++17" 12 | - write_to_bazelrc "build --linkopt=\"-lrt -lm\"" 13 | write_to_bazelrc "build --experimental_repo_remote_exec" 14 | write_to_bazelrc "build --python_path=\"${PYTHON_BIN}\"" 15 | + PLATFORM="$(uname)" 16 | + if [[ "$PLATFORM" != "Darwin" ]]; then 17 | + write_to_bazelrc "build --linkopt=\"-lrt -lm\"" 18 | + fi 19 | 20 | if [ -n "${CROSSTOOL_TOP}" ]; then 21 | write_to_bazelrc "build --crosstool_top=${CROSSTOOL_TOP}" 22 | @@ -42,8 +46,8 @@ function main() { 23 | # https://github.com/bazelbuild/bazel/issues/8622 24 | export USE_BAZEL_VERSION=5.4.0 25 | bazel clean 26 | - bazel build ... 27 | - bazel test --verbose_failures --test_output=errors ... 28 | + bazel build ... --action_env PYTHON_BIN_PATH="${PYTHON_BIN}" 29 | + bazel test --verbose_failures --test_output=errors ... --action_env PYTHON_BIN_PATH="${PYTHON_BIN}" 30 | 31 | DEST="/tmp/array_record/all_dist" 32 | # Create the directory, then do dirname on a non-existent file inside it to 33 | @@ -71,7 +75,11 @@ function main() { 34 | 35 | pushd ${TMPDIR} 36 | echo $(date) : "=== Building wheel" 37 | - ${PYTHON_BIN} setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION} 38 | + plat_name="" 39 | + if [[ "$(uname)" == "Darwin" ]]; then 40 | + plat_name="--plat-name macosx_11_0_$(uname -m)" 41 | + fi 42 | + ${PYTHON_BIN} setup.py bdist_wheel --python-tag py3${PYTHON_MINOR_VERSION} $plat_name 43 | 44 | if [ -n "${AUDITWHEEL_PLATFORM}" ]; then 45 | echo $(date) : "=== Auditing wheel" 46 | -------------------------------------------------------------------------------- /grain/oss/array_record/runner_common.patch: -------------------------------------------------------------------------------- 1 | diff --git a/oss/runner_common.sh b/oss/runner_common.sh 2 | index 2ee2c8c..433698f 100644 3 | --- a/oss/runner_common.sh 4 | +++ b/oss/runner_common.sh 5 | @@ -2,7 +2,7 @@ 6 | 7 | # Builds ArrayRecord from source code located in SOURCE_DIR producing wheels 8 | # under $SOURCE_DIR/all_dist. 9 | -function build_and_test_array_record() { 10 | +function build_and_test_array_record_linux() { 11 | SOURCE_DIR=$1 12 | 13 | # Automatically decide which platform to build for by checking on which 14 | @@ -15,7 +15,7 @@ function build_and_test_array_record() { 15 | 16 | # Build wheels for multiple Python minor versions. 17 | PYTHON_MAJOR_VERSION=3 18 | - for PYTHON_MINOR_VERSION in 9 10 11 12 19 | + for PYTHON_MINOR_VERSION in 10 11 12 20 | do 21 | PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION} 22 | PYTHON_BIN=/opt/python/cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}-cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}/bin 23 | @@ -41,4 +41,78 @@ function build_and_test_array_record() { 24 | done 25 | 26 | ls ${SOURCE_DIR}/all_dist/*.whl 27 | +} 28 | + 29 | +function install_and_init_pyenv { 30 | + pyenv_root=${1:-$HOME/.pyenv} 31 | + export PYENV_ROOT=$pyenv_root 32 | + if [[ ! -d $PYENV_ROOT ]]; then 33 | + echo "Installing pyenv.." 34 | + git clone https://github.com/pyenv/pyenv.git "$PYENV_ROOT" 35 | + pushd "$PYENV_ROOT" 36 | + git checkout "v2.4.21" 37 | + popd 38 | + export PATH="/home/kbuilder/.local/bin:$PYENV_ROOT/bin:$PATH" 39 | + eval "$(pyenv init --path)" 40 | + fi 41 | + 42 | + echo "Python setup..." 43 | + pyenv install -s "$PYENV_PYTHON_VERSION" 44 | + pyenv global "$PYENV_PYTHON_VERSION" 45 | + export PYTHON_BIN=$(pyenv which python) 46 | +} 47 | + 48 | +function setup_env_vars_py { 49 | + # This controls the python binary to use. 50 | + PYTHON_MAJOR_VERSION=$1 51 | + PYTHON_MINOR_VERSION=$2 52 | + # This is for pyenv install. 53 | + PYENV_PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION} 54 | + PYTHON="python$PYENV_PYTHON_VERSION" 55 | +} 56 | + 57 | +function update_bazel_macos { 58 | + BAZEL_VERSION=$1 59 | + ARCH="$(uname -m)" 60 | + curl -L https://github.com/bazelbuild/bazel/releases/download/${BAZEL_VERSION}/bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh -O 61 | + ls 62 | + chmod +x bazel-*.sh 63 | + ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh --user 64 | + rm -f ./bazel-${BAZEL_VERSION}-installer-darwin-${ARCH}.sh 65 | + # Add new bazel installation to path 66 | + export PATH="$HOME/bin:$PATH" 67 | +} 68 | + 69 | +function install_ar_deps { 70 | + $PYTHON_BIN -m pip install -U \ 71 | + absl-py \ 72 | + build \ 73 | + etils[epath] \ 74 | + setuptools \ 75 | + twine \ 76 | + wheel; 77 | +} 78 | + 79 | +function build_and_test_array_record_macos() { 80 | + SOURCE_DIR=$1 81 | + # Set up Bazel. 82 | + # Using a previous version of Bazel to avoid: 83 | + # https://github.com/bazelbuild/bazel/issues/8622 84 | + export BAZEL_VERSION="5.4.0" 85 | + update_bazel_macos ${BAZEL_VERSION} 86 | + bazel --version 87 | + 88 | + PYTHON_MAJOR_VERSION=3 89 | + for PYTHON_MINOR_VERSION in 10 11 12 90 | + do 91 | + # Set up Pyenv. 92 | + PYTHON_VERSION=${PYTHON_MAJOR_VERSION}.${PYTHON_MINOR_VERSION} 93 | + echo "Creating array_record wheel for Python Version $PYTHON_VERSION" 94 | + setup_env_vars_py $PYTHON_MAJOR_VERSION $PYTHON_MINOR_VERSION 95 | + install_and_init_pyenv 96 | + install_ar_deps 97 | + 98 | + # Build and test ArrayRecord. 99 | + bash ${SOURCE_DIR}/oss/build_whl.sh 100 | + done 101 | } 102 | \ No newline at end of file 103 | -------------------------------------------------------------------------------- /grain/oss/array_record/setup.patch: -------------------------------------------------------------------------------- 1 | diff --git a/setup.py b/setup.py 2 | index cfb0bac..4763c00 100644 3 | --- a/setup.py 4 | +++ b/setup.py 5 | @@ -25,7 +25,7 @@ class BinaryDistribution(Distribution): 6 | 7 | setup( 8 | name='array_record', 9 | - version='0.5.1', 10 | + version='0.6.0', 11 | description='A file format that achieves a new frontier of IO efficiency', 12 | author='ArrayRecord team', 13 | author_email='no-reply@google.com', 14 | -------------------------------------------------------------------------------- /grain/proto/BUILD: -------------------------------------------------------------------------------- 1 | load("@com_google_protobuf//:protobuf.bzl", "py_proto_library") 2 | 3 | default_visibility = ["//grain:__subpackages__"] 4 | 5 | package(default_visibility = default_visibility) 6 | 7 | py_proto_library( 8 | name = "execution_summary_py_pb2", 9 | # For profiling tooling. 10 | srcs = ["execution_summary.proto"], 11 | ) 12 | -------------------------------------------------------------------------------- /grain/proto/execution_summary.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package grain.python.execution_summary; 4 | 5 | message ExecutionSummary { 6 | message Node { 7 | // Unique ID of the node. 8 | int32 id = 2; 9 | // Human-readable name of the node. 10 | string name = 3; 11 | // Node IDs of the parent nodes. 12 | repeated int32 inputs = 4; 13 | // Ratio of time spent by the pipeline waiting for the given transformation 14 | // node. 15 | double wait_time_ratio = 5; 16 | // Cummulative processing time spent in the node from the start in 17 | // nanoseconds. 18 | int64 total_processing_time_ns = 6; 19 | // Minimum per-element processing time in nanoseconds. 20 | int64 min_processing_time_ns = 7; 21 | // Maximum per-element processing time in nanoseconds. 22 | int64 max_processing_time_ns = 8; 23 | // Number of elements produced by the node. 24 | int64 num_produced_elements = 9; 25 | // Human-readable specification of the produced elements. 26 | string output_spec = 10; 27 | // Whether the node is the root node. 28 | bool is_output = 11; 29 | // Whether the node is prefetch node. Child nodes of prefetch will have 30 | // their wait time ratio derived from the ratio of the prefetch node. 31 | // Sum of all ratios in a single pipeline is 1. 32 | bool is_prefetch = 12; 33 | // Bytes consumed by the node. Currently, bytes comsumed and bytes produced 34 | // by the node is best estimated only for prefetch nodes. The difference is 35 | // used to estimate the memory usage of the node. 36 | int64 bytes_consumed = 13; 37 | // Bytes produced by the node. 38 | int64 bytes_produced = 14; 39 | } 40 | // Map of node IDs to nodes in the pipeline. 41 | map nodes = 1; 42 | } 43 | -------------------------------------------------------------------------------- /grain/python/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Public API for Grain. 15 | 16 | Backwards compatibility re-import. Prefer adding new APIs to top-level modules. 17 | """ 18 | 19 | # pylint: disable=g-importing-member 20 | # pylint: disable=g-import-not-at-top 21 | # pylint: disable=g-multiple-import 22 | # pylint: disable=unused-import 23 | 24 | 25 | from grain._src.core.config import config 26 | from grain._src.core.constants import ( 27 | DATASET_INDEX, 28 | EPOCH, 29 | INDEX, 30 | META_FEATURES, 31 | RECORD, 32 | RECORD_KEY, 33 | SEED, 34 | ) 35 | from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions 36 | from grain._src.core.transforms import ( 37 | Batch, 38 | Filter as FilterTransform, 39 | MapTransform, 40 | MapWithIndex as MapWithIndexTransform, 41 | RandomMapTransform, 42 | Transformation, 43 | Transformations, 44 | ) 45 | 46 | from grain._src.python.checkpoint_handlers import ( 47 | CheckpointHandler as PyGrainCheckpointHandler, 48 | ) 49 | from grain._src.python.data_loader import ( 50 | DataLoader, 51 | DataLoaderIterator as PyGrainDatasetIterator, 52 | ) 53 | from grain._src.python.data_sources import ( 54 | ArrayRecordDataSource, 55 | SharedMemoryDataSource as InMemoryDataSource, 56 | RandomAccessDataSource, 57 | RangeDataSource, 58 | ) 59 | from grain._src.python.dataset.base import DatasetSelectionMap 60 | from grain._src.python.dataset.dataset import ( 61 | MapDataset, 62 | IterDataset, 63 | DatasetIterator, 64 | ) 65 | 66 | from grain._src.python.load import load 67 | from grain._src.python.operations import ( 68 | BatchOperation, 69 | FilterOperation, 70 | MapOperation, 71 | Operation, 72 | RandomMapOperation, 73 | ) 74 | from grain._src.python.options import ReadOptions, MultiprocessingOptions 75 | from grain._src.python.record import (Record, RecordMetadata) 76 | from grain._src.python.samplers import ( 77 | IndexSampler, 78 | Sampler, 79 | SequentialSampler, 80 | ) 81 | from grain._src.python.shared_memory_array import SharedMemoryArray 82 | from grain.python import experimental 83 | 84 | # These are imported only if Orbax is present. 85 | try: 86 | from grain._src.python.checkpoint_handlers import ( 87 | CheckpointSave as PyGrainCheckpointSave, 88 | CheckpointRestore as PyGrainCheckpointRestore, 89 | ) 90 | except ImportError: 91 | pass 92 | -------------------------------------------------------------------------------- /grain/python/experimental.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Experimental Grain APIs. 15 | 16 | Backwards compatibility re-import. Prefer adding new APIs to top-level 17 | `experimental.py` file. 18 | """ 19 | 20 | # pylint: disable=g-importing-member 21 | # pylint: disable=g-bad-import-order 22 | # pylint: disable=g-multiple-import 23 | # pylint: disable=unused-import 24 | 25 | from grain._src.python.dataset.base import ( 26 | DatasetOptions, 27 | ExecutionTrackingMode, 28 | ) 29 | from grain._src.python.dataset.dataset import ( 30 | apply_transformations, 31 | WithOptionsIterDataset, 32 | ) 33 | from grain._src.python.dataset.sources.parquet_dataset import ParquetIterDataset 34 | 35 | from grain._src.python.dataset.transformations.flatmap import ( 36 | FlatMapMapDataset, 37 | FlatMapIterDataset, 38 | ) 39 | from grain._src.python.dataset.transformations.interleave import ( 40 | InterleaveIterDataset, 41 | ) 42 | from grain._src.python.dataset.transformations.limit import LimitIterDataset 43 | from grain._src.python.dataset.transformations.map import RngPool 44 | from grain._src.python.dataset.transformations.mix import ConcatenateMapDataset 45 | from grain._src.python.dataset.transformations.packing import FirstFitPackIterDataset 46 | from grain._src.python.dataset.transformations.packing_concat_then_split import ( 47 | BOSHandling, 48 | ConcatThenSplitIterDataset, 49 | ) 50 | from grain._src.python.dataset.transformations.prefetch import ( 51 | MultiprocessPrefetchIterDataset, 52 | ThreadPrefetchIterDataset, 53 | ) 54 | from grain._src.python.dataset.transformations.shuffle import ( 55 | WindowShuffleMapDataset, 56 | WindowShuffleIterDataset, 57 | ) 58 | from grain._src.python.dataset.transformations.zip import ( 59 | ZipMapDataset, 60 | ZipIterDataset, 61 | ) 62 | from grain._src.core.transforms import ( 63 | FlatMapTransform, 64 | MapWithIndex as MapWithIndexTransform, 65 | ) 66 | from grain._src.python.experimental.example_packing.packing import PackAndBatchOperation 67 | 68 | # This should evetually live under grain.testing. 69 | from grain._src.python.testing.experimental import assert_equal_output_after_checkpoint 70 | 71 | from grain._src.python.experimental.index_shuffle.python.index_shuffle_module import index_shuffle 72 | -------------------------------------------------------------------------------- /grain/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Sampler APIs.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=g-multiple-import 19 | # pylint: disable=unused-import 20 | 21 | from grain._src.python.samplers import ( 22 | IndexSampler, 23 | Sampler, 24 | SequentialSampler, 25 | ) 26 | -------------------------------------------------------------------------------- /grain/sharding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """APIs for sharding pipelines for distributed training.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=g-multiple-import 19 | # pylint: disable=unused-import 20 | 21 | from grain._src.core.sharding import NoSharding, ShardByJaxProcess, ShardOptions 22 | -------------------------------------------------------------------------------- /grain/sources.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """APIs for reading data from various file formats.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=g-multiple-import 19 | # pylint: disable=unused-import 20 | 21 | # Note to developers: 22 | # - When adding a new OSS data source make the format lib dependency optional 23 | # by lazily importing the source and providing an extra installation, e.g. 24 | # `pip install grain[parquet]`. This will allow users to avoid installing 25 | # all supported format dependencies. 26 | from grain._src.python.data_sources import ( 27 | ArrayRecordDataSource, 28 | SharedMemoryDataSource, 29 | RandomAccessDataSource, 30 | RangeDataSource, 31 | ) 32 | -------------------------------------------------------------------------------- /grain/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Data transformation APIs.""" 15 | 16 | 17 | # pylint: disable=g-importing-member 18 | # pylint: disable=g-multiple-import 19 | # pylint: disable=unused-import 20 | 21 | from grain._src.core.transforms import ( 22 | Batch, 23 | Filter, 24 | MapTransform as Map, 25 | MapWithIndex, 26 | RandomMapTransform as RandomMap, 27 | Transformation, 28 | Transformations, 29 | ) 30 | 31 | from grain._src.python.dataset.base import DatasetSelectionMap 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "grain" 7 | version = "0.2.10" 8 | description = "Grain: A library for loading and transforming data for ML training." 9 | keywords = [] 10 | authors = [ 11 | {name = "Grain team", email = "grain-dev@google.com"}, 12 | ] 13 | dependencies = [ 14 | 'absl-py', 15 | 'array-record', 16 | 'cloudpickle', 17 | 'dm-tree', 18 | 'etils[epath,epy]', 19 | 'more-itertools>=9.1.0', 20 | 'numpy', 21 | 'protobuf>=3.20.3', 22 | ] 23 | readme = "README.md" 24 | license = { file = "LICENSE" } 25 | requires-python = ">=3.10" 26 | classifiers = [ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: Apache Software License", 29 | "Operating System :: POSIX :: Linux", 30 | ] 31 | 32 | [project.optional-dependencies] 33 | testing = [ 34 | 'attrs', 35 | 'dill', 36 | 'jax', 37 | 'jaxlib', 38 | 'jaxtyping', 39 | 'pyarrow', 40 | 'tensorflow-datasets', 41 | ] 42 | parquet = [ 43 | 'pyarrow', 44 | ] 45 | 46 | [project.urls] 47 | homepage = "https://github.com/google/grain" 48 | 49 | [tool.setuptools.packages.find] 50 | include = ["grain*"] 51 | 52 | [tool.setuptools.package-data] 53 | "*" = ["*.so"] -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup.py file for grain. 2 | 3 | Most project configs are in `pyproject.toml` -- prefer to modify 4 | `pyproject.toml` over this file if possible. 5 | """ 6 | 7 | import setuptools 8 | from setuptools import dist 9 | 10 | 11 | class BinaryDistribution(dist.Distribution): 12 | """This class makes 'bdist_wheel' include an ABI tag on the wheel.""" 13 | 14 | def has_ext_modules(self): 15 | return True 16 | 17 | 18 | setuptools.setup( 19 | distclass=BinaryDistribution, 20 | ) 21 | -------------------------------------------------------------------------------- /test_requirements.in: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This is the list of our Python third_party dependencies that Bazel should 16 | # pull from PyPi. 17 | # Note that requirements.txt must be re-generated using 18 | # bazel run //:requirements.update in the OSS version. 19 | array-record 20 | absl-py 21 | dm-tree 22 | etils[epath,epy] 23 | cloudpickle 24 | jax 25 | jaxtyping 26 | numpy 27 | attrs 28 | pyarrow 29 | parameterized --------------------------------------------------------------------------------