├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── config.yml │ ├── documentation.yml │ └── feature-request.yml ├── PULL_REQUEST_TEMPLATE.md ├── scripts │ ├── check_complete_doc.py │ └── validate_binaries.sh └── workflows │ ├── _build_test_upload.yml │ ├── build_wheels_linux.yml │ ├── lint.yml │ ├── nightly_release.yml │ ├── nodes_ci.yml │ ├── pull_release.yml │ ├── stateful_dataloader_ci.yml │ ├── test_release.yml │ ├── validate-binaries.yml │ └── validate-nightly-binaries.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierignore ├── .prettierrc.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── benchmarks ├── cloud │ ├── README.md │ ├── aws_s3_results.md │ ├── ec2.yml │ └── result_images │ │ ├── aws_s3_high_complex.jpg │ │ ├── aws_s3_low_complex.jpg │ │ └── aws_s3_med_complex.jpg └── torchvision_classification │ ├── helpers.py │ ├── presets.py │ ├── train.py │ └── utils.py ├── docs ├── Makefile ├── README.md ├── make.bat ├── requirements.txt └── source │ ├── _static │ └── css │ │ └── custom.css │ ├── _templates │ ├── class_method_template.rst │ ├── class_template.rst │ ├── function.rst │ └── layout.html │ ├── conf.py │ ├── docutils.conf │ ├── getting_started_with_torchdata_nodes.rst │ ├── index.rst │ ├── migrate_to_nodes_from_utils.rst │ ├── stateful_dataloader_tutorial.rst │ ├── torchdata.nodes.rst │ ├── torchdata.stateful_dataloader.rst │ └── what_is_torchdata_nodes.rst ├── examples ├── __init__.py ├── criteo_torcharrow │ ├── common.py │ ├── day_11_first_3k_rows.tsv │ ├── day_11_first_3k_rows_original.tsv │ └── tsv_to_parquet.py ├── nodes │ ├── hf_datasets_nodes_mnist.ipynb │ ├── hf_imdb_bert.ipynb │ ├── imagenet_benchmark.py │ ├── multi_dataset_weighted_sampling.ipynb │ └── torchata_nodes_basics.ipynb ├── text │ └── utils.py └── vision │ ├── __init__.py │ └── fakedata │ ├── caltech101 │ ├── 101_ObjectCategories.tar.gz │ └── Annotations.tar │ ├── caltech256 │ └── 256_ObjectCategories.tar │ └── imagefolder │ ├── cat │ ├── 1.jpg │ ├── 2.jpg │ └── 3.jpg │ └── dog │ ├── 1.jpg │ ├── 2.jpg │ └── 3.jpg ├── mypy.ini ├── packaging ├── README.md ├── build_conda.sh ├── build_wheel.sh ├── env-var-script.txt ├── manylinux │ ├── install_openssl_curl.sh │ └── python_helper.sh ├── pkg_helpers.bash ├── post_build_script_linux.sh ├── pre_build_script_linux.sh └── torchdata │ ├── bld.bat │ ├── build.sh │ └── meta.yaml ├── pyproject.toml ├── requirements.txt ├── scripts └── release_notes │ ├── commitlist.py │ └── common.py ├── setup.py ├── test ├── _fakedata │ ├── README.md │ ├── _create_fake_data.py │ ├── bytes.tar │ ├── bytes.tar.gz │ ├── bytes │ │ ├── 0.bt │ │ ├── 1.bt │ │ └── 2.bt │ ├── csv.tar │ ├── csv.tar.gz │ ├── csv │ │ ├── 0.csv │ │ ├── 1.csv │ │ └── 2.csv │ ├── json.tar │ ├── json.tar.gz │ ├── json │ │ ├── 0.json │ │ ├── 1.json │ │ └── 2.json │ ├── tfrecord │ │ ├── example.tfrecord │ │ └── sequence_example.tfrecord │ ├── txt.tar │ ├── txt.tar.gz │ └── txt │ │ ├── 0.txt │ │ ├── 1.txt │ │ └── 2.txt ├── _utils │ ├── __init__.py │ └── _common_utils_for_test.py ├── nodes │ ├── __init__.py │ ├── test_adapters.py │ ├── test_base_node.py │ ├── test_batch.py │ ├── test_cycler.py │ ├── test_filter.py │ ├── test_header.py │ ├── test_loader.py │ ├── test_map.py │ ├── test_multi_node_round_robin_sampler.py │ ├── test_multi_node_weighted_sampler.py │ ├── test_pin_memory.py │ ├── test_prefetch.py │ ├── test_snapshot_store.py │ └── utils.py ├── requirements.txt ├── smoke_test │ └── smoke_test.py └── stateful_dataloader │ ├── test_dataloader.py │ ├── test_hugging_face.py │ ├── test_incremental_state.py │ ├── test_sampler.py │ └── test_state_dict.py ├── tools ├── __init__.py ├── setup_helpers │ ├── __init__.py │ └── extension.py └── todo.py ├── torchdata ├── __init__.py ├── nodes │ ├── README.md │ ├── __init__.py │ ├── _apply_udf.py │ ├── _populate_queue.py │ ├── adapters.py │ ├── base_node.py │ ├── batch.py │ ├── constants.py │ ├── cycler.py │ ├── exception_wrapper.py │ ├── filter.py │ ├── header.py │ ├── loader.py │ ├── map.py │ ├── pin_memory.py │ ├── prefetch.py │ ├── samplers │ │ ├── __init__.py │ │ ├── multi_node_round_robin_sampler.py │ │ ├── multi_node_weighted_sampler.py │ │ ├── stop_criteria.py │ │ └── utils.py │ ├── snapshot_store.py │ └── types.py └── stateful_dataloader │ ├── README.md │ ├── __init__.py │ ├── incremental_state.py │ ├── sampler.py │ ├── stateful.py │ ├── stateful_dataloader.py │ └── worker.py └── version.txt /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203,E402,E501,F821,W503,W504, 4 | per-file-ignores = 5 | __init__.py: F401, F403, F405 6 | test/*: F401 7 | _extension.py: F401 8 | exclude = 9 | ./.git, 10 | ./third_party, 11 | *.pyi, 12 | ./examples/text/utils.py, # Copy from TorchText 13 | docs/source/conf.py, 14 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the 9 | existing and past issues](https://github.com/pytorch/data/issues?q=is%3Aissue+sort%3Acreated-desc+). 10 | - type: textarea 11 | attributes: 12 | label: 🐛 Describe the bug 13 | description: | 14 | Please provide a clear and concise description of what the bug is. 15 | 16 | If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: 17 | 18 | ```python 19 | # All necessary imports at the beginning 20 | import torch 21 | import torchdata 22 | 23 | # A succinct reproducing example trimmed down to the essential parts: 24 | t = torch.rand(5, 10) # Note: the bug is here, we should pass requires_grad=True 25 | t.sum().backward() 26 | ``` 27 | 28 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. 29 | 30 | Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 31 | placeholder: | 32 | A clear and concise description of what the bug is. 33 | 34 | ```python 35 | Sample code to reproduce the problem 36 | ``` 37 | 38 | ``` 39 | The error message you got, with the full traceback. 40 | ``` 41 | validations: 42 | required: true 43 | - type: textarea 44 | attributes: 45 | label: Versions 46 | description: | 47 | Please run the following and paste the output below. Make sure the version numbers of all relevant packages (e.g. torch, torchdata, other domain packages) are included. 48 | ```sh 49 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 50 | # For security purposes, please check the contents of collect_env.py before running it. 51 | python collect_env.py 52 | ``` 53 | validations: 54 | required: true 55 | - type: markdown 56 | attributes: 57 | value: > 58 | Thanks for contributing 🎉! 59 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Questions and Discussion 4 | url: https://discuss.pytorch.org/ 5 | about: Ask questions and discuss with other PyTorch community members. Please use the 'data' category. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue related to inline documentation 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 The doc issue 8 | description: > 9 | A clear and concise description of what content in https://pytorch.org/data is an issue. If this has to do with 10 | the general https://pytorch.org website, please file an issue at 11 | https://github.com/pytorch/pytorch.github.io/issues/new/choose instead. 12 | validations: 13 | required: true 14 | - type: textarea 15 | attributes: 16 | label: Suggest a potential alternative/fix 17 | description: > 18 | Tell us how we could improve the documentation in this regard. 19 | - type: markdown 20 | attributes: 21 | value: > 22 | Thanks for contributing 🎉! 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new TorchData feature 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 🚀 The feature 8 | description: > 9 | A clear and concise description of the feature proposal 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Motivation, pitch 15 | description: > 16 | Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., 17 | *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link 18 | here too. 19 | validations: 20 | required: true 21 | - type: textarea 22 | attributes: 23 | label: Alternatives 24 | description: > 25 | A description of any alternative solutions or features you've considered, if any. 26 | - type: textarea 27 | attributes: 28 | label: Additional context 29 | description: > 30 | Add any other context or screenshots about the feature request. 31 | - type: markdown 32 | attributes: 33 | value: > 34 | Thanks for contributing 🎉! 35 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Please read through our [contribution guide](https://github.com/pytorch/data/blob/main/CONTRIBUTING.md) prior to 2 | creating your pull request. 3 | 4 | - If you are adding a new node, ensure you read that section in the contribution guide, as it includes requirements for 5 | functionality and testing. 6 | 7 | Fixes #{issue number} 8 | 9 | ### Changes 10 | 11 | - 12 | - 13 | -------------------------------------------------------------------------------- /.github/scripts/check_complete_doc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import sys 9 | 10 | 11 | def collect_init_dps(init_file_location): 12 | init_dps = set() 13 | with open(init_file_location) as init_file: 14 | while (line := init_file.readline()) != "": 15 | if line.startswith("__all__ "): 16 | while (line := init_file.readline()) != "" and (stripped_line := line.strip()).startswith('"'): 17 | init_dps.add(stripped_line.replace(",", "").replace('"', "")) 18 | break 19 | return init_dps 20 | 21 | 22 | def collect_rst_dps(rst_file_location): 23 | rst_dps = set() 24 | with open(rst_file_location) as rst_file: 25 | while (line := rst_file.readline()) != "": 26 | if line.count("class_template.rst") > 0 or line.count("function.rst") > 0: 27 | rst_file.readline() 28 | while (line := rst_file.readline()) != "" and len(stripped_line := line.strip()) > 1: 29 | rst_dps.add(stripped_line) 30 | return rst_dps 31 | 32 | 33 | def compare_sets(set_a, set_b, ignore_set=None): 34 | res = set_a.difference(set_b) 35 | if ignore_set is not None: 36 | res.difference_update(ignore_set) 37 | return res 38 | 39 | 40 | def main(): 41 | init_file = "__init__.py" 42 | docs_source_folder = os.path.join("docs", "source") 43 | exit_code = 0 44 | 45 | for target, ignore_set in [("stateful_dataloader", {})]: 46 | init_path = os.path.join("torchdata", target, init_file) 47 | rst_path = os.path.join(docs_source_folder, "torchdata." + target + ".rst") 48 | 49 | init_set = collect_init_dps(init_path) 50 | rst_set = collect_rst_dps(rst_path) 51 | 52 | dif_init = compare_sets(init_set, rst_set, ignore_set) 53 | dif_rst = compare_sets(rst_set, init_set) 54 | 55 | for elem in dif_init: 56 | print(f"Please add {elem} to {rst_path}") 57 | exit_code = 1 58 | for elem in dif_rst: 59 | print(f"{elem} is present in {rst_path} but not in {init_path}") 60 | exit_code = 1 61 | 62 | sys.exit(exit_code) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /.github/scripts/validate_binaries.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | AWS_ENABLED=1 9 | 10 | # shellcheck disable=SC2086 11 | pip install ${PYTORCH_PIP_PREFIX:-} torchdata --extra-index-url "${PYTORCH_PIP_DOWNLOAD_URL}" 12 | 13 | 14 | case "${AWS_ENABLED}" in 15 | "0") 16 | python ./test/smoke_test/smoke_test.py --no-s3 17 | ;; 18 | "1") 19 | python ./test/smoke_test/smoke_test.py 20 | ;; 21 | *) 22 | exit 1 23 | ;; 24 | esac 25 | -------------------------------------------------------------------------------- /.github/workflows/_build_test_upload.yml: -------------------------------------------------------------------------------- 1 | name: Build, Test and Upload Wheel 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | branch: 7 | required: true 8 | type: string 9 | pre_dev_release: 10 | required: true 11 | type: boolean 12 | pytorch_version: 13 | required: true 14 | type: string 15 | 16 | permissions: 17 | id-token: write 18 | contents: write 19 | 20 | jobs: 21 | get_release_type: 22 | runs-on: ubuntu-latest 23 | outputs: 24 | type: ${{ steps.get_release_type.outputs.type }} 25 | steps: 26 | - name: Get Release Type 27 | run: | 28 | if [[ "${{ inputs.branch }}" == v* ]] && [[ ${{ inputs.pre_dev_release }} == false ]]; then 29 | RELEASE_TYPE=official 30 | elif [[ "${{ inputs.branch }}" == release/* ]] && [[ ${{ inputs.pre_dev_release }} == true ]]; then 31 | RELEASE_TYPE=test 32 | else 33 | if [[ "${{ github.base_ref }}" == release/* ]]; then 34 | RELEASE_TYPE=test 35 | else 36 | RELEASE_TYPE=nightly 37 | fi 38 | fi 39 | echo "Release Type: $RELEASE_TYPE" 40 | echo "type=$RELEASE_TYPE" >> $GITHUB_OUTPUT 41 | id: get_release_type 42 | 43 | build_docs: 44 | if: always() && inputs.branch != '' 45 | needs: get_release_type 46 | runs-on: ubuntu-latest 47 | steps: 48 | - name: Setup Python 3.9 49 | uses: actions/setup-python@v5 50 | with: 51 | python-version: 3.9 52 | - name: Checkout 53 | uses: actions/checkout@v4 54 | with: 55 | ref: ${{ inputs.branch }} 56 | submodules: recursive 57 | - name: Install Dependencies 58 | run: | 59 | echo `python3 --version` 60 | python3 -m pip install --upgrade pip 61 | python3 -m pip install setuptools 62 | python3 -m pip install matplotlib 63 | sudo apt-get install -y yarn 64 | - name: Install PyTorch & TorchData 65 | run: | 66 | pip3 install numpy 67 | # Add version requirement to PyTorch except nightly release 68 | if [[ -z "${{ inputs.pytorch_version }}" ]]; then 69 | PYTORCH_VERSION=torch 70 | else 71 | PYTORCH_VERSION=torch==${{ inputs.pytorch_version }} 72 | fi 73 | 74 | PIP_CHANNEL=${{ needs.get_release_type.outputs.type }} 75 | if [[ $PIP_CHANNEL == 'official' ]]; then 76 | pip3 install "$PYTORCH_VERSION" -f https://download.pytorch.org/whl/torch_stable.html 77 | else 78 | pip3 install --pre "$PYTORCH_VERSION" --index-url "https://download.pytorch.org/whl/$PIP_CHANNEL/cpu" 79 | fi 80 | 81 | pip3 install -r requirements.txt 82 | pip3 install . 83 | - name: Check env 84 | run: echo `which spinx-build` 85 | - name: Build the docset 86 | run: | 87 | cd ./docs 88 | sudo apt-get install -y graphviz 89 | pip3 install -r requirements.txt 90 | make html 91 | cd .. 92 | - name: Export Target Folder 93 | run: | 94 | TARGET_FOLDER=${{ inputs.branch }} 95 | if [[ $TARGET_FOLDER == release/* ]]; then 96 | TARGET_FOLDER=${TARGET_FOLDER:8} 97 | elif [[ $TARGET_FOLDER == tags/* ]]; then 98 | TARGET_FOLDER=${TARGET_FOLDER:5} 99 | elif [[ $TARGET_FOLDER == v* ]] && [[ ${{ inputs.pre_dev_release }} == false ]]; then 100 | if [[ $TARGET_FOLDER == v*.*.* ]]; then 101 | TARGET_FOLDER=${TARGET_FOLDER%.*} 102 | fi 103 | TARGET_FOLDER=${TARGET_FOLDER:1} 104 | fi 105 | echo "value=$TARGET_FOLDER" >> $GITHUB_OUTPUT 106 | id: target_folder 107 | - name: Deploy 108 | uses: JamesIves/github-pages-deploy-action@v4.4.1 109 | with: 110 | token: ${{ secrets.GITHUB_TOKEN }} 111 | branch: gh-pages # The branch the action should deploy to. 112 | folder: docs/build/html # The folder the action should deploy. 113 | target-folder: ${{ steps.target_folder.outputs.value }} # The destination folder the action should deploy to. 114 | -------------------------------------------------------------------------------- /.github/workflows/build_wheels_linux.yml: -------------------------------------------------------------------------------- 1 | name: Build Linux Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: linux 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | with-cuda: disable 29 | with-rocm: disable 30 | python-versions: '["3.9"]' 31 | build: 32 | needs: generate-matrix 33 | strategy: 34 | fail-fast: false 35 | matrix: 36 | include: 37 | - repository: pytorch/data 38 | pre-script: packaging/pre_build_script_linux.sh 39 | post-script: "" 40 | smoke-test-script: test/smoke_test/smoke_test.py 41 | package-name: torchdata 42 | name: ${{ matrix.repository }} 43 | uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main 44 | with: 45 | repository: ${{ matrix.repository }} 46 | ref: "" 47 | test-infra-repository: pytorch/test-infra 48 | test-infra-ref: main 49 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 50 | pre-script: ${{ matrix.pre-script }} 51 | post-script: ${{ matrix.post-script }} 52 | smoke-test-script: ${{ matrix.smoke-test-script }} 53 | package-name: ${{ matrix.package-name }} 54 | env-var-script: packaging/env-var-script.txt 55 | trigger-event: ${{ github.event_name }} 56 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - release/* 8 | pull_request: 9 | 10 | jobs: 11 | style: 12 | if: ${{ github.repository_owner == 'pytorch' }} 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Setup Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.9" 19 | - name: Check out source repository 20 | uses: actions/checkout@v4 21 | - name: Install lint utilities 22 | run: | 23 | pip install pre-commit 24 | pre-commit install-hooks 25 | - name: Lint Python code and config files 26 | run: pre-commit run --all-files 27 | - name: Required modifications 28 | if: ${{ failure() }} 29 | run: git --no-pager diff 30 | 31 | mypy: 32 | if: ${{ github.repository_owner == 'pytorch' }} 33 | runs-on: ubuntu-latest 34 | steps: 35 | - name: Get PyTorch Channel 36 | shell: bash 37 | run: | 38 | if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then 39 | PT_CHANNEL="https://download.pytorch.org/whl/test/cpu" 40 | else 41 | PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu" 42 | fi 43 | echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT 44 | id: pytorch_channel 45 | - name: Setup Python environment 46 | uses: actions/setup-python@v5 47 | with: 48 | python-version: 3.9 49 | - name: Check out source repository 50 | uses: actions/checkout@v4 51 | - name: Install PyTorch 52 | run: | 53 | pip3 install networkx 54 | pip3 install --pre torch --index-url "${{ steps.pytorch_channel.outputs.value }}" 55 | - name: Install dependencies 56 | run: | 57 | pip3 install -r requirements.txt 58 | pip3 install mypy==1.8.0 numpy types-requests 59 | - name: Build TorchData 60 | run: | 61 | pip3 install . 62 | - name: Run mypy 63 | env: 64 | MYPY_FORCE_COLOR: 1 65 | TERM: xterm-color 66 | run: | 67 | set -eux 68 | STATUS= 69 | if ! mypy --config=mypy.ini; then 70 | STATUS=fail 71 | fi 72 | if [ -n "$STATUS" ]; then 73 | echo 'Please fix the above mypy warnings.' 74 | false 75 | fi 76 | 77 | complete_documentation: 78 | if: ${{ github.repository_owner == 'pytorch' }} 79 | runs-on: ubuntu-latest 80 | steps: 81 | - name: Setup Python 82 | uses: actions/setup-python@v5 83 | with: 84 | python-version: "3.9" 85 | - name: Check out source repository 86 | uses: actions/checkout@v4 87 | - name: Check if documentation is complete 88 | run: python ./.github/scripts/check_complete_doc.py 89 | -------------------------------------------------------------------------------- /.github/workflows/nightly_release.yml: -------------------------------------------------------------------------------- 1 | name: Push Nightly Release 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: 00 11 * * * 7 | 8 | permissions: 9 | id-token: write 10 | contents: write 11 | 12 | jobs: 13 | build_test_upload: 14 | if: | 15 | github.repository == 'pytorch/data' && (github.ref_name == 'main' || github.event_name == 'workflow_dispatch') 16 | uses: ./.github/workflows/_build_test_upload.yml 17 | with: 18 | branch: "main" 19 | pre_dev_release: true 20 | pytorch_version: "" 21 | -------------------------------------------------------------------------------- /.github/workflows/nodes_ci.yml: -------------------------------------------------------------------------------- 1 | name: Run Nodes Tests 2 | on: 3 | push: 4 | branches: 5 | - main 6 | - release/* 7 | tags: 8 | pull_request: 9 | types: [opened, synchronize, reopened, labeled] 10 | branches: 11 | - main 12 | # For PR created by ghstack 13 | - gh/*/*/base 14 | - release/* 15 | 16 | jobs: 17 | test: 18 | if: ${{ github.repository_owner == 'pytorch' }} 19 | runs-on: ${{ matrix.os }} 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: 24 | - macos-latest 25 | - ubuntu-latest 26 | - windows-latest 27 | python-version: 28 | - 3.9 29 | - "3.10" 30 | - "3.11" 31 | - "3.12" 32 | - "3.13" 33 | exclude: 34 | - os: macos-latest 35 | python-version: "3.13" 36 | - os: windows-latest 37 | python-version: "3.13" 38 | steps: 39 | - name: Get PyTorch Channel 40 | shell: bash 41 | run: | 42 | if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then 43 | PT_CHANNEL="https://download.pytorch.org/whl/test/cpu" 44 | else 45 | PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu" 46 | fi 47 | echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT 48 | id: pytorch_channel 49 | - name: Setup additional system libraries 50 | if: startsWith( matrix.os, 'ubuntu' ) 51 | run: | 52 | sudo add-apt-repository multiverse 53 | sudo apt update 54 | sudo apt install rar unrar libssl-dev libcurl4-openssl-dev zlib1g-dev 55 | - name: Setup Python ${{ matrix.python-version }} 56 | uses: actions/setup-python@v5 57 | with: 58 | python-version: ${{ matrix.python-version }} 59 | - name: Setup msbuild on Windows 60 | if: matrix.os == 'windows-latest' 61 | uses: microsoft/setup-msbuild@v1.1 62 | - name: Set up Visual Studio shell 63 | if: matrix.os == 'windows-latest' 64 | uses: egor-tensin/vs-shell@v2 65 | with: 66 | arch: x64 67 | - name: Check out source repository 68 | uses: actions/checkout@v4 69 | with: 70 | submodules: recursive 71 | - name: Install dependencies 72 | run: | 73 | pip3 install -r requirements.txt 74 | pip3 install networkx 75 | pip3 install --pre torch --index-url "${{ steps.pytorch_channel.outputs.value }}" 76 | pip3 install cmake ninja 77 | echo "/home/runner/.local/bin" >> $GITHUB_PATH 78 | - name: Build TorchData 79 | run: | 80 | pip3 install . 81 | env: 82 | BUILD_S3: 0 83 | - name: Install test requirements 84 | run: pip3 install -r test/requirements.txt 85 | - name: Run Node tests with pytest - dataloader 86 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 87 | run: pytest --durations=0 --no-header -v test/nodes/ 88 | -------------------------------------------------------------------------------- /.github/workflows/pull_release.yml: -------------------------------------------------------------------------------- 1 | name: Test Release Pipelines 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | paths: 7 | - .github/workflows/pull_release.yml 8 | - .github/workflows/_build_test_upload.yml 9 | 10 | permissions: 11 | contents: write 12 | id-token: write 13 | 14 | jobs: 15 | build_test_upload: 16 | if: github.repository == 'pytorch/data' 17 | uses: ./.github/workflows/_build_test_upload.yml 18 | with: 19 | branch: "" 20 | pre_dev_release: true 21 | pytorch_version: "" 22 | do-upload: false 23 | -------------------------------------------------------------------------------- /.github/workflows/stateful_dataloader_ci.yml: -------------------------------------------------------------------------------- 1 | name: Run StatefulDataLoader Tests 2 | on: 3 | push: 4 | branches: 5 | - main 6 | - release/* 7 | tags: 8 | pull_request: 9 | types: [opened, synchronize, reopened, labeled] 10 | branches: 11 | - main 12 | # For PR created by ghstack 13 | - gh/*/*/base 14 | - release/* 15 | 16 | jobs: 17 | test: 18 | if: ${{ github.repository_owner == 'pytorch' }} 19 | runs-on: ${{ matrix.os }} 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: 24 | - macos-latest 25 | - ubuntu-latest 26 | - windows-latest 27 | python-version: 28 | - 3.9 29 | - "3.10" 30 | - "3.11" 31 | - "3.12" 32 | - "3.13" 33 | exclude: 34 | - os: macos-latest 35 | python-version: "3.13" 36 | - os: windows-latest 37 | python-version: "3.13" 38 | steps: 39 | - name: Get PyTorch Channel 40 | shell: bash 41 | run: | 42 | if [[ "${{ github.base_ref }}" == release/* ]] || [[ "${{ github.ref }}" == refs/heads/release/* ]] || [[ "${{ github.ref }}" == refs/tags/v* ]]; then 43 | PT_CHANNEL="https://download.pytorch.org/whl/test/cpu" 44 | else 45 | PT_CHANNEL="https://download.pytorch.org/whl/nightly/cpu" 46 | fi 47 | echo "value=$PT_CHANNEL" >> $GITHUB_OUTPUT 48 | id: pytorch_channel 49 | - name: Setup additional system libraries 50 | if: startsWith( matrix.os, 'ubuntu' ) 51 | run: | 52 | sudo add-apt-repository multiverse 53 | sudo apt update 54 | sudo apt install rar unrar libssl-dev libcurl4-openssl-dev zlib1g-dev 55 | - name: Setup Python ${{ matrix.python-version }} 56 | uses: actions/setup-python@v5 57 | with: 58 | python-version: ${{ matrix.python-version }} 59 | - name: Setup msbuild on Windows 60 | if: matrix.os == 'windows-latest' 61 | uses: microsoft/setup-msbuild@v1.1 62 | - name: Set up Visual Studio shell 63 | if: matrix.os == 'windows-latest' 64 | uses: egor-tensin/vs-shell@v2 65 | with: 66 | arch: x64 67 | - name: Check out source repository 68 | uses: actions/checkout@v4 69 | with: 70 | submodules: recursive 71 | - name: Install dependencies 72 | run: | 73 | pip3 install -r requirements.txt 74 | pip3 install networkx 75 | pip3 install --pre torch --index-url "${{ steps.pytorch_channel.outputs.value }}" 76 | pip3 install cmake ninja 77 | echo "/home/runner/.local/bin" >> $GITHUB_PATH 78 | - name: Build TorchData 79 | run: | 80 | pip3 install . 81 | env: 82 | BUILD_S3: 0 83 | - name: Install test requirements 84 | run: pip3 install -r test/requirements.txt 85 | - name: Run StatefulDataLoader tests with pytest - dataloader 86 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 87 | run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_dataloader.py 88 | - name: Run StatefulDataSampler tests with pytest - datasampler 89 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 90 | run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_sampler.py 91 | - name: Run StatefulDataLoader tests with pytest - state_dict 0 92 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 93 | run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard0 94 | - name: Run StatefulDataLoader tests with pytest - state_dict 1 95 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 96 | run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard1 97 | - name: Run StatefulDataLoader tests with pytest - state_dict 2 98 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 99 | run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard2 100 | - name: Run StatefulDataLoader tests with pytest - state_dict 3 101 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 102 | run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard3 103 | - name: Run StatefulDataLoader HuggingFace tests 104 | if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }} 105 | run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_hugging_face.py 106 | -------------------------------------------------------------------------------- /.github/workflows/test_release.yml: -------------------------------------------------------------------------------- 1 | name: Push Test Release 2 | 3 | on: 4 | # [ Note: Manually Trigger the Workflow ] 5 | # 1. Go to Actions under pytorch/data repo 6 | # 2. In the left sidebar, click the workflow you want to run 7 | # 3. Above the list of workflow runs, select Run workflow 8 | # 4. Use the Branch dropdown to select the release/* branch 9 | # 5. Click Run workflow 10 | workflow_dispatch: 11 | # Automatically trigger test/official release 12 | # Requred Feature of GHA: Run schedule on specific branch 13 | # Otherwise, all changes for release need to be landed into main branch 14 | # See: https://github.community/t/scheduled-builds-of-non-default-branch/16306 15 | # schedule: 16 | # - cron: 30 23 * * * 17 | 18 | # [ Note: Workflow/Job Level ENV ] 19 | # Workflow/Job level env doesn't work even though document indicates this feature 20 | # https://github.com/actions/runner/issues/480 21 | # https://github.community/t/how-to-set-and-access-a-workflow-variable/17335 22 | # env: 23 | # RELEASE_BRANCH: "" 24 | 25 | permissions: 26 | id-token: write 27 | contents: write 28 | 29 | jobs: 30 | build_test_upload: 31 | if: github.repository == 'pytorch/data' && startsWith(github.ref_name, 'release/') 32 | uses: ./.github/workflows/_build_test_upload.yml 33 | with: 34 | branch: "release/0.11" 35 | pre_dev_release: true 36 | pytorch_version: "2.6.0" 37 | -------------------------------------------------------------------------------- /.github/workflows/validate-binaries.yml: -------------------------------------------------------------------------------- 1 | name: Validate binaries 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | channel: 7 | description: "Channel to use (nightly, test, release, all)" 8 | required: false 9 | type: string 10 | default: release 11 | os: 12 | description: "Operating system to generate for (linux, windows, macos, macos-arm64)" 13 | required: true 14 | type: string 15 | ref: 16 | description: "Reference to checkout, defaults to empty" 17 | default: "" 18 | required: false 19 | type: string 20 | workflow_dispatch: 21 | inputs: 22 | channel: 23 | description: "Channel to use (nightly, test, release, all)" 24 | required: true 25 | type: choice 26 | options: 27 | - release 28 | - nightly 29 | - test 30 | - all 31 | os: 32 | description: "Operating system to generate for (linux, windows, macos)" 33 | required: true 34 | type: choice 35 | default: all 36 | options: 37 | - windows 38 | - linux 39 | - macos 40 | - all 41 | ref: 42 | description: "Reference to checkout, defaults to empty" 43 | default: "" 44 | required: false 45 | type: string 46 | jobs: 47 | validate-binaries: 48 | uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main 49 | with: 50 | package_type: "wheel" 51 | os: ${{ inputs.os }} 52 | channel: ${{ inputs.channel }} 53 | repository: "pytorch/data" 54 | smoke_test: "source ./.github/scripts/validate_binaries.sh" 55 | install_torch: true 56 | -------------------------------------------------------------------------------- /.github/workflows/validate-nightly-binaries.yml: -------------------------------------------------------------------------------- 1 | # Scheduled validation of the nightly binaries 2 | name: Cron job to validate TorchData Nightly Binaries 3 | 4 | on: 5 | schedule: 6 | # At 5:30 pm UTC (7:30 am PDT) 7 | - cron: "30 17 * * *" 8 | # Have the ability to trigger this job manually through the API 9 | workflow_dispatch: 10 | pull_request: 11 | paths: 12 | - .github/workflows/validate-nightly-binaries.yml 13 | - .github/workflows/validate-binaries.yml 14 | - .github/scripts/validate_binaries.sh 15 | jobs: 16 | nightly: 17 | uses: ./.github/workflows/validate-binaries.yml 18 | with: 19 | channel: nightly 20 | os: all 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | dist/* 3 | *.egg-info/* 4 | conda-bld/* 5 | 6 | torchdata/version.py 7 | torchdata/datapipes/iter/__init__.pyi 8 | torchdata/datapipes/map/__init__.pyi 9 | 10 | # Editor temporaries 11 | *.swn 12 | *.swo 13 | *.swp 14 | *.swm 15 | *~ 16 | .vscode 17 | .idea 18 | 19 | # macOS dir files 20 | .DS_Store 21 | 22 | ## General 23 | 24 | */*.so* 25 | */**/*.so* 26 | torchdata/*.so* 27 | 28 | # Compiled Object files 29 | *.slo 30 | *.lo 31 | *.o 32 | *.cuo 33 | *.obj 34 | 35 | # Compiled Dynamic libraries 36 | *.so 37 | *.dylib 38 | *.dll 39 | 40 | # Compiled python 41 | *.pyc 42 | *.pyd 43 | 44 | # setup script artifacts for Velox 45 | f4d-deps 46 | 47 | # sphinx documentation 48 | docs/build 49 | docs/source/generated 50 | # pytorch-sphinx-theme gets installed here 51 | docs/src 52 | 53 | # AWSSDK 54 | third_party/aws_sdk 55 | 56 | # Release Notes 57 | scripts/release_notes/results 58 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | node: 16.14.2 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.0.1 7 | hooks: 8 | - id: check-docstring-first 9 | - id: mixed-line-ending 10 | args: [--fix=lf] 11 | - id: end-of-file-fixer 12 | 13 | - repo: https://github.com/pre-commit/mirrors-prettier 14 | rev: v2.5.1 15 | hooks: 16 | - id: prettier 17 | types_or: 18 | - markdown 19 | - toml 20 | - yaml 21 | 22 | - repo: https://github.com/asottile/pyupgrade 23 | rev: v2.31.0 24 | hooks: 25 | - id: pyupgrade 26 | args: [--py37-plus] 27 | 28 | - repo: https://github.com/omnilib/ufmt 29 | rev: v1.3.2 30 | hooks: 31 | - id: ufmt 32 | additional_dependencies: 33 | - black == 21.12b0 34 | - usort == 1.0.0 35 | 36 | - repo: https://github.com/pycqa/flake8 37 | rev: 6.1.0 38 | hooks: 39 | - id: flake8 40 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | # Ignore artifacts: 2 | packaging 3 | -------------------------------------------------------------------------------- /.prettierrc.yaml: -------------------------------------------------------------------------------- 1 | proseWrap: always 2 | printWidth: 120 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make 6 | participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, 7 | disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, 8 | socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 9 | 10 | ## Our Standards 11 | 12 | Examples of behavior that contributes to creating a positive environment include: 13 | 14 | - Using welcoming and inclusive language 15 | - Being respectful of differing viewpoints and experiences 16 | - Gracefully accepting constructive criticism 17 | - Focusing on what is best for the community 18 | - Showing empathy towards other community members 19 | 20 | Examples of unacceptable behavior by participants include: 21 | 22 | - The use of sexualized language or imagery and unwelcome sexual attention or advances 23 | - Trolling, insulting/derogatory comments, and personal or political attacks 24 | - Public or private harassment 25 | - Publishing others' private information, such as a physical or electronic address, without explicit permission 26 | - Other conduct which could reasonably be considered inappropriate in a professional setting 27 | 28 | ## Our Responsibilities 29 | 30 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take 31 | appropriate and fair corrective action in response to any instances of unacceptable behavior. 32 | 33 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, 34 | issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any 35 | contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 36 | 37 | ## Scope 38 | 39 | This Code of Conduct applies within all project spaces, and it also applies when an individual is representing the 40 | project or its community in public spaces. Examples of representing a project or community include using an official 41 | project e-mail address, posting via an official social media account, or acting as an appointed representative at an 42 | online or offline event. Representation of a project may be further defined and clarified by project maintainers. 43 | 44 | ## Enforcement 45 | 46 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at 47 | . All complaints will be reviewed and investigated and will result in a response that is deemed 48 | necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to 49 | the reporter of an incident. Further details of specific enforcement policies may be posted separately. 50 | 51 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent 52 | repercussions as determined by other members of the project's leadership. 53 | 54 | ## Attribution 55 | 56 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at 57 | https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 58 | 59 | [homepage]: https://www.contributor-covenant.org 60 | 61 | For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq 62 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to TorchData 2 | 3 | We want to make contributing to this project as easy and transparent as possible. 4 | 5 | ## TL;DR 6 | 7 | We appreciate all contributions. If you are interested in contributing to TorchData, there are many ways to help out. 8 | Your contributions may fall into the following categories: 9 | 10 | - It helps the project if you can 11 | 12 | - Report issues that you're facing 13 | - Give a :+1: on issues that others reported and that are relevant to you 14 | 15 | - Answering questions on the issue tracker, investigating bugs are very valuable contributions to the project. 16 | 17 | - You would like to improve the documentation. This is no less important than improving the library itself! If you find 18 | a typo in the documentation, do not hesitate to submit a GitHub pull request. 19 | 20 | - If you would like to fix a bug: 21 | 22 | - comment on the issue that you want to work on this issue 23 | - send a PR with your fix, see below. 24 | 25 | - If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the 26 | feature with us. 27 | - We have a checklist of things to go through while adding a new Node. See below. 28 | - If you would like to feature a usage example in our documentation, discuss that with us in an issue. 29 | 30 | ## Issues 31 | 32 | We use GitHub issues to track public bugs. Please follow the existing templates if possible and ensure that the 33 | description is clear and has sufficient instructions to be able to reproduce the issue. 34 | 35 | For question related to the usage of this library, please post a question on the 36 | [PyTorch forum, under the "data" category](https://discuss.pytorch.org/c/data/37). 37 | 38 | ## Pull Requests 39 | 40 | We actively welcome your pull requests. 41 | 42 | 1. Fork the repo and create your branch from `main`. 43 | 2. If you've added code that should be tested, add tests. 44 | 3. If you've changed APIs, update the documentation and examples. 45 | 4. Ensure the test suite passes. 46 | 5. If you haven't already, complete the Contributor License Agreement ("CLA"). 47 | 48 | ## Development installation 49 | 50 | ### Install PyTorch Nightly 51 | 52 | ```bash 53 | conda install pytorch -c pytorch-nightly 54 | # or with pip (see https://pytorch.org/get-started/locally/) 55 | # pip install numpy 56 | # pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu 57 | ``` 58 | 59 | ### Install TorchData and Test Requirements 60 | 61 | ```bash 62 | git clone https://github.com/pytorch/data.git 63 | cd data 64 | pip install -e . 65 | pip install -r test/requirements.txt 66 | ``` 67 | 68 | ### Code style 69 | 70 | `torchdata` enforces a fairly strict code format through [`pre-commit`](https://pre-commit.com). You can install it with 71 | 72 | ```bash 73 | conda install -c conda-forge pre-commit 74 | # or pip install pre-commit 75 | cd data 76 | with-proxy conda install pre-commit 77 | pre-commit install --install-hooks 78 | ``` 79 | 80 | ### Running mypy and unit-tests locally 81 | 82 | Currently we don't run mypy as part of pre-commit hooks 83 | 84 | ```bash 85 | mypy --config-file=mypy.ini 86 | ``` 87 | 88 | ```bash 89 | pytest --durations=0 --no-header -v test/nodes/ 90 | ``` 91 | 92 | ### Adding a new Node 93 | 94 | When adding a new Node, there are few things that need to be done to ensure it is working and documented properly. 95 | 96 | The following simplifying assumptions are made of node implementations: 97 | 98 | - state is managed solely by the BaseNode, not through any iterators returned from them. 99 | - state_dict() returns the state of the most recently requested iterator. 100 | - load_state_dict() will set the state for the next iterator. 101 | 102 | 1. Functionality - Nodes must subclass BaseNode and implement the required methods. 103 | 104 | - `.iterator(self, initial_state: Optional[Dict[str, Any]])` - return a new iterator/generator that is properly 105 | initialized with the optional initial_state 106 | - `.get_state(self) -> Dict[str, Any]` - return a dictionary representing the state of the most recently returned 107 | iterator, or if not yet requested, the initial state. 108 | - ensure you're calling `state_dict()/load_state_dict()` on ancestor BaseNodes. Here is a simple example of a pretty 109 | useless node: 110 | 111 | ```python 112 | class MyNode(BaseNode[T]): 113 | def __init__(self, parent: BaseNode[T]): 114 | self.parent = parent 115 | self.idx = 0 # not very useful state 116 | 117 | def iterator(self, initial_state: Optional[Dict[str, Any]]) -> Iterator[T] 118 | if initial_state is not None: 119 | self.parent.load_state_dict(initial_state["parent"]) 120 | self.idx = initial_state["idx"] 121 | 122 | for item in self.parent: 123 | self.idx += 1 124 | yield item 125 | 126 | def get_state(self) -> Dict[str, Any]: 127 | return { 128 | "parent": self.parent.state_dict(), # note we call state_dict() and not get_state() here 129 | "idx": self.idx, 130 | } 131 | ``` 132 | 133 | 2. Typing - Include type-hints for all public functions and methods 134 | 3. Testing - please add unit tests to ensure that the Node is functioning properly. 135 | - In addition to testing basic functionatity, state management must also be tested. 136 | - For basic state testing, you may use `test.nodes.utils.run_test_save_load_state`. See `test/nodes/test_batch.py` 137 | for an example. 138 | 4. Documentation - ensure that the Node has a docstring, and a usage example. 139 | 5. Import - import the Node in the correct `__init__.py` file. 140 | 141 | ## Contributor License Agreement ("CLA") 142 | 143 | In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of 144 | Facebook's open source projects. 145 | 146 | Complete your CLA here: 147 | 148 | ## License 149 | 150 | By contributing to TorchData, you agree that your contributions will be licensed under the LICENSE file in the root 151 | directory of this source tree. 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021-present, Facebook, Inc. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchData 2 | 3 | [**What is TorchData?**](#what-is-torchdata) | [**Stateful DataLoader**](#stateful-dataloader) | 4 | [**Install guide**](#installation) | [**Contributing**](#contributing) | [**License**](#license) 5 | 6 | ## 7 | 8 | ## What is TorchData? 9 | 10 | The TorchData project is an iterative enhancement to the PyTorch torch.utils.data.DataLoader and 11 | torch.utils.data.Dataset/IterableDataset to make them scalable, performant dataloading solutions. We will be iterating 12 | on the enhancements under [the torchdata repo](torchdata). 13 | 14 | Our first change begins with adding checkpointing to torch.utils.data.DataLoader, which can be found in 15 | [stateful_dataloader, a drop-in replacement for torch.utils.data.DataLoader](torchdata/stateful_dataloader), by defining 16 | `load_state_dict` and `state_dict` methods that enable mid-epoch checkpointing, and an API for users to track custom 17 | iteration progress, and other custom states from the dataloader workers such as token buffers and/or RNG states. 18 | 19 | ## Stateful DataLoader 20 | 21 | `torchdata.stateful_dataloader.StatefulDataLoader` is a drop-in replacement for torch.utils.data.DataLoader which 22 | provides state_dict and load_state_dict functionality. See 23 | [the Stateful DataLoader main page](torchdata/stateful_dataloader) for more information and examples. Also check out the 24 | examples 25 | [in this Colab notebook](https://colab.research.google.com/drive/1tonoovEd7Tsi8EW8ZHXf0v3yHJGwZP8M?usp=sharing). 26 | 27 | ## torchdata.nodes 28 | 29 | torchdata.nodes is a library of composable iterators (not iterables!) that let you chain together common dataloading and 30 | pre-proc operations. It follows a streaming programming model, although "sampler + Map-style" can still be configured if 31 | you desire. See [torchdata.nodes main page](torchdata/nodes) for more details. Stay tuned for tutorial on 32 | torchdata.nodes coming soon! 33 | 34 | ## Installation 35 | 36 | ### Version Compatibility 37 | 38 | The following is the corresponding `torchdata` versions and supported Python versions. 39 | 40 | | `torch` | `torchdata` | `python` | 41 | | -------------------- | ------------------ | ----------------- | 42 | | `master` / `nightly` | `main` / `nightly` | `>=3.9`, `<=3.13` | 43 | | `2.6.0` | `0.11.0` | `>=3.9`, `<=3.13` | 44 | | `2.5.0` | `0.10.0` | `>=3.9`, `<=3.12` | 45 | | `2.5.0` | `0.9.0` | `>=3.9`, `<=3.12` | 46 | | `2.4.0` | `0.8.0` | `>=3.8`, `<=3.12` | 47 | | `2.0.0` | `0.6.0` | `>=3.8`, `<=3.11` | 48 | | `1.13.1` | `0.5.1` | `>=3.7`, `<=3.10` | 49 | | `1.12.1` | `0.4.1` | `>=3.7`, `<=3.10` | 50 | | `1.12.0` | `0.4.0` | `>=3.7`, `<=3.10` | 51 | | `1.11.0` | `0.3.0` | `>=3.7`, `<=3.10` | 52 | 53 | ### Local pip or conda 54 | 55 | First, set up an environment. We will be installing a PyTorch binary as well as torchdata. If you're using conda, create 56 | a conda environment: 57 | 58 | ```bash 59 | conda create --name torchdata 60 | conda activate torchdata 61 | ``` 62 | 63 | If you wish to use `venv` instead: 64 | 65 | ```bash 66 | python -m venv torchdata-env 67 | source torchdata-env/bin/activate 68 | ``` 69 | 70 | Install torchdata: 71 | 72 | Using pip: 73 | 74 | ```bash 75 | pip install torchdata 76 | ``` 77 | 78 | Using conda: 79 | 80 | ```bash 81 | conda install -c pytorch torchdata 82 | ``` 83 | 84 | ### From source 85 | 86 | ```bash 87 | pip install . 88 | ``` 89 | 90 | In case building TorchData from source fails, install the nightly version of PyTorch following the linked guide on the 91 | [contributing page](CONTRIBUTING.md#install-pytorch-nightly). 92 | 93 | ### From nightly 94 | 95 | The nightly version of TorchData is also provided and updated daily from main branch. 96 | 97 | Using pip: 98 | 99 | ```bash 100 | pip install --pre torchdata --index-url https://download.pytorch.org/whl/nightly/cpu 101 | ``` 102 | 103 | Using conda: 104 | 105 | ```bash 106 | conda install torchdata -c pytorch-nightly 107 | ``` 108 | 109 | ## Contributing 110 | 111 | We welcome PRs! See the [CONTRIBUTING](CONTRIBUTING.md) file. 112 | 113 | ## Beta Usage and Feedback 114 | 115 | We'd love to hear from and work with early adopters to shape our designs. Please reach out by raising an issue if you're 116 | interested in using this tooling for your project. 117 | 118 | ## License 119 | 120 | TorchData is BSD licensed, as found in the [LICENSE](LICENSE) file. 121 | -------------------------------------------------------------------------------- /benchmarks/cloud/README.md: -------------------------------------------------------------------------------- 1 | This folder contains templates that are useful for cloud setups 2 | 3 | Idea would be to provision a machine by configuring it in a YAML file and then running a benchmark script on it 4 | automatically. This is critical both for ad hoc benchmarking that are reproducible but also including real world 5 | benchmarks in a release. 6 | 7 | We've provided some useful `yml` templates for you to get started 8 | 9 | https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/using-cfn-cli-creating-stack.html 10 | 11 | ## Setup aws cli 12 | 13 | `aws configure` and enter your credentials 14 | 15 | ## Setup stack (machine configuration) 16 | 17 | ```sh 18 | aws cloudformation create-stack \ 19 | --stack-name torchdatabenchmark \ 20 | --template-body ec2.yml \ 21 | --parameters ParameterKey=InstanceTypeParameter,ParameterValue=p3.2xlarge ParameterKey=DiskType,ParameterValue=gp3 22 | ``` 23 | 24 | ## Ssh into machine and run job 25 | 26 | ``` 27 | ssh elastic_ip 28 | git clone https://github.com/pytorch/data 29 | cd data/benchmarks 30 | python run_benchmark.py 31 | ``` 32 | 33 | Visually inspect logs 34 | 35 | ## Shut down stack 36 | 37 | `aws cloudformation delete-stack --stack-name torchdatabenchmark` 38 | -------------------------------------------------------------------------------- /benchmarks/cloud/aws_s3_results.md: -------------------------------------------------------------------------------- 1 | ## AWS S3 Benchmark - Data Loading from S3 and Attached Volume 2 | 3 | As we introduce 4 | [various methods to load data from cloud service providers](https://pytorch.org/data/0.5/tutorial.html#working-with-cloud-storage-providers), 5 | we are interested to know how the performance of loading from cloud compares to loading data from a local disk. We 6 | created a [small script](aws_s3.py) to benchmark their throughputs (MiB/s) on an AWS EC2 instance. 7 | 8 | ### AWS EC2 Setup 9 | 10 | - Instance Type: [c5.24xlarge](https://aws.amazon.com/ec2/instance-types/c5/) 11 | - File system: gp2, AWS S3 12 | - Dataset (50GB) 13 | - Each tar archive is ~100MB with ~1050 images 14 | - Each image (256 px × 256 px) is about 100kb in size 15 | - Epochs: 16 | - 1 warm up epoch to isolate disk caching 17 | - 3-5 actual epochs 18 | - Data transformation complexity 19 | - Low (~3ms/image), Medium (~7ms), High (~10ms) 20 | - Number of workers: 2-12 21 | - We limit it to a reasonable number of workers as that is approximately what is likely used per GPU 22 | 23 | Our observation is that the performance gap between loading from S3 and local disk tends to be relatively small when 24 | your pre-processing/transformation operations are compute intensive. The exact result can vary depending on your 25 | instance’s disk IO throughput and network bandwidth. 26 | 27 | ![Low Transform Complexity](result_images/aws_s3_low_complex.jpg) 28 | ![Medium Transform Complexity](result_images/aws_s3_med_complex.jpg) 29 | ![High Transform Complexity](result_images/aws_s3_high_complex.jpg) 30 | -------------------------------------------------------------------------------- /benchmarks/cloud/ec2.yml: -------------------------------------------------------------------------------- 1 | # This script sets up an Ec2 instance with elastic IP and a disk volume 2 | Parameters: 3 | InstanceTypeParameter: 4 | Type: String 5 | Default: c5n.large 6 | AllowedValues: 7 | - c5n.large 8 | - p2.2xlarge 9 | - p3.2xlarge 10 | - p3.8xlarge 11 | Description: Instance type CPU, GPU 12 | DiskSize: 13 | Type: Number 14 | Default: 100 15 | Description: Disk size in GB 16 | DiskType: 17 | Type: String 18 | Default: gp2 19 | AllowedValues: 20 | - gp2 21 | - gp3 22 | - io1 23 | - io2 24 | - sc1 25 | - st1 26 | - standard 27 | Description: Enter Disk type SSD, HDD 28 | 29 | Resources: 30 | MyInstance: 31 | Type: AWS::EC2::Instance 32 | Properties: 33 | AvailabilityZone: us-west-2a 34 | ImageId: ami-0306d46d05aaf8663 # Deep Learning AMI 35 | InstanceType: 36 | Ref: InstanceTypeParameter 37 | SecurityGroups: 38 | - !Ref SSHSecurityGroup 39 | 40 | # Elastic IP so I can easily ssh into the machine 41 | MyEIP: 42 | Type: AWS::EC2::EIP 43 | Properties: 44 | InstanceId: !Ref MyInstance 45 | 46 | # Open security group for SSH 47 | SSHSecurityGroup: 48 | Type: AWS::EC2::SecurityGroup 49 | Properties: 50 | GroupDescription: Enable SSH access via port 22 51 | SecurityGroupIngress: 52 | - CidrIp: 0.0.0.0/0 53 | FromPort: 22 54 | IpProtocol: tcp 55 | ToPort: 22 56 | 57 | NewVolume: 58 | Type: AWS::EC2::Volume 59 | Properties: 60 | Size: 61 | Ref: DiskSize 62 | VolumeType: 63 | Ref: DiskType 64 | AvailabilityZone: !GetAtt MyInstance.AvailabilityZone 65 | Tags: 66 | - Key: MyTag 67 | Value: TagValue 68 | DeletionPolicy: Snapshot 69 | 70 | MountPoint: 71 | Type: AWS::EC2::VolumeAttachment 72 | Properties: 73 | InstanceId: !Ref MyInstance 74 | VolumeId: !Ref NewVolume 75 | Device: /dev/sdh 76 | -------------------------------------------------------------------------------- /benchmarks/cloud/result_images/aws_s3_high_complex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/benchmarks/cloud/result_images/aws_s3_high_complex.jpg -------------------------------------------------------------------------------- /benchmarks/cloud/result_images/aws_s3_low_complex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/benchmarks/cloud/result_images/aws_s3_low_complex.jpg -------------------------------------------------------------------------------- /benchmarks/cloud/result_images/aws_s3_med_complex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/benchmarks/cloud/result_images/aws_s3_med_complex.jpg -------------------------------------------------------------------------------- /benchmarks/torchvision_classification/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | import random 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.distributed as dist 13 | import torchvision 14 | from PIL import Image 15 | 16 | 17 | # TODO: maybe infinite buffer can / is already natively supported by torchdata? 18 | INFINITE_BUFFER_SIZE = 1_000_000_000 19 | 20 | IMAGENET_TRAIN_LEN = 1_281_167 21 | IMAGENET_TEST_LEN = 50_000 22 | 23 | 24 | def _decode(path, root, category_to_int): 25 | category = Path(path).relative_to(root).parts[0] 26 | 27 | image = Image.open(path).convert("RGB") 28 | label = category_to_int(category) 29 | 30 | return image, label 31 | 32 | 33 | def _apply_tranforms(img_and_label, transforms): 34 | img, label = img_and_label 35 | return transforms(img), label 36 | 37 | 38 | class PreLoadedMapStyle: 39 | # All the data is pre-loaded and transformed in __init__, so the DataLoader should be crazy fast. 40 | # This is just to assess how fast a model could theoretically be trained if there was no data bottleneck at all. 41 | def __init__(self, dir, transform, buffer_size=100): 42 | dataset = torchvision.datasets.ImageFolder(dir, transform=transform) 43 | self.size = len(dataset) 44 | self.samples = [dataset[torch.randint(0, len(dataset), size=(1,)).item()] for i in range(buffer_size)] 45 | 46 | def __len__(self): 47 | return self.size 48 | 49 | def __getitem__(self, idx): 50 | return self.samples[idx % len(self.samples)] 51 | 52 | 53 | class MapStyleToIterable(torch.utils.data.IterableDataset): 54 | # This converts a MapStyle dataset into an iterable one. 55 | # Not sure this kind of Iterable dataset is actually useful to benchmark. It 56 | # was necessary when benchmarking async-io stuff, but not anymore. 57 | # If anything, it shows how tricky Iterable datasets are to implement. 58 | def __init__(self, dataset, shuffle): 59 | self.dataset = dataset 60 | self.shuffle = shuffle 61 | 62 | self.size = len(self.dataset) 63 | self.seed = 0 # has to be hard-coded for all DDP workers to have the same shuffling 64 | 65 | def __len__(self): 66 | return self.size // dist.get_world_size() 67 | 68 | def __iter__(self): 69 | 70 | worker_info = torch.utils.data.get_worker_info() 71 | num_dl_workers = worker_info.num_workers 72 | dl_worker_id = worker_info.id 73 | 74 | num_ddp_workers = dist.get_world_size() 75 | ddp_worker_id = dist.get_rank() 76 | 77 | num_total_workers = num_ddp_workers * num_dl_workers 78 | current_worker_id = ddp_worker_id + (num_ddp_workers * dl_worker_id) 79 | 80 | indices = range(self.size) 81 | if self.shuffle: 82 | rng = random.Random(self.seed) 83 | indices = rng.sample(indices, k=self.size) 84 | indices = itertools.islice(indices, current_worker_id, None, num_total_workers) 85 | 86 | samples = (self.dataset[i] for i in indices) 87 | yield from samples 88 | 89 | 90 | # TODO: maybe only generate these when --no-transforms is passed? 91 | _RANDOM_IMAGE_TENSORS = [torch.randn(3, 224, 224) for _ in range(300)] 92 | 93 | 94 | def no_transforms(_): 95 | # see --no-transforms doc 96 | return random.choice(_RANDOM_IMAGE_TENSORS) 97 | -------------------------------------------------------------------------------- /benchmarks/torchvision_classification/presets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torchvision.transforms import transforms 9 | 10 | 11 | class ClassificationPresetTrain: 12 | def __init__( 13 | self, 14 | *, 15 | crop_size, 16 | mean=(0.485, 0.456, 0.406), 17 | std=(0.229, 0.224, 0.225), 18 | hflip_prob=0.5, 19 | ): 20 | trans = [transforms.RandomResizedCrop(crop_size)] 21 | if hflip_prob > 0: 22 | trans.append(transforms.RandomHorizontalFlip(hflip_prob)) 23 | 24 | trans.extend( 25 | [ 26 | transforms.PILToTensor(), 27 | transforms.ConvertImageDtype(torch.float), 28 | transforms.Normalize(mean=mean, std=std), 29 | ] 30 | ) 31 | 32 | self.transforms = transforms.Compose(trans) 33 | 34 | def __call__(self, img): 35 | return self.transforms(img) 36 | 37 | 38 | class ClassificationPresetEval: 39 | def __init__( 40 | self, 41 | *, 42 | crop_size, 43 | resize_size=256, 44 | mean=(0.485, 0.456, 0.406), 45 | std=(0.229, 0.224, 0.225), 46 | ): 47 | 48 | self.transforms = transforms.Compose( 49 | [ 50 | transforms.Resize(resize_size), 51 | transforms.CenterCrop(crop_size), 52 | transforms.PILToTensor(), 53 | transforms.ConvertImageDtype(torch.float), 54 | transforms.Normalize(mean=mean, std=std), 55 | ] 56 | ) 57 | 58 | def __call__(self, img): 59 | return self.transforms(img) 60 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | ifneq ($(EXAMPLES_PATTERN),) 5 | EXAMPLES_PATTERN_OPTS := -D sphinx_gallery_conf.filename_pattern="$(EXAMPLES_PATTERN)" 6 | endif 7 | 8 | # You can set these variables from the command line. 9 | SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) 10 | SPHINXBUILD = sphinx-build 11 | SPHINXPROJ = torchdata 12 | SOURCEDIR = source 13 | BUILDDIR = build 14 | 15 | # Put it first so that "make" without argument is like "make help". 16 | help: 17 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 18 | 19 | docset: html 20 | doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/data/ --force $(BUILDDIR)/html/ 21 | 22 | # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. 23 | cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png 24 | convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png 25 | 26 | html-noplot: # Avoids running the gallery examples, which may take time 27 | $(SPHINXBUILD) -D plot_gallery=0 -b html "${SOURCEDIR}" "$(BUILDDIR)"/html 28 | @echo 29 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 30 | 31 | clean: 32 | rm -rf $(BUILDDIR)/* 33 | rm -rf $(SOURCEDIR)/generated_examples/ # sphinx-gallery 34 | rm -rf $(SOURCEDIR)/gen_modules/ # sphinx-gallery 35 | rm -rf $(SOURCEDIR)/sg_execution_times.rst # sphinx-gallery 36 | rm -rf $(SOURCEDIR)/generated/ # autosummary 37 | 38 | .PHONY: help Makefile docset 39 | 40 | # Catch-all target: route all unknown targets to Sphinx using the new 41 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 42 | %: Makefile 43 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 44 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | ## Building the Documentation 2 | 3 | To build the documentation, you will need [Sphinx](http://www.sphinx-doc.org) and the PyTorch theme. 4 | 5 | ```bash 6 | cd docs/ 7 | pip install -r requirements.txt 8 | ``` 9 | 10 | You can then build the documentation by running `make ` from the `docs/` folder. Run `make` to get a list of all 11 | available output formats. 12 | 13 | ```bash 14 | make html 15 | ``` 16 | 17 | ## Improving the Documentation 18 | 19 | Feel free to open an issue or pull request to inform us of any inaccuracy or potential improvement that we can make to 20 | our documentation. 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.0.0 2 | # torch 3 | # PyTorch Theme 4 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 5 | # For Graph Visualization 6 | graphviz 7 | -------------------------------------------------------------------------------- /docs/source/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the BSD-style license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | h1 { 10 | font-size: 2rem; 11 | letter-spacing: 1.78px; 12 | line-height: 2.5rem; 13 | margin: 1.375rem 0; 14 | text-transform: none; /* Overwrite upper-case conversion of titles */ 15 | } 16 | -------------------------------------------------------------------------------- /docs/source/_templates/class_method_template.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline }} 7 | 8 | .. autoclass:: {{ name }} 9 | :members: 10 | -------------------------------------------------------------------------------- /docs/source/_templates/class_template.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline }} 7 | 8 | .. autoclass:: {{ name }} 9 | -------------------------------------------------------------------------------- /docs/source/_templates/function.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | .. currentmodule:: {{ module }} 4 | 5 | 6 | {{ name | underline }} 7 | 8 | .. autofunction:: {{ name }} 9 | -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block sidebartitle %} 4 | 7 | {% include "searchbox.html" %} 8 | {% endblock %} 9 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Configuration file for the Sphinx documentation builder. 8 | # 9 | # This file only contains a selection of the most common options. For a full 10 | # list see the documentation: 11 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 12 | 13 | # -- Path setup -------------------------------------------------------------- 14 | 15 | # If extensions (or modules to document with autodoc) are in another directory, 16 | # add these directories to sys.path here. If the directory is relative to the 17 | # documentation root, use os.path.abspath to make it absolute, like shown here. 18 | # 19 | import os 20 | import sys 21 | 22 | import pytorch_sphinx_theme 23 | import torchdata 24 | 25 | # sys.path.insert(0, os.path.abspath('.')) 26 | 27 | current_dir = os.path.dirname(__file__) 28 | target_dir = os.path.abspath(os.path.join(current_dir, "../..")) 29 | sys.path.insert(0, target_dir) 30 | print(target_dir) 31 | 32 | 33 | # -- Project information ----------------------------------------------------- 34 | 35 | project = "TorchData" 36 | copyright = "2021 - Present, Torch Contributors" 37 | author = "Torch Contributors" 38 | 39 | # The short X.Y version 40 | version = "main (" + torchdata.__version__ + " )" 41 | 42 | # The full version, including alpha/beta/rc tags 43 | release = "main" 44 | 45 | 46 | # -- General configuration --------------------------------------------------- 47 | 48 | # Add any Sphinx extension module names here, as strings. They can be 49 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 50 | # ones. 51 | extensions = [ 52 | "sphinx.ext.napoleon", 53 | "sphinx.ext.autodoc", 54 | "sphinx.ext.autosummary", 55 | "sphinx.ext.intersphinx", 56 | "sphinx.ext.doctest", 57 | "sphinx.ext.graphviz", 58 | ] 59 | 60 | # Do not execute standard reST doctest blocks so that documentation can 61 | # be successively migrated to sphinx's doctest directive. 62 | doctest_test_doctest_blocks = "" 63 | 64 | # Add any paths that contain templates here, relative to this directory. 65 | templates_path = ["_templates"] 66 | 67 | # List of patterns, relative to source directory, that match files and 68 | # directories to ignore when looking for source files. 69 | # This pattern also affects html_static_path and html_extra_path. 70 | exclude_patterns = [] 71 | 72 | 73 | # -- Options for HTML output ------------------------------------------------- 74 | 75 | # The theme to use for HTML and HTML Help pages. See the documentation for 76 | # a list of builtin themes. 77 | # 78 | # html_theme = 'alabaster' 79 | html_theme = "pytorch_sphinx_theme" 80 | html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] 81 | 82 | html_theme_options = { 83 | "collapse_navigation": False, 84 | "display_version": True, 85 | "logo_only": True, 86 | "pytorch_project": "docs", 87 | "navigation_with_keys": True, 88 | "analytics_id": "UA-117752657-2", 89 | } 90 | 91 | # Add any paths that contain custom static files (such as style sheets) here, 92 | # relative to this directory. They are copied after the builtin static files, 93 | # so a file named "default.css" will overwrite the builtin "default.css". 94 | html_static_path = ["_static"] 95 | 96 | html_css_files = [ 97 | "css/custom.css", 98 | ] 99 | 100 | signature_replacements = {} 101 | 102 | 103 | def process_signature(app, what, name, obj, options, signature, return_annotation): 104 | """Replacing long type annotations in signature with more succinct ones.""" 105 | if isinstance(signature, str): 106 | for old, new in signature_replacements.items(): 107 | if old in signature: 108 | signature = signature.replace(old, new) 109 | return signature, return_annotation 110 | 111 | 112 | def setup(app): 113 | 114 | app.connect("autodoc-process-signature", process_signature) 115 | 116 | 117 | intersphinx_mapping = { 118 | "graphviz": ("https://graphviz.readthedocs.io/en/stable/", None), 119 | } 120 | -------------------------------------------------------------------------------- /docs/source/docutils.conf: -------------------------------------------------------------------------------- 1 | [html writers] 2 | table_style: colwidths-auto # Necessary for the table generated by autosummary to look decent 3 | -------------------------------------------------------------------------------- /docs/source/getting_started_with_torchdata_nodes.rst: -------------------------------------------------------------------------------- 1 | Getting Started With ``torchdata.nodes`` (beta) 2 | =============================================== 3 | 4 | Install torchdata with pip. 5 | 6 | .. code:: bash 7 | 8 | pip install torchdata>=0.10.0 9 | 10 | Generator Example 11 | ~~~~~~~~~~~~~~~~~ 12 | 13 | Wrap a generator (or any iterable) to convert it to a BaseNode and get started 14 | 15 | .. code:: python 16 | 17 | from torchdata.nodes import IterableWrapper, ParallelMapper, Loader 18 | 19 | node = IterableWrapper(range(10)) 20 | node = ParallelMapper(node, map_fn=lambda x: x**2, num_workers=3, method="thread") 21 | loader = Loader(node) 22 | result = list(loader) 23 | print(result) 24 | # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] 25 | 26 | Sampler Example 27 | ~~~~~~~~~~~~~~~ 28 | 29 | Samplers are still supported, and you can use your existing 30 | ``torch.utils.data.Dataset``\'s. See :ref:`migrate-to-nodes-from-utils` for an in-depth 31 | example. 32 | 33 | .. code:: python 34 | 35 | from torch.utils.data import RandomSampler 36 | from torchdata.nodes import SamplerWrapper, ParallelMapper, Loader 37 | 38 | 39 | class SquaredDataset(torch.utils.data.Dataset): 40 | def __getitem__(self, i: int) -> int: 41 | return i**2 42 | def __len__(self): 43 | return 10 44 | 45 | dataset = SquaredDataset() 46 | sampler = RandomSampler(dataset) 47 | 48 | # For fine-grained control of iteration order, define your own sampler 49 | node = SamplerWrapper(sampler) 50 | # Simply apply dataset's __getitem__ as a map function to the indices generated from sampler 51 | node = ParallelMapper(node, map_fn=dataset.__getitem__, num_workers=3, method="thread") 52 | # Loader is used to convert a node (iterator) into an Iterable that may be reused for multi epochs 53 | loader = Loader(node) 54 | print(list(loader)) 55 | # [25, 36, 9, 49, 0, 81, 4, 16, 64, 1] 56 | print(list(loader)) 57 | # [0, 4, 1, 64, 49, 25, 9, 16, 81, 36] 58 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. TorchData documentation master file, created by 2 | sphinx-quickstart on Thu Jan 20 09:56:06 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | TorchData 7 | ###################################### 8 | This library is part of the `PyTorch 9 | `_ project. PyTorch is an open source 10 | machine learning framework. 11 | 12 | :mod:`torchdata` is a Beta library of common modular data loading primitives for 13 | easily constructing flexible and performant data pipelines. And, there are 14 | a few features still in prototype stage. 15 | 16 | Features described in this documentation are classified by release status: 17 | 18 | *Stable:* These features will be maintained long-term and there should generally 19 | be no major performance limitations or gaps in documentation. 20 | We also expect to maintain backwards compatibility (although 21 | breaking changes can happen and notice will be given one release ahead 22 | of time). 23 | 24 | *Beta:* Features are tagged as Beta because the API may change based on 25 | user feedback, because the performance needs to improve, or because 26 | coverage across operators is not yet complete. For Beta features, we are 27 | committing to seeing the feature through to the Stable classification. 28 | We are not, however, committing to backwards compatibility. 29 | 30 | *Prototype:* These features are typically not available as part of 31 | binary distributions like PyPI or Conda, except sometimes behind run-time 32 | flags, and are at an early stage for feedback and testing. 33 | 34 | .. toctree:: 35 | :maxdepth: 2 36 | :caption: Developer Notes: 37 | 38 | what_is_torchdata_nodes.rst 39 | 40 | .. toctree:: 41 | :maxdepth: 2 42 | :caption: API Reference: 43 | 44 | torchdata.nodes.rst 45 | torchdata.stateful_dataloader.rst 46 | 47 | 48 | .. toctree:: 49 | :maxdepth: 2 50 | :caption: Tutorial and Examples: 51 | 52 | getting_started_with_torchdata_nodes.rst 53 | migrate_to_nodes_from_utils.rst 54 | stateful_dataloader_tutorial.rst 55 | 56 | 57 | .. toctree:: 58 | :maxdepth: 1 59 | :caption: PyTorch Libraries 60 | 61 | PyTorch 62 | torchtune 63 | torchaudio 64 | torchvision 65 | TorchElastic 66 | TorchServe 67 | PyTorch on XLA Devices 68 | 69 | 70 | Indices 71 | ================== 72 | 73 | * :ref:`genindex` 74 | -------------------------------------------------------------------------------- /docs/source/torchdata.nodes.rst: -------------------------------------------------------------------------------- 1 | ``torchdata.nodes`` (beta) 2 | ========================== 3 | 4 | .. automodule:: torchdata.nodes 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/source/torchdata.stateful_dataloader.rst: -------------------------------------------------------------------------------- 1 | :tocdepth: 3 2 | 3 | Stateful DataLoader 4 | =================== 5 | 6 | .. automodule:: torchdata.stateful_dataloader 7 | 8 | StatefulDataLoader is a drop-in replacement for `torch.utils.data.DataLoader `_ which offers ``state_dict`` / ``load_state_dict`` methods for handling mid-epoch checkpointing which operate on the previous/next iterator requested from the dataloader (resp.). 9 | 10 | By default, the state includes the number of batches yielded and uses this to naively fast-forward the sampler (map-style) or the dataset (iterable-style). However if the sampler and/or dataset include ``state_dict`` / ``load_state_dict`` methods, then it will call them during its own ``state_dict`` / ``load_state_dict`` calls. Under the hood, :class:`StatefulDataLoader` handles aggregation and distribution of state across multiprocess workers (but not across ranks). 11 | 12 | .. autoclass:: StatefulDataLoader 13 | :members: 14 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /examples/criteo_torcharrow/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, List, TypeVar 8 | 9 | T = TypeVar("T") 10 | 11 | # Criteo Data Set Parameters 12 | INT_FEATURE_COUNT = 13 13 | CAT_FEATURE_COUNT = 26 14 | DEFAULT_LABEL_NAME = "label" 15 | DEFAULT_INT_NAMES: List[str] = [f"int_{idx}" for idx in range(INT_FEATURE_COUNT)] 16 | DEFAULT_CAT_NAMES: List[str] = [f"cat_{idx}" for idx in range(CAT_FEATURE_COUNT)] 17 | DEFAULT_COLUMN_NAMES: List[str] = [ 18 | DEFAULT_LABEL_NAME, 19 | *DEFAULT_INT_NAMES, 20 | *DEFAULT_CAT_NAMES, 21 | ] 22 | 23 | 24 | def safe_cast(val: T, dest_type: Callable[[T], T], default: T) -> T: 25 | """ 26 | Helper function to safely cast data with default as fallback. 27 | """ 28 | try: 29 | return dest_type(val) 30 | except ValueError: 31 | return default 32 | 33 | 34 | def safe_hex_to_int(num): 35 | try: 36 | return int(safe_cast(num, str, "0") or "0", 16) 37 | except Exception: 38 | return float("NaN") 39 | -------------------------------------------------------------------------------- /examples/criteo_torcharrow/tsv_to_parquet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | This file pre-process the source file and save it as a TSV file and a Parquet file. 9 | You do not need to re-run this file if "day_11_first_3k_rows.parquet" and "day_11_first_3k_rows.tsv" exist locally 10 | """ 11 | 12 | import pandas 13 | import pyarrow 14 | import pyarrow.parquet as parquet 15 | from common import DEFAULT_CAT_NAMES, DEFAULT_COLUMN_NAMES, safe_hex_to_int 16 | 17 | 18 | # Read TSV File with Pandas 19 | tsv_fname = "day_11_first_3k_rows_original.tsv" 20 | df = pandas.read_csv(tsv_fname, sep="\t") 21 | df.columns = DEFAULT_COLUMN_NAMES 22 | 23 | # Convert hex strings to interger 24 | for i, row in df.iterrows(): 25 | for cat_col in DEFAULT_CAT_NAMES: 26 | df.at[i, cat_col] = safe_hex_to_int(row[cat_col]) 27 | 28 | # Convert to PyArrow table and write to disk as parquet file 29 | table = pyarrow.Table.from_pandas(df=df) 30 | parquet_fname = "day_11_first_3k_rows.parquet" 31 | parquet.write_table(table, parquet_fname) 32 | 33 | # Write to a new .tsv file 34 | df.to_csv("day_11_first_3k_rows.tsv", sep="\t") 35 | -------------------------------------------------------------------------------- /examples/nodes/hf_imdb_bert.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d8513771-36ac-4d03-b890-35108bce2211", 6 | "metadata": {}, 7 | "source": [ 8 | "### Loading and processing IMDB movie review dataset\n", 9 | "In this example, we will load the IMDB dataset from Hugging Face, \n", 10 | "use `torchdata.nodes` to process it and generate training batches." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "eb3b507c-2ad1-410d-a834-6847182de684", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from datasets import load_dataset\n", 21 | "from transformers import BertTokenizer, BertForSequenceClassification" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "089f1126-7125-4274-9d71-5c949ccc7bbd", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "import torch\n", 32 | "from torch.utils.data import default_collate, RandomSampler, SequentialSampler" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "2afac7d9-3d66-4195-8647-dc7034d306f2", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# Load IMDB dataset from huggingface datasets and select the \"train\" split\n", 43 | "dataset = load_dataset(\"imdb\", streaming=False)\n", 44 | "dataset = dataset[\"train\"]\n", 45 | "# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n", 46 | "# Please refer to the migration guide here https://pytorch.org/data/main/migrate_to_nodes_from_utils.html\n", 47 | "# to migrate from torch.utils.data to torchdata.nodes\n", 48 | "\n", 49 | "sampler = RandomSampler(dataset)\n", 50 | "# Use a standard bert tokenizer\n", 51 | "tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", 52 | "# Now we can set up some torchdata.nodes to create our pre-proc pipeline" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "09e08a47-573c-4d32-9a02-36cd8150db60", 58 | "metadata": {}, 59 | "source": [ 60 | "All torchdata.nodes.BaseNode implementations are Iterators.\n", 61 | "MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", 62 | "Under the hood, MapStyleWrapper just does:\n", 63 | "```python\n", 64 | "node = IterableWrapper(sampler)\n", 65 | "node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", 66 | "```" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 4, 72 | "id": "02af5479-ee69-41d8-ab2d-bf154b84bc15", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", 77 | "node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", 78 | "\n", 79 | "# Now we want to transform the raw inputs. We can just use another Mapper with\n", 80 | "# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", 81 | "# threads (or processes) to parallelize this work and have it run in the background\n", 82 | "max_len = 512\n", 83 | "batch_size = 2\n", 84 | "def bert_transform(item):\n", 85 | " encoding = tokenizer.encode_plus(\n", 86 | " item[\"text\"],\n", 87 | " add_special_tokens=True,\n", 88 | " max_length=max_len,\n", 89 | " padding=\"max_length\",\n", 90 | " truncation=True,\n", 91 | " return_attention_mask=True,\n", 92 | " return_tensors=\"pt\",\n", 93 | " )\n", 94 | " return {\n", 95 | " \"input_ids\": encoding[\"input_ids\"].flatten(),\n", 96 | " \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n", 97 | " \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n", 98 | " }\n", 99 | "node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n", 100 | "\n", 101 | "# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", 102 | "# to stack the tensors between. We use torch.utils.data.default_collate for this\n", 103 | "node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", 104 | "node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", 105 | "\n", 106 | "# we can optionally apply pin_memory to the batches\n", 107 | "if torch.cuda.is_available():\n", 108 | " node = PinMemory(node)\n", 109 | "\n", 110 | "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", 111 | "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", 112 | "loader = Loader(node)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": 5, 118 | "id": "60fd54f3-62ef-47aa-a790-853cb4899f13", 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "{'input_ids': tensor([[ 101, 1045, 2572, ..., 2143, 2000, 102],\n", 126 | " [ 101, 2004, 1037, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", 127 | " [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "# Inspect a batch\n", 133 | "batch = next(iter(loader))\n", 134 | "print(batch)\n", 135 | "# In a batch we get three keys, as defined in the method `bert_transform`.\n", 136 | "# Since the batch size is 2, two samples are stacked together for each key." 137 | ] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3 (ipykernel)", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.12.6" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 5 161 | } 162 | -------------------------------------------------------------------------------- /examples/nodes/multi_dataset_weighted_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "79a14c63-a085-493f-8db9-6af3e1d744b5", 6 | "metadata": {}, 7 | "source": [ 8 | "### `MultiNodeWeightedSampler` example\n", 9 | "In this notebook, we will explore the usage of `MultiNodeWeightedSampler` in `torchdata.nodes`.\n", 10 | "\n", 11 | "`MultiNodeWeightedSampler` allows us to sample with a probability from multiple datsets. We will make three datasets, and then see how does the composition of the output depends on the weights defined in the `MultiNodeWeightedSampler`." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "id": "0b283748-9b3f-4b9e-bbc5-db0791f4d900", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from torchdata.nodes import Mapper, MultiNodeWeightedSampler, IterableWrapper, Loader\n", 22 | "import collections\n", 23 | "\n", 24 | "# defining a simple map_fn as a place holder example\n", 25 | "def map_fn(item):\n", 26 | " return {\"x\":item}\n", 27 | "\n", 28 | "\n", 29 | "def constant_stream(value: int):\n", 30 | " while True:\n", 31 | " yield value\n", 32 | "\n", 33 | "# First, we create a dictionary of three datasets, with each dataset converted into BaseNode using the IterableWrapper\n", 34 | "num_datasets = 3\n", 35 | "datasets = {\n", 36 | " \"ds0\": IterableWrapper(constant_stream(0)),\n", 37 | " \"ds1\": IterableWrapper(constant_stream(1)),\n", 38 | " \"ds2\": IterableWrapper(constant_stream(2)),\n", 39 | "}\n", 40 | "\n", 41 | "# Next, we have to define weights for sampling from a particular dataset\n", 42 | "weights = {\"ds0\": 0.5, \"ds1\": 0.25, \"ds2\": 0.25}\n", 43 | "\n", 44 | "# Finally we instatiate the MultiNodeWeightedSampler to sample from our datasets\n", 45 | "multi_node_sampler = MultiNodeWeightedSampler(datasets, weights)\n", 46 | "\n", 47 | "# Since nodes are iterators, they need to be manually .reset() between epochs.\n", 48 | "# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", 49 | "loader = Loader(multi_node_sampler)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 2, 55 | "id": "77784ba3-b917-4083-aed4-dba2374110d5", 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "fractions = {0: 0.49791, 2: 0.25067, 1: 0.25142}\n", 63 | "The original weights were = {'ds0': 0.5, 'ds1': 0.25, 'ds2': 0.25}\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "# Let's take a look at the output for 100k numbers, compute the fraction of each dataset in that batch\n", 69 | "# and compare the batch composition with our given weights\n", 70 | "n = 100000\n", 71 | "it = iter(loader)\n", 72 | "samples = [next(it) for _ in range(n)]\n", 73 | "fractions = {k: v/len(samples) for k, v in collections.Counter(samples).items()}\n", 74 | "print(f\"fractions = {fractions}\")\n", 75 | "print(f\"The original weights were = {weights}\")" 76 | ] 77 | } 78 | ], 79 | "metadata": { 80 | "kernelspec": { 81 | "display_name": "Python 3 (ipykernel)", 82 | "language": "python", 83 | "name": "python3" 84 | }, 85 | "language_info": { 86 | "codemirror_mode": { 87 | "name": "ipython", 88 | "version": 3 89 | }, 90 | "file_extension": ".py", 91 | "mimetype": "text/x-python", 92 | "name": "python", 93 | "nbconvert_exporter": "python", 94 | "pygments_lexer": "ipython3", 95 | "version": "3.12.6" 96 | } 97 | }, 98 | "nbformat": 4, 99 | "nbformat_minor": 5 100 | } 101 | -------------------------------------------------------------------------------- /examples/vision/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /examples/vision/fakedata/caltech101/101_ObjectCategories.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/caltech101/101_ObjectCategories.tar.gz -------------------------------------------------------------------------------- /examples/vision/fakedata/caltech101/Annotations.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/caltech101/Annotations.tar -------------------------------------------------------------------------------- /examples/vision/fakedata/caltech256/256_ObjectCategories.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/caltech256/256_ObjectCategories.tar -------------------------------------------------------------------------------- /examples/vision/fakedata/imagefolder/cat/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/imagefolder/cat/1.jpg -------------------------------------------------------------------------------- /examples/vision/fakedata/imagefolder/cat/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/imagefolder/cat/2.jpg -------------------------------------------------------------------------------- /examples/vision/fakedata/imagefolder/cat/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/imagefolder/cat/3.jpg -------------------------------------------------------------------------------- /examples/vision/fakedata/imagefolder/dog/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/imagefolder/dog/1.jpg -------------------------------------------------------------------------------- /examples/vision/fakedata/imagefolder/dog/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/imagefolder/dog/2.jpg -------------------------------------------------------------------------------- /examples/vision/fakedata/imagefolder/dog/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/examples/vision/fakedata/imagefolder/dog/3.jpg -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_unused_configs = True 3 | warn_redundant_casts = True 4 | show_error_codes = True 5 | show_column_numbers = True 6 | check_untyped_defs = True 7 | pretty = True 8 | 9 | files = torchdata 10 | 11 | exclude = examples, test, packaging 12 | 13 | python_version = 3.9 14 | 15 | # 16 | # Third party dependencies that don't have types. 17 | # 18 | [mypy-aistore.*] 19 | ignore_missing_imports = True 20 | 21 | [mypy-datasets.*] 22 | ignore_missing_imports = True 23 | 24 | [mypy-dill.*] 25 | ignore_missing_imports = True 26 | 27 | [mypy-expecttest.*] 28 | ignore_missing_imports = True 29 | 30 | [mypy-fsspec.*] 31 | ignore_missing_imports = True 32 | 33 | [mypy-google.*] 34 | ignore_missing_imports = True 35 | 36 | [mypy-graphviz.*] 37 | ignore_missing_imports = True 38 | 39 | [mypy-iopath.*] 40 | ignore_missing_imports = True 41 | 42 | [mypy-rarfile.*] 43 | ignore_missing_imports = True 44 | 45 | [mypy-scipy.*] 46 | ignore_missing_imports = True 47 | 48 | [mypy-setuptools.*] 49 | ignore_missing_imports = True 50 | 51 | [mypy-torcharrow.*] 52 | ignore_missing_imports = True 53 | 54 | [mypy-packaging.*] 55 | ignore_missing_imports = True 56 | 57 | [mypy-pandas.*] 58 | ignore_missing_imports = True 59 | 60 | [mypy-portalocker.*] 61 | ignore_missing_imports = True 62 | 63 | [mypy-psutil.*] 64 | ignore_missing_imports = True 65 | 66 | [mypy-pyarrow.*] 67 | ignore_missing_imports = True 68 | -------------------------------------------------------------------------------- /packaging/README.md: -------------------------------------------------------------------------------- 1 | # Build TorchData Release 2 | 3 | These are a collection of scripts that are to be used for release activities. 4 | 5 | ## Conda 6 | 7 | ### Release 8 | 9 | ```bash 10 | PYTHON_VERSION=3.9 PYTORCH_VERSION=1.11.0 packaging/build_conda.sh 11 | ``` 12 | 13 | ### Nightly 14 | 15 | ```bash 16 | PYTHON_VERSION=3.9 packaging/build_conda.sh 17 | ``` 18 | 19 | ## Wheel 20 | 21 | ### Release 22 | 23 | ```bash 24 | PYTHON_VERSION=3.9 PYTORCH_VERSION=1.11.0 packaging/build_wheel.sh 25 | ``` 26 | 27 | ### Nightly 28 | 29 | ```bash 30 | PYTHON_VERSION=3.9 packaging/build_wheel.sh 31 | ``` 32 | ## [`AWSSDK`](https://github.com/aws/aws-sdk-cpp) 33 | 34 | The following table is the corresponding `torchdata` binaries with pre-compiled `AWSSDK` extension on different operating systems. 35 | 36 | | `torchdata` | `Wheel` | `Conda` | 37 | | ------------------ | ------------------ | ------------------ | 38 | | Linux | :heavy_check_mark: | :heavy_check_mark: | 39 | | Windows | :heavy_check_mark: | :x: | 40 | | MacOS (x86_64) | :heavy_check_mark: | :heavy_check_mark: | 41 | | MacOS (arm64) | :heavy_check_mark: | :heavy_check_mark: | 42 | 43 | ### Manylinux 44 | 45 | `AWSSDK` requires OpenSSL and cURL. In order to provide `manylinux2014_x86_64` wheels with `AWSSDK` enabled, `torchdata` distributions are bundled with OpenSSL(1.1.1o) and cURL(7.38.1). If anything is out of date, please open an issue to request upgrading them. 46 | -------------------------------------------------------------------------------- /packaging/build_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 11 | . "$script_dir/pkg_helpers.bash" 12 | 13 | export CU_VERSION=cpu 14 | export NO_CUDA_PACKAGE=1 15 | export BUILD_TYPE="conda" 16 | 17 | export SOURCE_ROOT_DIR="$PWD" 18 | setup_env 19 | setup_conda_pytorch_constraint 20 | 21 | mkdir -p conda-bld 22 | conda build \ 23 | -c defaults \ 24 | $CONDA_CHANNEL_FLAGS \ 25 | --no-anaconda-upload \ 26 | --output-folder conda-bld \ 27 | --python "$PYTHON_VERSION" \ 28 | packaging/torchdata 29 | -------------------------------------------------------------------------------- /packaging/build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 11 | . "$script_dir/pkg_helpers.bash" 12 | 13 | export CU_VERSION=cpu 14 | export NO_CUDA_PACKAGE=1 15 | export BUILD_TYPE="wheel" 16 | 17 | export SOURCE_ROOT_DIR="$PWD" 18 | setup_env 19 | pip_install future wheel 20 | setup_pip_pytorch_version 21 | 22 | pip_install -r requirements.txt 23 | python setup.py clean 24 | # TODO: Add windows support 25 | python setup.py bdist_wheel 26 | -------------------------------------------------------------------------------- /packaging/env-var-script.txt: -------------------------------------------------------------------------------- 1 | export MACOSX_DEPLOYMENT_TARGET="10.13" 2 | -------------------------------------------------------------------------------- /packaging/manylinux/install_openssl_curl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | OPENSSL_URL="https://www.openssl.org/source/" 9 | OPENSSL_NAME="openssl-1.1.1o" 10 | OPENSSL_SHA256="9384a2b0570dd80358841464677115df785edb941c71211f75076d72fe6b438f" 11 | OPENSSL_BUILD_FLAGS="no-ssl2 no-zlib no-shared no-comp no-dynamic-engine enable-ec_nistp_64_gcc_128" 12 | 13 | CURL_URL="https://github.com/curl/curl/releases/download" 14 | CURL_NAME="curl-7.83.1" 15 | CURL_BUILD_FLAGS="--disable-shared" 16 | 17 | function check_sha256sum { 18 | local fname=$1 19 | local sha256=$2 20 | echo "${sha256} ${fname}" > ${fname}.sha256 21 | sha256sum -c ${fname}.sha256 22 | rm ${fname}.sha256 23 | } 24 | 25 | yum erase -y openssl-devel curl-devel 26 | 27 | pushd ${WORKSPACE} 28 | 29 | # OpenSSL 30 | curl -fsSL -o ${OPENSSL_NAME}.tar.gz ${OPENSSL_URL}/${OPENSSL_NAME}.tar.gz 31 | check_sha256sum ${OPENSSL_NAME}.tar.gz ${OPENSSL_SHA256} 32 | tar zxf ${OPENSSL_NAME}.tar.gz 33 | 34 | pushd ${OPENSSL_NAME} 35 | 36 | ./config $OPENSSL_BUILD_FLAGS --prefix=${WORKSPACE}/ssl --openssldir=${WORKSPACE}/ssl 37 | make -j4 > /dev/null 38 | # avoid installing the docs 39 | # https://github.com/openssl/openssl/issues/6685#issuecomment-403838728 40 | make install_sw > /dev/null 41 | 42 | popd 43 | rm -rf ${OPENSSL_NAME} ${OPENSSL_NAME}.tar.gz 44 | 45 | # cURL 46 | curl -fsSL -o ${CURL_NAME}.tar.gz ${CURL_URL}/${CURL_NAME//./_}/${CURL_NAME}.tar.gz 47 | tar zxf ${CURL_NAME}.tar.gz 48 | 49 | pushd ${CURL_NAME} 50 | 51 | ./configure ${CURL_BUILD_FLAGS} --with-openssl=${WORKSPACE}/ssl --prefix=${WORKSPACE}/curl 52 | make -j4 > /dev/null 53 | make install > /dev/null 54 | 55 | popd 56 | rm -rf ${CURL_NAME} ${CURL_NAME}.tar.gz 57 | 58 | popd 59 | -------------------------------------------------------------------------------- /packaging/manylinux/python_helper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | python_nodot="$(echo $PYTHON_VERSION | tr -d '.')" 9 | case $PYTHON_VERSION in 10 | 3.[6-7]*) 11 | DESIRED_PYTHON="cp${python_nodot}-cp${python_nodot}m" 12 | ;; 13 | 3.*) 14 | DESIRED_PYTHON="cp${python_nodot}-cp${python_nodot}" 15 | ;; 16 | esac 17 | 18 | pydir="/opt/python/$DESIRED_PYTHON" 19 | export PATH="$pydir/bin:$PATH" 20 | -------------------------------------------------------------------------------- /packaging/post_build_script_linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | pip3 install auditwheel pkginfo 5 | 6 | for pkg in dist/torchdata*.whl; do 7 | echo "PkgInfo of $pkg:" 8 | pkginfo $pkg 9 | 10 | auditwheel repair $pkg --plat manylinux2014_x86_64 -w wheelhouse 11 | 12 | pkg_name=`basename ${pkg%-linux_x86_64.whl}` 13 | auditwheel show wheelhouse/${pkg_name}-manylinux_2_17_x86_64.manylinux2014_x86_64.whl 14 | done 15 | -------------------------------------------------------------------------------- /packaging/pre_build_script_linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | source packaging/manylinux/python_helper.sh 5 | yum -y install zlib 6 | # Docker path is /__w by default 7 | export WORKSPACE="/__w" 8 | # Install static OpenSSL/libcrypto library 9 | ./packaging/manylinux/install_openssl_curl.sh 10 | 11 | python -m pip install cmake ninja 12 | -------------------------------------------------------------------------------- /packaging/torchdata/bld.bat: -------------------------------------------------------------------------------- 1 | @REM Copyright (c) Meta Platforms, Inc. and affiliates. 2 | @REM All rights reserved. 3 | @REM 4 | @REM This source code is licensed under the BSD-style license found in the 5 | @REM LICENSE file in the root directory of this source tree. 6 | 7 | @echo off 8 | 9 | git config --system core.longpaths true 10 | 11 | git submodule update --init --recursive 12 | if errorlevel 1 exit /b 1 13 | 14 | pip install . 15 | if errorlevel 1 exit /b 1 16 | -------------------------------------------------------------------------------- /packaging/torchdata/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | git submodule update --init --recursive 11 | pip install . 12 | -------------------------------------------------------------------------------- /packaging/torchdata/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: torchdata 3 | version: "{{ environ.get('BUILD_VERSION') }}" 4 | 5 | source: 6 | path: "{{ environ.get('SOURCE_ROOT_DIR') }}" 7 | 8 | requirements: 9 | # TODO: Figure out how to add build-time python dependency (PyTorch for codegen) 10 | build: 11 | - cmake 12 | - ninja 13 | - python 14 | - setuptools 15 | - cpuonly 16 | - curl # [not win] 17 | - openssl # [unix] 18 | - zlib # [unix] 19 | - pytorch>=2.0 20 | run: 21 | - python 22 | - urllib3>=1.25 23 | - requests 24 | - pytorch>=2.0 25 | 26 | build: 27 | string: py{{py}} 28 | script_env: 29 | - BUILD_VERSION 30 | 31 | test: 32 | imports: 33 | - torchdata 34 | - torchdata.stateful_dataloader 35 | source_files: 36 | - test 37 | requires: 38 | - cpuonly 39 | - pytest 40 | - expecttest 41 | # fsspec doesn't support Python 3.11 42 | # - fsspec 43 | # The following packages are not on the default conda channel 44 | # - iopath 45 | # - rarfile 46 | 47 | about: 48 | home: https://github.com/pytorch/data 49 | license: BSD 50 | license_file: LICENSE 51 | summary: "Common modular data loading primitives for easily constructing flexible and performant data pipelines for PyTorch users" 52 | doc_url: https://pytorch.org/data 53 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | 3 | first_party_detection = false 4 | 5 | [tool.black] 6 | 7 | line-length = 120 8 | target-version = ["py39"] 9 | 10 | [build-system] 11 | requires = [ 12 | "setuptools", 13 | "wheel", 14 | "ninja", 15 | "cmake", 16 | "torch", 17 | ] 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | urllib3 >= 1.25 2 | requests 3 | -------------------------------------------------------------------------------- /scripts/release_notes/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This file is adpated from PyTorch Core 8 | # https://github.com/pytorch/pytorch/blob/master/scripts/release_notes/common.py 9 | 10 | import json 11 | import locale 12 | import os 13 | import re 14 | import subprocess 15 | from collections import namedtuple 16 | 17 | import requests 18 | 19 | 20 | topics = [ 21 | "bc_breaking", 22 | "deprecations", 23 | "new_features", 24 | "improvements", 25 | "bug_fixes", 26 | "performance", 27 | "docs", 28 | "devs", 29 | "Untopiced", 30 | ] 31 | 32 | 33 | Features = namedtuple( 34 | "Features", 35 | [ 36 | "title", 37 | "body", 38 | "pr_number", 39 | "files_changed", 40 | "labels", 41 | ], 42 | ) 43 | 44 | 45 | def dict_to_features(dct): 46 | return Features( 47 | title=dct["title"], 48 | body=dct["body"], 49 | pr_number=dct["pr_number"], 50 | files_changed=dct["files_changed"], 51 | labels=dct["labels"], 52 | ) 53 | 54 | 55 | def features_to_dict(features): 56 | return dict(features._asdict()) 57 | 58 | 59 | def run(command): 60 | """Returns (return-code, stdout, stderr)""" 61 | p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) 62 | output, err = p.communicate() 63 | rc = p.returncode 64 | enc = locale.getpreferredencoding() 65 | output = output.decode(enc) 66 | err = err.decode(enc) 67 | return rc, output.strip(), err.strip() 68 | 69 | 70 | def commit_body(commit_hash): 71 | cmd = f"git log -n 1 --pretty=format:%b {commit_hash}" 72 | ret, out, err = run(cmd) 73 | return out if ret == 0 else None 74 | 75 | 76 | def commit_title(commit_hash): 77 | cmd = f"git log -n 1 --pretty=format:%s {commit_hash}" 78 | ret, out, err = run(cmd) 79 | return out if ret == 0 else None 80 | 81 | 82 | def commit_files_changed(commit_hash): 83 | cmd = f"git diff-tree --no-commit-id --name-only -r {commit_hash}" 84 | ret, out, err = run(cmd) 85 | return out.split("\n") if ret == 0 else None 86 | 87 | 88 | def parse_pr_number(body, commit_hash, title): 89 | regex = r"Pull Request resolved: https://github.com/pytorch/data/pull/([0-9]+)" 90 | matches = re.findall(regex, body) 91 | if len(matches) == 0: 92 | if "revert" not in title.lower() and "updating submodules" not in title.lower(): 93 | print(f"[{commit_hash}: {title}] Could not parse PR number, ignoring PR") 94 | return None 95 | if len(matches) > 1: 96 | print(f"[{commit_hash}: {title}] Got two PR numbers, using the first one") 97 | return matches[0] 98 | return matches[0] 99 | 100 | 101 | def get_ghstack_token(): 102 | pattern = "github_oauth = (.*)" 103 | with open(os.path.expanduser("~/.ghstackrc"), "r+") as f: 104 | config = f.read() 105 | matches = re.findall(pattern, config) 106 | if len(matches) == 0: 107 | raise RuntimeError("Can't find a github oauth token") 108 | return matches[0] 109 | 110 | 111 | token = get_ghstack_token() 112 | headers = {"Authorization": f"token {token}"} 113 | 114 | 115 | def run_query(query): 116 | request = requests.post("https://api.github.com/graphql", json={"query": query}, headers=headers) 117 | if request.status_code == 200: 118 | return request.json() 119 | else: 120 | raise Exception(f"Query failed to run by returning code of {request.status_code}. {query}") 121 | 122 | 123 | def gh_labels(pr_number): 124 | query = f""" 125 | {{ 126 | repository(owner: "pytorch", name: "data") {{ 127 | pullRequest(number: {pr_number}) {{ 128 | labels(first: 10) {{ 129 | edges {{ 130 | node {{ 131 | name 132 | }} 133 | }} 134 | }} 135 | }} 136 | }} 137 | }} 138 | """ 139 | query = run_query(query) 140 | edges = query["data"]["repository"]["pullRequest"]["labels"]["edges"] 141 | return [edge["node"]["name"] for edge in edges] 142 | 143 | 144 | def get_features(commit_hash, return_dict=False): 145 | title, body, files_changed = ( 146 | commit_title(commit_hash), 147 | commit_body(commit_hash), 148 | commit_files_changed(commit_hash), 149 | ) 150 | pr_number = parse_pr_number(body, commit_hash, title) 151 | labels = [] 152 | if pr_number is not None: 153 | labels = gh_labels(pr_number) 154 | result = Features(title, body, pr_number, files_changed, labels) 155 | if return_dict: 156 | return features_to_dict(result) 157 | return result 158 | 159 | 160 | class CommitDataCache: 161 | def __init__(self, path="results/data.json"): 162 | self.path = path 163 | self.data = {} 164 | if os.path.exists(path): 165 | self.data = self.read_from_disk() 166 | 167 | def get(self, commit): 168 | if commit not in self.data.keys(): 169 | # Fetch and cache the data 170 | self.data[commit] = get_features(commit) 171 | self.write_to_disk() 172 | return self.data[commit] 173 | 174 | def read_from_disk(self): 175 | with open(self.path) as f: 176 | data = json.load(f) 177 | data = {commit: dict_to_features(dct) for commit, dct in data.items()} 178 | return data 179 | 180 | def write_to_disk(self): 181 | data = {commit: features._asdict() for commit, features in self.data.items()} 182 | with open(self.path, "w") as f: 183 | json.dump(data, f) 184 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import distutils.command.clean 9 | import os 10 | import shutil 11 | import subprocess 12 | import sys 13 | 14 | from pathlib import Path 15 | 16 | from setuptools import find_packages, setup 17 | 18 | from tools.setup_helpers.extension import get_ext_modules 19 | 20 | ROOT_DIR = Path(__file__).parent.resolve() 21 | 22 | 23 | ################################################################################ 24 | # Parameters parsed from environment 25 | ################################################################################ 26 | RUN_BUILD_DEP = True 27 | for _, arg in enumerate(sys.argv): 28 | if arg in ["clean", "egg_info", "sdist"]: 29 | RUN_BUILD_DEP = False 30 | 31 | 32 | def _get_version(): 33 | with open(os.path.join(ROOT_DIR, "version.txt")) as f: 34 | version = f.readline().strip() 35 | 36 | sha = "Unknown" 37 | try: 38 | sha = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=str(ROOT_DIR)).decode("ascii").strip() 39 | except Exception: 40 | pass 41 | 42 | os_build_version = os.getenv("BUILD_VERSION") 43 | if os_build_version: 44 | version = os_build_version 45 | elif sha != "Unknown": 46 | version += "+" + sha[:7] 47 | 48 | return version, sha 49 | 50 | 51 | def _export_version(version, sha): 52 | version_path = ROOT_DIR / "torchdata" / "version.py" 53 | with open(version_path, "w") as f: 54 | f.write(f"__version__ = '{version}'\n") 55 | f.write(f"git_version = {repr(sha)}\n") 56 | 57 | 58 | def _get_requirements(): 59 | req_list = [] 60 | with Path("requirements.txt").open("r") as f: 61 | for line in f: 62 | req = line.strip() 63 | if len(req) == 0 or req.startswith("#"): 64 | continue 65 | req_list.append(req) 66 | return req_list 67 | 68 | 69 | # Use new version of torch on main branch 70 | pytorch_package_dep = "torch>=2" 71 | 72 | requirements = _get_requirements() 73 | requirements.append(pytorch_package_dep) 74 | 75 | 76 | class clean(distutils.command.clean.clean): 77 | def run(self): 78 | # Run default behavior first 79 | distutils.command.clean.clean.run(self) 80 | 81 | # Remove torchdata extension 82 | def remove_extension(pattern): 83 | for path in (ROOT_DIR / "torchdata").glob(pattern): 84 | print(f"removing extension '{path}'") 85 | path.unlink() 86 | 87 | for ext in ["so", "dylib", "pyd"]: 88 | remove_extension("**/*." + ext) 89 | 90 | # Remove build directory 91 | build_dirs = [ 92 | ROOT_DIR / "build", 93 | ] 94 | for path in build_dirs: 95 | if path.exists(): 96 | print(f"removing '{path}' (and everything under it)") 97 | shutil.rmtree(str(path), ignore_errors=True) 98 | 99 | 100 | if __name__ == "__main__": 101 | VERSION, SHA = _get_version() 102 | _export_version(VERSION, SHA) 103 | 104 | print("-- Building version " + VERSION) 105 | setup( 106 | # Metadata 107 | name="torchdata", 108 | version=VERSION, 109 | description="Composable data loading modules for PyTorch", 110 | long_description=Path("README.md").read_text(encoding="utf-8"), 111 | long_description_content_type="text/markdown", 112 | url="https://github.com/pytorch/data", 113 | author="PyTorch Team", 114 | author_email="packages@pytorch.org", 115 | license="BSD", 116 | install_requires=requirements, 117 | python_requires=">=3.9", 118 | classifiers=[ 119 | "Intended Audience :: Developers", 120 | "Intended Audience :: Science/Research", 121 | "License :: OSI Approved :: BSD License", 122 | "Operating System :: MacOS :: MacOS X", 123 | "Operating System :: Microsoft :: Windows", 124 | "Programming Language :: Python :: 3.9", 125 | "Programming Language :: Python :: 3.10", 126 | "Programming Language :: Python :: 3.11", 127 | "Programming Language :: Python :: 3.12", 128 | "Programming Language :: Python :: 3.13", 129 | "Programming Language :: Python :: Implementation :: CPython", 130 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 131 | ], 132 | # Package Info 133 | packages=find_packages(exclude=["test*", "examples*", "tools*", "build*"]), 134 | zip_safe=False, 135 | # C++ Extension Modules 136 | ext_modules=get_ext_modules(), 137 | cmdclass={"clean": clean}, 138 | ) 139 | -------------------------------------------------------------------------------- /test/_fakedata/README.md: -------------------------------------------------------------------------------- 1 | # Fake Data 2 | 3 | This directory is only used for testing purpose. 4 | 5 | ## Data Structure 6 | 7 | ### Files 8 | 9 | | Folder | File | Data | 10 | | ------ | ------- | -------------------------------------------------- | 11 | | bytes | fn.bt | b'fn_0123456789abcdef' | 12 | | csv | fn.csv | key,item
0,fn_0
1,fn_1 | 13 | | json | fn.json | {'fn_0': [{'fn\_01': 1}, {'fn\_02': 2}],'fn_1': 1} | 14 | | txt | fn.txt | 'fn_0123456789abcdef' | 15 | 16 | ### Archives 17 | 18 | Each subfolder has corresponding archive files (uncompressed and compressed). 19 | -------------------------------------------------------------------------------- /test/_fakedata/_create_fake_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import tarfile 9 | 10 | 11 | NUMBER_OF_FILES = 3 12 | FILES = [ 13 | ("bytes", "bt", "{fn}_0123456789abcdef\n", True), 14 | ("csv", "csv", "key,item\n0,{fn}_0\n1,{fn}_1\n"), 15 | ("json", "json", '{{"{fn}_0": [{{"{fn}_01": 1}}, {{"{fn}_02": 2}}], "{fn}_1": 1}}\n'), 16 | ("txt", "txt", "{fn}_0123456789abcdef\n"), 17 | ] 18 | 19 | 20 | def create_files(folder, suffix, data, encoding=False): 21 | os.makedirs(folder, exist_ok=True) 22 | for i in range(NUMBER_OF_FILES): 23 | fn = str(i) 24 | d = data.format(fn=fn) 25 | mode = "wb" if encoding else "wt" 26 | if encoding: 27 | d = d.encode() 28 | with open(folder + "/" + fn + "." + suffix, mode) as f: 29 | f.write(d) 30 | 31 | with tarfile.open(folder + ".tar", mode="w") as archive: 32 | archive.add(folder) 33 | 34 | with tarfile.open(folder + ".tar.gz", mode="w:gz") as archive: 35 | archive.add(folder) 36 | 37 | 38 | def create_tfrecord_files(path: str): 39 | try: 40 | import tensorflow as tf 41 | except ImportError: 42 | print("TensorFlow not found!") 43 | print("We will not generate tfrecord files.") 44 | return 45 | 46 | os.makedirs(path, exist_ok=True) 47 | with tf.io.TFRecordWriter(os.path.join(path, "example.tfrecord")) as writer: 48 | for i in range(4): 49 | x = tf.range(i * 10, (i + 1) * 10) 50 | record_bytes = tf.train.Example( 51 | features=tf.train.Features( 52 | feature={ 53 | "x_float": tf.train.Feature(float_list=tf.train.FloatList(value=x)), 54 | "x_int": tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64"))), 55 | "x_byte": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"])), 56 | } 57 | ) 58 | ).SerializeToString() 59 | writer.write(record_bytes) 60 | 61 | with tf.io.TFRecordWriter(os.path.join(path, "sequence_example.tfrecord")) as writer: 62 | for i in range(4): 63 | x = tf.range(i * 10, (i + 1) * 10) 64 | rep = 2 * i + 3 65 | 66 | record_bytes = tf.train.SequenceExample( 67 | context=tf.train.Features( 68 | feature={ 69 | "x_float": tf.train.Feature(float_list=tf.train.FloatList(value=x)), 70 | "x_int": tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64"))), 71 | "x_byte": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"])), 72 | } 73 | ), 74 | feature_lists=tf.train.FeatureLists( 75 | feature_list={ 76 | "x_float_seq": tf.train.FeatureList( 77 | feature=[tf.train.Feature(float_list=tf.train.FloatList(value=x))] * rep 78 | ), 79 | "x_int_seq": tf.train.FeatureList( 80 | feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=tf.cast(x * 10, "int64")))] 81 | * rep 82 | ), 83 | "x_byte_seq": tf.train.FeatureList( 84 | feature=[tf.train.Feature(bytes_list=tf.train.BytesList(value=[b"test str"]))] * rep 85 | ), 86 | } 87 | ), 88 | ).SerializeToString() 89 | writer.write(record_bytes) 90 | 91 | 92 | if __name__ == "__main__": 93 | for args in FILES: 94 | create_files(*args) 95 | create_tfrecord_files("tfrecord") 96 | -------------------------------------------------------------------------------- /test/_fakedata/bytes.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/test/_fakedata/bytes.tar.gz -------------------------------------------------------------------------------- /test/_fakedata/bytes/0.bt: -------------------------------------------------------------------------------- 1 | 0_0123456789abcdef 2 | -------------------------------------------------------------------------------- /test/_fakedata/bytes/1.bt: -------------------------------------------------------------------------------- 1 | 1_0123456789abcdef 2 | -------------------------------------------------------------------------------- /test/_fakedata/bytes/2.bt: -------------------------------------------------------------------------------- 1 | 2_0123456789abcdef 2 | -------------------------------------------------------------------------------- /test/_fakedata/csv.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/test/_fakedata/csv.tar.gz -------------------------------------------------------------------------------- /test/_fakedata/csv/0.csv: -------------------------------------------------------------------------------- 1 | key,item 2 | 0,0_0 3 | 1,0_1 4 | -------------------------------------------------------------------------------- /test/_fakedata/csv/1.csv: -------------------------------------------------------------------------------- 1 | key,item 2 | 0,1_0 3 | 1,1_1 4 | -------------------------------------------------------------------------------- /test/_fakedata/csv/2.csv: -------------------------------------------------------------------------------- 1 | key,item 2 | 0,2_0 3 | 1,2_1 4 | -------------------------------------------------------------------------------- /test/_fakedata/json.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/test/_fakedata/json.tar.gz -------------------------------------------------------------------------------- /test/_fakedata/json/0.json: -------------------------------------------------------------------------------- 1 | {"0_0": [{"0_01": 1}, {"0_02": 2}], "0_1": 1} 2 | -------------------------------------------------------------------------------- /test/_fakedata/json/1.json: -------------------------------------------------------------------------------- 1 | {"1_0": [{"1_01": 1}, {"1_02": 2}], "1_1": 1} 2 | -------------------------------------------------------------------------------- /test/_fakedata/json/2.json: -------------------------------------------------------------------------------- 1 | {"2_0": [{"2_01": 1}, {"2_02": 2}], "2_1": 1} 2 | -------------------------------------------------------------------------------- /test/_fakedata/tfrecord/example.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/test/_fakedata/tfrecord/example.tfrecord -------------------------------------------------------------------------------- /test/_fakedata/tfrecord/sequence_example.tfrecord: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/test/_fakedata/tfrecord/sequence_example.tfrecord -------------------------------------------------------------------------------- /test/_fakedata/txt.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/test/_fakedata/txt.tar.gz -------------------------------------------------------------------------------- /test/_fakedata/txt/0.txt: -------------------------------------------------------------------------------- 1 | 0_0123456789abcdef 2 | -------------------------------------------------------------------------------- /test/_fakedata/txt/1.txt: -------------------------------------------------------------------------------- 1 | 1_0123456789abcdef 2 | -------------------------------------------------------------------------------- /test/_fakedata/txt/2.txt: -------------------------------------------------------------------------------- 1 | 2_0123456789abcdef 2 | -------------------------------------------------------------------------------- /test/_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /test/_utils/_common_utils_for_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import hashlib 8 | import os 9 | import platform 10 | import sys 11 | import tempfile 12 | from typing import List, Tuple, TypeVar 13 | 14 | T_co = TypeVar("T_co", covariant=True) 15 | 16 | 17 | IS_LINUX = sys.platform == "linux" 18 | IS_WINDOWS = sys.platform == "win32" 19 | IS_MACOS = sys.platform == "darwin" 20 | 21 | IS_M1 = IS_MACOS and "arm" in platform.platform() 22 | 23 | 24 | def get_name(path_and_stream): 25 | return os.path.basename(path_and_stream[0]), path_and_stream[1] 26 | 27 | 28 | def create_temp_dir(dir=None): 29 | # The temp dir and files within it will be released and deleted in tearDown(). 30 | # Adding `noqa: P201` to avoid mypy's warning on not releasing the dir handle within this function. 31 | temp_dir = tempfile.TemporaryDirectory(dir=dir) # noqa: P201 32 | return temp_dir 33 | 34 | 35 | def create_temp_files(temp_dir, prefix=1, empty=True): 36 | temp_dir_path = temp_dir.name 37 | 38 | with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, prefix=str(prefix), suffix=".txt") as f: 39 | temp_file1_name = f.name 40 | with open(temp_file1_name, "w") as f1: 41 | f1.write("0123456789abcdef") 42 | 43 | with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, prefix=str(prefix + 1), suffix=".byte") as f: 44 | temp_file2_name = f.name 45 | with open(temp_file2_name, "wb") as f2: 46 | f2.write(b"0123456789abcdef") 47 | 48 | if empty: 49 | with tempfile.NamedTemporaryFile(dir=temp_dir_path, delete=False, prefix=str(prefix + 2), suffix=".empty") as f: 50 | temp_file3_name = f.name 51 | return temp_file1_name, temp_file2_name, temp_file3_name 52 | 53 | return temp_file1_name, temp_file2_name 54 | 55 | 56 | def check_hash_fn(filepath, expected_hash, hash_type="md5"): 57 | 58 | if hash_type == "sha256": 59 | hash_fn = hashlib.sha256() 60 | elif hash_type == "md5": 61 | hash_fn = hashlib.md5() 62 | else: 63 | raise ValueError("Invalid hash_type requested, should be one of {}".format(["sha256", "md5"])) 64 | 65 | with open(filepath, "rb") as f: 66 | chunk = f.read(1024 ** 2) 67 | while chunk: 68 | hash_fn.update(chunk) 69 | chunk = f.read(1024 ** 2) 70 | 71 | return hash_fn.hexdigest() == expected_hash 72 | -------------------------------------------------------------------------------- /test/nodes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/test/nodes/__init__.py -------------------------------------------------------------------------------- /test/nodes/test_adapters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Iterator 8 | 9 | from parameterized import parameterized 10 | from torch.testing._internal.common_utils import TestCase 11 | 12 | from torch.utils.data import DistributedSampler, RandomSampler 13 | from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper 14 | 15 | from .utils import DummyIterableDataset, DummyMapDataset, run_test_save_load_state, StatefulRange 16 | 17 | 18 | class TestIterableWrapper(TestCase): 19 | def test_iterable(self): 20 | n = 20 21 | node = IterableWrapper(range(n)) 22 | for epoch in range(2): 23 | node.reset() 24 | result = list(node) 25 | self.assertEqual(len(result), n) 26 | for i, j in enumerate(result): 27 | self.assertEqual(j, i) 28 | 29 | def test_generator(self): 30 | n = 20 31 | node = IterableWrapper(f"str_{i}" for i in range(n)) 32 | result = list(node) 33 | self.assertEqual(len(result), n) 34 | for i, j in enumerate(result): 35 | self.assertEqual(j, f"str_{i}") 36 | 37 | # Second time iter is called on generator will raise StopIteration 38 | result = list(node) 39 | self.assertEqual(len(result), 0) 40 | 41 | def test_iterable_dataset(self): 42 | n = 20 43 | node = IterableWrapper(DummyIterableDataset(n, name="test")) 44 | for epoch in range(2): 45 | node.reset() 46 | result = list(node) 47 | self.assertEqual(len(result), n) 48 | for i, row in enumerate(result): 49 | self.assertEqual(row["step"], i) 50 | self.assertEqual(row["test_tensor"].item(), i) 51 | self.assertEqual(row["test_str"], f"str_{i}") 52 | 53 | @parameterized.expand([0, 5]) 54 | def test_save_load_state_fast_forward(self, midpoint: int): 55 | run_test_save_load_state(self, IterableWrapper(range(10)), midpoint) 56 | 57 | @parameterized.expand([0, 5]) 58 | def test_save_load_state_stateful(self, midpoint: int): 59 | run_test_save_load_state(self, IterableWrapper(StatefulRange(10)), midpoint) 60 | 61 | 62 | class TestMapStyle(TestCase): 63 | def test_default_sampler(self): 64 | n = 20 65 | node = MapStyleWrapper(DummyMapDataset(n), sampler=range(n)) 66 | for epoch in range(2): 67 | node.reset() 68 | result = list(node) 69 | self.assertEqual(len(result), n) 70 | for i, row in enumerate(result): 71 | self.assertEqual(row["step"], i) 72 | self.assertEqual(row["test_tensor"].item(), i) 73 | self.assertEqual(row["test_str"], f"str_{i}") 74 | 75 | def test_random_sampler(self): 76 | n = 20 77 | ds = DummyMapDataset(n) 78 | node = MapStyleWrapper(ds, sampler=RandomSampler(ds)) 79 | results = [] 80 | for epoch in range(2): 81 | node.reset() 82 | result = list(node) 83 | results.append(result) 84 | self.assertEqual(len(result), n) 85 | self.assertEqual({row["step"] for row in result}, set(range(n))) 86 | self.assertEqual({row["test_tensor"].item() for row in result}, set(range(n))) 87 | self.assertEqual( 88 | {row["test_str"] for row in result}, 89 | {f"str_{i}" for i in range(n)}, 90 | ) 91 | 92 | self.assertNotEqual(results[0], results[1]) # Should have different values per epoch 93 | 94 | def test_dict(self): 95 | n = 20 96 | orig_ds = DummyMapDataset(n) 97 | d = {f"i{i}": orig_ds[i] for i in range(n)} 98 | sampler = list(d.keys()) 99 | node = MapStyleWrapper(d, sampler=sampler) 100 | for epoch in range(2): 101 | node.reset() 102 | result = list(node) 103 | self.assertEqual(len(result), n) 104 | for i, row in enumerate(result): 105 | self.assertEqual(row["step"], i) 106 | self.assertEqual(row["test_tensor"].item(), i) 107 | self.assertEqual(row["test_str"], f"str_{i}") 108 | 109 | @parameterized.expand([0, 7]) 110 | def test_save_load_state_fast_forward(self, midpoint: int): 111 | n = 20 112 | node = MapStyleWrapper(DummyMapDataset(n), sampler=range(n)) 113 | run_test_save_load_state(self, node, midpoint) 114 | 115 | @parameterized.expand([0, 7]) 116 | def test_save_load_state_stateful(self, midpoint: int): 117 | n = 20 118 | node = MapStyleWrapper(DummyMapDataset(n), sampler=StatefulRange(n)) 119 | run_test_save_load_state(self, node, midpoint) 120 | 121 | 122 | class TestSamplerWrapper(TestCase): 123 | def test_sampler_wrapper(self): 124 | n = 20 125 | ds = DummyMapDataset(n) 126 | 127 | node = SamplerWrapper(sampler=RandomSampler(ds)) 128 | 129 | results = [] 130 | for epoch in range(2): 131 | node.reset() 132 | self.assertEqual(node.epoch, epoch) 133 | result = list(node) 134 | results.append(result) 135 | self.assertEqual(len(result), n) 136 | self.assertEqual(set(result), set(range(n))) 137 | 138 | self.assertNotEqual(results[0], results[1]) 139 | 140 | def test_distributed_sampler(self): 141 | # Distributed sampler has set_epoch method 142 | n = 40 143 | ds = DummyMapDataset(n) 144 | 145 | sampler = DistributedSampler(ds, rank=1, num_replicas=2) 146 | exp = [] 147 | for epoch in range(4): 148 | sampler.set_epoch(epoch) 149 | exp.append(list(sampler)) 150 | 151 | node = SamplerWrapper(sampler=sampler) 152 | 153 | for epoch in range(4): 154 | node.reset() 155 | result = list(node) 156 | self.assertEqual(result, exp[epoch]) 157 | 158 | @parameterized.expand([0, 7]) 159 | def test_save_load_state(self, midpoint: int): 160 | n = 20 161 | ds = DummyMapDataset(n) 162 | sampler = DistributedSampler(ds, rank=1, num_replicas=2) 163 | node = SamplerWrapper(sampler=sampler) 164 | run_test_save_load_state(self, node, midpoint) 165 | 166 | @parameterized.expand([0, 7]) 167 | def test_save_load_state_with_updater(self, midpoint: int): 168 | n = 20 169 | ds = DummyMapDataset(n) 170 | initial_epoch = 2 171 | 172 | def epoch_updater(epoch): 173 | return epoch + 5 174 | 175 | sampler = DistributedSampler(ds, rank=1, num_replicas=2) 176 | node = SamplerWrapper(sampler=sampler, initial_epoch=initial_epoch, epoch_updater=epoch_updater) 177 | run_test_save_load_state(self, node, midpoint) 178 | -------------------------------------------------------------------------------- /test/nodes/test_base_node.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.testing._internal.common_utils import TestCase 8 | from torchdata.nodes.adapters import IterableWrapper 9 | 10 | from .utils import run_test_save_load_state 11 | 12 | 13 | class TestBaseNode(TestCase): 14 | def test_save_load_state(self): 15 | run_test_save_load_state(self, IterableWrapper(range(10)), 5) 16 | -------------------------------------------------------------------------------- /test/nodes/test_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | 9 | import torch 10 | from parameterized import parameterized 11 | from torch.testing._internal.common_utils import TestCase 12 | from torchdata.nodes.batch import Batcher, Unbatcher 13 | 14 | from .utils import MockSource, run_test_save_load_state 15 | 16 | 17 | class TestBatcher(TestCase): 18 | def test_batcher(self) -> None: 19 | batch_size = 6 20 | src = MockSource(num_samples=20) 21 | node = Batcher(src, batch_size=batch_size, drop_last=True) 22 | 23 | results = list(node) 24 | self.assertEqual(len(results), 3) 25 | for i in range(3): 26 | for j in range(batch_size): 27 | self.assertEqual(results[i][j]["step"], i * batch_size + j) 28 | self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) 29 | self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") 30 | 31 | def test_batcher_drop_last_false(self) -> None: 32 | batch_size = 6 33 | src = MockSource(num_samples=20) 34 | root = Batcher(src, batch_size=batch_size, drop_last=False) 35 | 36 | results = list(root) 37 | self.assertEqual(len(results), 4) 38 | for i in range(4): 39 | n = batch_size if i < 3 else 2 40 | for j in range(n): 41 | self.assertEqual(results[i][j]["step"], i * batch_size + j) 42 | self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) 43 | self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") 44 | 45 | @parameterized.expand(itertools.product([0, 2], [True, False])) 46 | def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool): 47 | batch_size = 6 48 | src = MockSource(num_samples=20) 49 | node = Batcher(src, batch_size=batch_size, drop_last=drop_last) 50 | run_test_save_load_state(self, node, midpoint) 51 | 52 | 53 | class TestUnbatcher(TestCase): 54 | def test_unbatcher(self) -> None: 55 | batch_size = 6 56 | n = 20 57 | src = MockSource(num_samples=n) 58 | node = Batcher(src, batch_size=batch_size, drop_last=False) 59 | node = Unbatcher(node) 60 | 61 | results = list(node) 62 | self.assertEqual(len(results), n) 63 | for i in range(n): 64 | self.assertEqual(results[i]["step"], i) 65 | self.assertEqual(results[i]["test_tensor"], torch.tensor([i])) 66 | self.assertEqual(results[i]["test_str"], f"str_{i}") 67 | 68 | @parameterized.expand(itertools.product([0, 2], [True, False])) 69 | def test_save_load_state_fast_forward(self, midpoint: int, drop_last: bool): 70 | batch_size = 6 71 | src = MockSource(num_samples=20) 72 | node = Batcher(src, batch_size=batch_size, drop_last=drop_last) 73 | node = Unbatcher(node) 74 | run_test_save_load_state(self, node, midpoint) 75 | -------------------------------------------------------------------------------- /test/nodes/test_header.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | 9 | from parameterized import parameterized 10 | from torch.testing._internal.common_utils import TestCase 11 | from torchdata.nodes import Header 12 | from torchdata.nodes.adapters import IterableWrapper 13 | 14 | from .utils import MockSource, run_test_save_load_state, StatefulRangeNode 15 | 16 | 17 | class TestHeader(TestCase): 18 | def test_header_basic(self) -> None: 19 | # Test with a simple range 20 | source = IterableWrapper(range(10)) 21 | node = Header(source, n=5) 22 | 23 | results = list(node) 24 | self.assertEqual(results, [0, 1, 2, 3, 4]) 25 | 26 | # Verify counter 27 | self.assertEqual(node._num_yielded, 5) 28 | 29 | # Test with n larger than source 30 | source = IterableWrapper(range(3)) 31 | node = Header(source, n=10) 32 | 33 | results = list(node) 34 | self.assertEqual(results, [0, 1, 2]) 35 | 36 | # Verify counter with n larger than source 37 | self.assertEqual(node._num_yielded, 3) 38 | 39 | # Test with n=0 (should yield nothing) 40 | source = IterableWrapper(range(10)) 41 | node = Header(source, n=0) 42 | 43 | results = list(node) 44 | self.assertEqual(results, []) 45 | 46 | # Verify counter with n=0 47 | self.assertEqual(node._num_yielded, 0) 48 | 49 | def test_header_with_mock_source(self) -> None: 50 | num_samples = 20 51 | source = MockSource(num_samples=num_samples) 52 | node = Header(source, n=7) # Limit to first 7 items 53 | 54 | # Test multi epoch 55 | for _ in range(2): 56 | node.reset() 57 | results = list(node) 58 | self.assertEqual(len(results), 7) 59 | 60 | # Verify counter after each epoch 61 | self.assertEqual(node._num_yielded, 7) 62 | 63 | for i, result in enumerate(results): 64 | expected_step = i 65 | self.assertEqual(result["step"], expected_step) 66 | self.assertEqual(result["test_tensor"].item(), expected_step) 67 | self.assertEqual(result["test_str"], f"str_{expected_step}") 68 | 69 | def test_header_empty_source(self) -> None: 70 | source = IterableWrapper([]) 71 | node = Header(source, n=5) 72 | 73 | results = list(node) 74 | self.assertEqual(results, []) 75 | 76 | # Verify counter with empty source 77 | self.assertEqual(node._num_yielded, 0) 78 | 79 | @parameterized.expand(itertools.product([0, 3, 7])) 80 | def test_save_load_state(self, midpoint: int) -> None: 81 | n = 50 82 | source = StatefulRangeNode(n=n) 83 | node = Header(source, n=20) # Limit to first 20 items 84 | run_test_save_load_state(self, node, midpoint) 85 | 86 | def test_header_reset_state(self) -> None: 87 | source = IterableWrapper(range(10)) 88 | node = Header(source, n=5) 89 | 90 | # Consume first two items 91 | self.assertEqual(next(node), 0) 92 | self.assertEqual(next(node), 1) 93 | 94 | # Check counter after consuming two items 95 | self.assertEqual(node._num_yielded, 2) 96 | 97 | # Get state and reset 98 | state = node.state_dict() 99 | node.reset(state) 100 | 101 | # Counter should be preserved after reset with state 102 | self.assertEqual(node._num_yielded, 2) 103 | 104 | # Should continue from where we left off 105 | self.assertEqual(next(node), 2) 106 | self.assertEqual(next(node), 3) 107 | self.assertEqual(next(node), 4) 108 | 109 | # Counter should be updated after consuming more items 110 | self.assertEqual(node._num_yielded, 5) 111 | 112 | # Should raise StopIteration after all items are consumed 113 | with self.assertRaises(StopIteration): 114 | next(node) 115 | 116 | def test_counter_reset(self) -> None: 117 | # Test that counter is properly reset 118 | source = IterableWrapper(range(10)) 119 | node = Header(source, n=5) 120 | 121 | # Consume all items 122 | list(node) 123 | 124 | # Verify counter after first pass 125 | self.assertEqual(node._num_yielded, 5) 126 | 127 | # Reset without state 128 | node.reset() 129 | 130 | # Counter should be reset to 0 131 | self.assertEqual(node._num_yielded, 0) 132 | 133 | # Consume some items 134 | next(node) # 0 135 | next(node) # 1 136 | 137 | # Verify counter after partial consumption 138 | self.assertEqual(node._num_yielded, 2) 139 | 140 | def test_invalid_input(self) -> None: 141 | # Test with negative n 142 | source = IterableWrapper(range(10)) 143 | with self.assertRaises(ValueError): 144 | Header(source, n=-1) 145 | -------------------------------------------------------------------------------- /test/nodes/test_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.testing._internal.common_utils import TestCase 8 | from torchdata.nodes.adapters import IterableWrapper, MapStyleWrapper 9 | from torchdata.nodes.base_node import BaseNode 10 | from torchdata.nodes.loader import Loader 11 | 12 | from .utils import DummyIterableDataset, DummyMapDataset, StatefulRange 13 | 14 | 15 | class TestLoader(TestCase): 16 | def _test_loader_correct_state_dict_at_midpoint(self, node: BaseNode, length: int): 17 | x = Loader(node) 18 | results = list(x) 19 | 20 | # Create an iterator at end of iteration 21 | it = iter(x) 22 | 23 | results_copy = [] 24 | for _ in range(length // 2): 25 | results_copy.append(next(it)) 26 | state_dict_0 = x.state_dict() 27 | 28 | x.load_state_dict(state_dict_0) 29 | 30 | # Create an iterator in the middle of iteration 31 | it = iter(x) 32 | 33 | self.assertEqual(x.state_dict(), state_dict_0) 34 | 35 | for i in range(length // 2): 36 | results_copy.append(next(it)) 37 | 38 | self.assertEqual(len(results), length) 39 | self.assertEqual(len(results_copy), length) 40 | self.assertEqual(results[length // 2 :], results_copy[length // 2 :]) 41 | 42 | def test_loader_equal_state_dict_on_save_load_iterable(self) -> None: 43 | length = 10 44 | node = IterableWrapper(DummyIterableDataset(length)) 45 | self._test_loader_correct_state_dict_at_midpoint(node, length) 46 | 47 | def test_loader_equal_state_dict_on_save_load_stateful(self) -> None: 48 | length = 10 49 | node = IterableWrapper(StatefulRange(length)) 50 | self._test_loader_correct_state_dict_at_midpoint(node, length) 51 | 52 | def test_loader_equal_state_dict_on_save_load_map(self) -> None: 53 | length = 10 54 | node = MapStyleWrapper(DummyMapDataset(length), sampler=range(length)) 55 | self._test_loader_correct_state_dict_at_midpoint(node, length) 56 | -------------------------------------------------------------------------------- /test/nodes/test_pin_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | import unittest 9 | 10 | import torch 11 | 12 | from parameterized import parameterized 13 | 14 | from torch.testing._internal.common_utils import TEST_CUDA, TestCase 15 | 16 | from torchdata.nodes.batch import Batcher 17 | from torchdata.nodes.map import Mapper 18 | from torchdata.nodes.pin_memory import PinMemory 19 | from torchdata.nodes.prefetch import Prefetcher 20 | 21 | from .utils import Collate, IterInitError, MockSource, run_test_save_load_state, StatefulRangeNode 22 | 23 | 24 | @unittest.skipIf(not TEST_CUDA, "CUDA unavailable") 25 | class TestPinMemory(TestCase): 26 | def test_pin_memory(self) -> None: 27 | batch_size = 6 28 | src = MockSource(num_samples=20) 29 | node = Batcher(src, batch_size=batch_size) 30 | node = Mapper(node, Collate()) 31 | node = PinMemory(node) 32 | root = Prefetcher(node, prefetch_factor=2) 33 | 34 | # 2 epochs 35 | for epoch in range(2): 36 | root.reset() 37 | results = list(root) 38 | self.assertEqual(len(results), 3, epoch) 39 | for i in range(3): 40 | for j in range(batch_size): 41 | self.assertEqual(results[i]["step"][j], i * batch_size + j) 42 | self.assertEqual(results[i]["test_tensor"][j], torch.tensor([i * batch_size + j])) 43 | self.assertEqual(results[i]["test_str"][j], f"str_{i * batch_size + j}") 44 | 45 | def test_exception_handling(self): 46 | class PinMemoryFails: 47 | def pin_memory(self): 48 | raise ValueError("test exception") 49 | 50 | batch_size = 6 51 | src = MockSource(num_samples=20) 52 | node = Mapper(src, lambda x: dict(fail=PinMemoryFails(), **x)) 53 | node = Batcher(node, batch_size=batch_size) 54 | node = Mapper(node, Collate()) 55 | node = PinMemory(node) 56 | root = Prefetcher(node, prefetch_factor=2) 57 | 58 | with self.assertRaisesRegex(ValueError, "test exception"): 59 | list(root) 60 | 61 | def test_iter_init_error(self): 62 | node = IterInitError() 63 | node = PinMemory(node) 64 | root = Prefetcher(node, prefetch_factor=2) 65 | 66 | with self.assertRaisesRegex(ValueError, "Iter Init Error"): 67 | list(root) 68 | 69 | @parameterized.expand(itertools.product([0, 7, 33], [0, 1, 9])) 70 | def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int): 71 | batch_size = 6 72 | n = 200 73 | node = StatefulRangeNode(n=n) 74 | node = Batcher(node, batch_size=batch_size, drop_last=False) 75 | node = Mapper(node, Collate()) 76 | node = PinMemory(node, snapshot_frequency=snapshot_frequency) 77 | node = Prefetcher(node, prefetch_factor=8) 78 | 79 | run_test_save_load_state(self, node, midpoint) 80 | -------------------------------------------------------------------------------- /test/nodes/test_prefetch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | 9 | import torch 10 | from parameterized import parameterized 11 | from torch.testing._internal.common_utils import TestCase 12 | from torchdata.nodes.adapters import IterableWrapper 13 | from torchdata.nodes.batch import Batcher 14 | from torchdata.nodes.loader import Loader 15 | from torchdata.nodes.prefetch import Prefetcher 16 | 17 | from .utils import IterInitError, MockSource, run_test_save_load_state, StatefulRangeNode 18 | 19 | 20 | class TestPrefetcher(TestCase): 21 | def test_prefetcher(self) -> None: 22 | batch_size = 6 23 | src = MockSource(num_samples=20) 24 | node = Batcher(src, batch_size=batch_size, drop_last=True) 25 | root = Prefetcher(node, prefetch_factor=2) 26 | 27 | # Test multi epoch shutdown and restart 28 | for _ in range(2): 29 | root.reset() 30 | results = list(root) 31 | self.assertEqual(len(results), 3) 32 | for i in range(3): 33 | for j in range(batch_size): 34 | self.assertEqual(results[i][j]["step"], i * batch_size + j) 35 | self.assertEqual(results[i][j]["test_tensor"], torch.tensor([i * batch_size + j])) 36 | self.assertEqual(results[i][j]["test_str"], f"str_{i * batch_size + j}") 37 | 38 | def test_iter_init_error(self): 39 | node = IterInitError() 40 | root = Prefetcher(node, prefetch_factor=2) 41 | 42 | with self.assertRaisesRegex(ValueError, "Iter Init Error"): 43 | list(root) 44 | 45 | @parameterized.expand(itertools.product([0, 7, 32], [0, 1, 9])) 46 | def test_save_load_state_stateful(self, midpoint: int, snapshot_frequency: int): 47 | batch_size = 6 48 | n = 200 49 | src = StatefulRangeNode(n=n) 50 | node = Batcher(src, batch_size=batch_size, drop_last=False) 51 | node = Prefetcher(node, prefetch_factor=8, snapshot_frequency=snapshot_frequency) 52 | run_test_save_load_state(self, node, midpoint) 53 | -------------------------------------------------------------------------------- /test/nodes/test_snapshot_store.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import threading 8 | import time 9 | from concurrent.futures import ThreadPoolExecutor 10 | 11 | from torch.testing._internal.common_utils import TestCase 12 | from torchdata.nodes.constants import QUEUE_TIMEOUT 13 | from torchdata.nodes.exception_wrapper import StartupExceptionWrapper 14 | from torchdata.nodes.snapshot_store import QueueSnapshotStore 15 | 16 | 17 | class TestQueueSnapshotStore(TestCase): 18 | def test_snapshot_store(self) -> None: 19 | for _ in range(100): 20 | store = QueueSnapshotStore() 21 | store.append({"a": 1}, 0) 22 | store.append({"a": 2}, 10) 23 | 24 | self.assertEqual(len(store._q.queue), 2) 25 | 26 | val = store.pop_version(0) 27 | self.assertEqual(val, {"a": 1}) 28 | self.assertEqual(len(store._q.queue), 1) 29 | val = store.pop_version(1) 30 | self.assertIsNone(val) 31 | self.assertEqual(len(store._q.queue), 1) 32 | val = store.pop_version(7) 33 | self.assertIsNone(val) 34 | self.assertEqual(len(store._q.queue), 1) 35 | val = store.pop_version(10) 36 | self.assertEqual(val, {"a": 2}) 37 | self.assertEqual(len(store._q.queue), 0) 38 | 39 | val = store.pop_version(11) 40 | self.assertIsNone(val) 41 | self.assertEqual(len(store._q.queue), 0) 42 | 43 | with self.assertRaisesRegex(ValueError, "is not strictly greater than"): 44 | store.append({"a": 3}, 3) 45 | 46 | self.assertEqual(len(store._q.queue), 0) 47 | 48 | with self.assertRaisesRegex(ValueError, "is not strictly greater than"): 49 | store.append({"a": 4}, 10) 50 | self.assertEqual(len(store._q.queue), 0) 51 | 52 | store.append({"a": 4}, 11) 53 | store.append({"a": 5}, 19) 54 | val = store.pop_version(19) 55 | self.assertEqual(val, {"a": 5}) 56 | self.assertEqual(len(store._q.queue), 0) 57 | 58 | def test_init_error(self) -> None: 59 | for _ in range(10): 60 | store = QueueSnapshotStore() 61 | sleep_time = 0.1 62 | thread = threading.Thread(target=_worker_init_error, args=(store, sleep_time)) 63 | thread.start() 64 | with self.assertRaisesRegex(RuntimeError, "Test Startup Exception"): 65 | store.get_initial_snapshot(thread, sleep_time) 66 | thread.join() 67 | 68 | def test_timeout_error(self) -> None: 69 | for _ in range(10): 70 | store = QueueSnapshotStore() 71 | sleep_time = 0.1 72 | thread = threading.Thread(target=_worker_raises_after, args=(sleep_time,)) 73 | thread.start() 74 | with self.assertRaisesRegex(RuntimeError, "Failed to get initial snapshot"): 75 | store.get_initial_snapshot(thread, sleep_time * 0.1) 76 | thread.join() 77 | 78 | def test_thread_dead_error(self) -> None: 79 | # Test when thread is alive for longer than QUEUE_TIMEOUT but dies afterwards 80 | for _ in range(10): # Should be reliable 81 | store = QueueSnapshotStore() 82 | thread = threading.Thread(target=_worker_raises_after, args=(QUEUE_TIMEOUT * 3.0,)) 83 | thread.start() 84 | with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"): 85 | store.get_initial_snapshot(thread, QUEUE_TIMEOUT * 5.0) 86 | thread.join() 87 | 88 | def test_future_dead_error(self) -> None: 89 | # Test when thread is alive for longer than QUEUE_TIMEOUT but dies afterwards 90 | for _ in range(10): # Should be reliable 91 | store = QueueSnapshotStore() 92 | pool = ThreadPoolExecutor() 93 | future = pool.submit(_worker_raises_after, QUEUE_TIMEOUT * 3.0) 94 | with self.assertRaisesRegex(RuntimeError, r"thread.is_alive\(\)=False"): 95 | store.get_initial_snapshot(future, QUEUE_TIMEOUT * 5.0) 96 | pool.shutdown() 97 | 98 | 99 | def _worker_init_error(store, sleep_time): 100 | try: 101 | raise RuntimeError("Test Startup Exception") 102 | except Exception as e: 103 | e = StartupExceptionWrapper(where="_worker_init_error") 104 | store.append_initial_snapshot(e) 105 | time.sleep(sleep_time) 106 | 107 | 108 | def _worker_raises_after(sleep_time): 109 | time.sleep(sleep_time) 110 | raise RuntimeError(f"Thread dying {sleep_time=}") 111 | -------------------------------------------------------------------------------- /test/nodes/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | import time 9 | from typing import Any, Dict, Iterator, Optional 10 | 11 | import torch 12 | from torchdata.nodes.adapters import IterableWrapper 13 | from torchdata.nodes.base_node import BaseNode 14 | from torchdata.nodes.loader import Loader 15 | 16 | from torchdata.nodes.types import Stateful 17 | 18 | 19 | class MockGenerator: 20 | def __init__(self, num_samples: int) -> None: 21 | self.num_samples = num_samples 22 | 23 | def __iter__(self): 24 | for i in range(self.num_samples): 25 | yield {"step": i, "test_tensor": torch.tensor([i]), "test_str": f"str_{i}"} 26 | 27 | 28 | def MockSource(num_samples: int) -> BaseNode[dict]: 29 | return IterableWrapper(MockGenerator(num_samples)) 30 | 31 | 32 | def udf_raises(item): 33 | raise ValueError("test exception") 34 | 35 | 36 | class RandomSleepUdf: 37 | def __init__(self, sleep_max_sec: float = 0.01) -> None: 38 | self.sleep_max_sec = sleep_max_sec 39 | 40 | def __call__(self, x): 41 | time.sleep(random.random() * self.sleep_max_sec) 42 | return x 43 | 44 | 45 | class Collate: 46 | def __call__(self, x): 47 | result = {} 48 | for k in x[0].keys(): 49 | result[k] = [i[k] for i in x] 50 | return result 51 | 52 | 53 | class IterInitError(BaseNode[int]): 54 | def __init__(self, msg: str = "Iter Init Error") -> None: 55 | super().__init__() 56 | self.msg = msg 57 | 58 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 59 | super().reset(initial_state) 60 | raise ValueError(self.msg) 61 | 62 | def next(self): 63 | raise ValueError("next() should not be called") 64 | 65 | def get_state(self) -> Dict[str, Any]: 66 | return {} 67 | 68 | 69 | class DummyIterableDataset(torch.utils.data.IterableDataset): 70 | def __init__(self, num_samples: int, name: str = "test") -> None: 71 | self.num_samples = num_samples 72 | self.name = name 73 | 74 | def __iter__(self) -> Iterator[dict]: 75 | for i in range(self.num_samples): 76 | yield { 77 | "name": self.name, 78 | "step": i, 79 | "test_tensor": torch.tensor([i]), 80 | "test_str": f"str_{i}", 81 | } 82 | 83 | 84 | class DummyMapDataset(torch.utils.data.Dataset): 85 | def __init__(self, num_samples: int, name: str = "test") -> None: 86 | self.num_samples = num_samples 87 | self.name = name 88 | 89 | def __len__(self) -> int: 90 | return self.num_samples 91 | 92 | def __getitem__(self, i: int) -> dict: 93 | return { 94 | "name": self.name, 95 | "step": i, 96 | "test_tensor": torch.tensor([i]), 97 | "test_str": f"str_{i}", 98 | } 99 | 100 | 101 | class StatefulRange(Stateful): 102 | def __init__(self, n: int) -> None: 103 | self.n = n 104 | self._num_yielded = 0 105 | self._next_start = 0 106 | 107 | def __iter__(self) -> Iterator[int]: 108 | self._num_yielded = self._next_start # Reset for next iter call 109 | self._next_start = 0 110 | for i in range(self._num_yielded, self.n): 111 | self._num_yielded += 1 112 | yield i 113 | 114 | def state_dict(self) -> Dict[str, Any]: 115 | return {"_num_yielded": self._num_yielded} 116 | 117 | def load_state_dict(self, state_dict: Dict[str, Any]): 118 | self._next_start = state_dict["_num_yielded"] 119 | 120 | 121 | class StatefulRangeNode(BaseNode[Dict[str, int]]): 122 | def __init__(self, n: int) -> None: 123 | super().__init__() 124 | self.n = n 125 | self.i = 0 126 | self.num_resets = 0 127 | 128 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 129 | super().reset(initial_state) 130 | if initial_state is not None: 131 | self.i = initial_state["i"] 132 | self.num_resets = initial_state["num_resets"] 133 | else: 134 | self.i = 0 135 | self.num_resets += 1 136 | 137 | def next(self) -> Dict[str, int]: 138 | if self.i == self.n: 139 | raise StopIteration() 140 | ret = {"i": self.i, "resets": self.num_resets} 141 | self.i += 1 142 | return ret 143 | 144 | def get_state(self) -> Dict[str, Any]: 145 | return { 146 | "i": self.i, 147 | "num_resets": self.num_resets, 148 | } 149 | 150 | 151 | def run_test_save_load_state(test, node: BaseNode, midpoint: int): 152 | ############################## 153 | # Generate initial, midpoint, and end state_dict's 154 | x = Loader(node) 155 | 156 | initial_state_dict = x.state_dict() 157 | it = iter(x) 158 | results = [] 159 | for _ in range(midpoint): 160 | results.append(next(it)) 161 | state_dict = x.state_dict() 162 | for val in it: 163 | results.append(val) 164 | 165 | state_dict_0_end = x.state_dict() 166 | 167 | # store epoch 1's results 168 | it = iter(x) 169 | results_1 = [] 170 | for _ in range(midpoint): 171 | results_1.append(next(it)) 172 | state_dict_1 = x.state_dict() 173 | for val in it: 174 | results_1.append(val) 175 | 176 | ############################## 177 | # Test restoring from midpoint 178 | x.load_state_dict(state_dict) 179 | results_after = list(x) 180 | test.assertEqual(results_after, results[midpoint:]) 181 | 182 | # Test for second epoch after resume 183 | results_after_1 = list(x) 184 | test.assertEqual(results_after_1, results_1) 185 | 186 | ############################## 187 | # Test restoring from midpoint of epoch 1 188 | x.load_state_dict(state_dict_1) 189 | results_after_2 = list(x) 190 | test.assertEqual(results_after_2, results_1[midpoint:]) 191 | 192 | ############################## 193 | # Test initialize from beginning after resume 194 | x.load_state_dict(initial_state_dict) 195 | full_results = list(x) 196 | test.assertEqual(full_results, results) 197 | full_results_1 = list(x) 198 | test.assertEqual(full_results_1, results_1) 199 | 200 | ############################## 201 | # Test restoring from end-of-epoch 0 202 | x = Loader(node, restart_on_stop_iteration=False) 203 | x.load_state_dict(state_dict_0_end) 204 | results_after_dict_0_with_restart_false = list(x) 205 | test.assertEqual(results_after_dict_0_with_restart_false, []) 206 | 207 | ############################## 208 | # Test restoring from end of epoch 0 with restart_on_stop_iteration=True 209 | x = Loader(node) 210 | x.load_state_dict(state_dict_0_end) 211 | results_after_dict_0 = list(x) 212 | test.assertEqual(results_after_dict_0, results_1) 213 | -------------------------------------------------------------------------------- /test/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | expecttest 3 | fsspec 4 | numpy<2 5 | datasets @ git+https://github.com/huggingface/datasets@main 6 | graphviz 7 | adlfs 8 | awscli>=1.27.66 9 | psutil 10 | parameterized 11 | -------------------------------------------------------------------------------- /test/smoke_test/smoke_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | def stateful_dataloader_test(): 9 | from torchdata.nodes import Loader 10 | from torchdata.stateful_dataloader import StatefulDataLoader 11 | 12 | 13 | if __name__ == "__main__": 14 | r""" 15 | TorchData Smoke Test 16 | """ 17 | 18 | stateful_dataloader_test() 19 | -------------------------------------------------------------------------------- /test/stateful_dataloader/test_hugging_face.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import itertools 8 | 9 | from datasets.info import DatasetInfo 10 | from datasets.iterable_dataset import ExamplesIterable, IterableDataset 11 | from torch.testing._internal.common_utils import IS_MACOS, TestCase 12 | from torchdata.stateful_dataloader import StatefulDataLoader 13 | 14 | 15 | DEFAULT_N_EXAMPLES = 20 16 | DEFAULT_FILEPATH = "file.txt" 17 | 18 | 19 | def generate_examples_fn(**kwargs): 20 | kwargs = kwargs.copy() 21 | n = kwargs.pop("n", DEFAULT_N_EXAMPLES) 22 | filepaths = kwargs.pop("filepaths", None) 23 | for filepath in filepaths or [DEFAULT_FILEPATH]: 24 | if filepaths is not None: 25 | kwargs["filepath"] = filepath 26 | for i in range(n): 27 | yield f"{filepath}_{i}", {"id": i, **kwargs} 28 | 29 | 30 | def identity(x): 31 | return x 32 | 33 | 34 | class TestStatefulDataLoaderIterable_shard0(TestCase): 35 | def _get_dataset(self): 36 | ex_iterable = ExamplesIterable(generate_examples_fn, {}) 37 | return IterableDataset(ex_iterable, info=DatasetInfo(description="dummy"), split="train") 38 | 39 | def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1): 40 | dataset = self._get_dataset() 41 | dl = StatefulDataLoader( 42 | dataset=dataset, 43 | num_workers=num_workers, 44 | collate_fn=identity, 45 | snapshot_every_n_steps=every_n_steps, 46 | persistent_workers=pw, 47 | multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, 48 | ) 49 | it = iter(dl) 50 | for _ in range(interrupt): 51 | next(it) 52 | 53 | state_dict = dl.state_dict() 54 | exp = [] 55 | for data in it: 56 | exp.append(data) 57 | 58 | # Restore new instance from state 59 | batches = [] 60 | dl = StatefulDataLoader( 61 | dataset=dataset, 62 | num_workers=num_workers, 63 | collate_fn=identity, 64 | snapshot_every_n_steps=every_n_steps, 65 | persistent_workers=pw, 66 | multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, 67 | ) 68 | dl.load_state_dict(state_dict) 69 | for batch in iter(dl): 70 | batches.append(batch) 71 | 72 | self.assertEqual(exp, batches) 73 | 74 | def test_no_mp(self): 75 | for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): 76 | with self.subTest(batch_size=batch_size, interrupt=interrupt): 77 | self._run_and_checkpoint( 78 | num_workers=0, 79 | batch_size=batch_size, 80 | pw=False, 81 | interrupt=interrupt, 82 | ) 83 | 84 | def test_mp_x(self): 85 | for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): 86 | with self.subTest(batch_size=batch_size, interrupt=interrupt): 87 | self._run_and_checkpoint( 88 | num_workers=3, 89 | batch_size=batch_size, 90 | pw=False, 91 | interrupt=interrupt, 92 | ) 93 | 94 | def test_mp_pw(self): 95 | for batch_size, interrupt in itertools.product([None, 7], [0, 1, 10]): 96 | with self.subTest(batch_size=batch_size, interrupt=interrupt): 97 | self._run_and_checkpoint( 98 | num_workers=3, 99 | batch_size=batch_size, 100 | pw=True, 101 | interrupt=interrupt, 102 | ) 103 | 104 | def test_mp_every_n_steps(self): 105 | batch_size = 7 106 | for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): 107 | with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt): 108 | self._run_and_checkpoint( 109 | num_workers=3, 110 | batch_size=batch_size, 111 | pw=True, 112 | interrupt=interrupt, 113 | ) 114 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tools/setup_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tools/setup_helpers/extension.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | from pathlib import Path 10 | 11 | 12 | __all__ = ["get_ext_modules"] 13 | 14 | 15 | _THIS_DIR = Path(__file__).parent.resolve() 16 | _ROOT_DIR = _THIS_DIR.parent.parent.resolve() 17 | 18 | 19 | def _get_build(var, default=False): 20 | if var not in os.environ: 21 | return default 22 | 23 | val = os.environ.get(var, "0") 24 | trues = ["1", "true", "TRUE", "on", "ON", "yes", "YES"] 25 | falses = ["0", "false", "FALSE", "off", "OFF", "no", "NO"] 26 | if val in trues: 27 | return True 28 | if val not in falses: 29 | print(f"WARNING: Unexpected environment variable value `{var}={val}`. " f"Expected one of {trues + falses}") 30 | return False 31 | 32 | 33 | def get_ext_modules(): 34 | return [] 35 | -------------------------------------------------------------------------------- /tools/todo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Scrip can be used with 8 | # find -name '*.py' | grep -v third_party | perl -ne'print "python tools/todo.py $_"' | head -n 5 | bash 9 | 10 | import configparser 11 | import os 12 | import re 13 | import shutil 14 | import sys 15 | import tempfile 16 | 17 | from github import Github # pip install PyGithub 18 | 19 | file_name = sys.argv[1] 20 | 21 | config = configparser.ConfigParser(allow_no_value=True) 22 | with open(os.path.join(os.path.expanduser("~"), ".ghstackrc")) as stream: 23 | config.read_string(stream.read()) 24 | 25 | GITHUB_KEY = config["ghstack"]["github_oauth"] 26 | 27 | 28 | def get_git_branch_hash(): 29 | stream = os.popen("git rev-parse origin/main") 30 | return stream.read().rstrip() 31 | 32 | 33 | def generate_issue_id(id_or_name, title, file_name, line_number): 34 | git_branch_hash = get_git_branch_hash() 35 | # print(file_name, line_number, title, id_or_name) 36 | match = re.match(r"\((\d+)\)", id_or_name) 37 | if match: 38 | return int(match.group(1)) 39 | match = re.match(r"\((.*)\)", id_or_name) 40 | name = None 41 | if match: 42 | name = match.group(1) 43 | if name is not None: 44 | owner = f"cc @{name}" 45 | else: 46 | owner = "" 47 | g = Github(GITHUB_KEY) 48 | repo = g.get_repo("pytorch/data") 49 | # label_be = repo.get_label("better-engineering" ) 50 | # labels = [label_be] 51 | line_reference = f"https://github.com/pytorch/data/blob/{git_branch_hash}/{file_name}#L{line_number}" 52 | line_reference = line_reference.replace("/./", "/") 53 | body = """ 54 | This issue is generated from the TODO line 55 | 56 | {line_reference} 57 | 58 | {owner} 59 | """.format( 60 | owner=owner, 61 | line_reference=line_reference, 62 | ) 63 | title = f"[TODO] {title}" 64 | issue = repo.create_issue(title=title, body=body, labels=[]) 65 | print(f"Created issue https://github.com/pytorch/data/issues/{issue.number}") 66 | return issue.number 67 | 68 | 69 | def update_file(file_name): 70 | try: 71 | f = tempfile.NamedTemporaryFile(delete=False) 72 | shutil.copyfile(file_name, f.name) 73 | with open(f.name) as f_inp: 74 | with open(file_name, "w") as f_out: 75 | for line_number, line in enumerate(f_inp.readlines()): 76 | if not re.search(r"ignore-todo", line, re.IGNORECASE): 77 | match = re.search(r"(.*?)#\s*todo\s*(\([^)]+\)){0,1}:{0,1}(.*)", line, re.IGNORECASE) 78 | if match: 79 | # print(line) 80 | prefix = match.group(1) 81 | text = match.group(3) 82 | issue_id = generate_issue_id(str(match.group(2)), text, file_name, line_number + 1) 83 | line = f"{prefix}# TODO({issue_id}):{text}\n" # ignore-todo 84 | f_out.write(line) 85 | except Exception as e: 86 | shutil.copyfile(f.name, file_name) 87 | raise e 88 | finally: 89 | os.unlink(f.name) 90 | 91 | 92 | update_file(file_name) 93 | -------------------------------------------------------------------------------- /torchdata/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | try: 8 | from .version import __version__ # noqa: F401 9 | except ImportError: 10 | pass 11 | -------------------------------------------------------------------------------- /torchdata/nodes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper 8 | from .base_node import BaseNode, T 9 | from .batch import Batcher, Unbatcher 10 | from .cycler import Cycler 11 | from .filter import Filter 12 | from .header import Header 13 | from .loader import Loader 14 | from .map import Mapper, ParallelMapper 15 | from .pin_memory import PinMemory 16 | from .prefetch import Prefetcher 17 | from .samplers.multi_node_weighted_sampler import MultiNodeWeightedSampler 18 | from .samplers.stop_criteria import StopCriteria 19 | from .types import Stateful 20 | 21 | 22 | __all__ = [ 23 | "BaseNode", 24 | "Batcher", 25 | "Cycler", 26 | "Filter", 27 | "Header", 28 | "IterableWrapper", 29 | "Loader", 30 | "MapStyleWrapper", 31 | "Mapper", 32 | "MultiNodeWeightedSampler", 33 | "ParallelMapper", 34 | "PinMemory", 35 | "Prefetcher", 36 | "SamplerWrapper", 37 | "Stateful", 38 | "StopCriteria", 39 | "T", 40 | "Unbatcher", 41 | ] 42 | 43 | assert sorted(__all__) == __all__ 44 | -------------------------------------------------------------------------------- /torchdata/nodes/_apply_udf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import multiprocessing.synchronize as python_mp_synchronize 8 | import queue 9 | import threading 10 | from typing import Callable, Union 11 | 12 | import torch 13 | import torch.multiprocessing as mp 14 | 15 | from torch._utils import ExceptionWrapper 16 | 17 | from .constants import QUEUE_TIMEOUT 18 | 19 | 20 | def _apply_udf( 21 | worker_id: int, 22 | in_q: Union[queue.Queue, mp.Queue], 23 | out_q: Union[queue.Queue, mp.Queue], 24 | udf: Callable, 25 | stop_event: Union[threading.Event, python_mp_synchronize.Event], 26 | ): 27 | """_apply_udf assumes in_q emits tuples of (x, idx) where x is the 28 | payload, idx is the index of the result, potentially used for maintaining 29 | ordered outputs. For every input it pulls, a tuple (y, idx) is put on the out_q 30 | where the output of udf(x), an ExceptionWrapper, or StopIteration (if it pulled 31 | StopIteration from in_q). 32 | """ 33 | torch.set_num_threads(1) 34 | while True: 35 | if stop_event.is_set() and in_q.empty(): 36 | break 37 | 38 | try: 39 | item, idx = in_q.get(block=True, timeout=QUEUE_TIMEOUT) 40 | except queue.Empty: 41 | continue 42 | 43 | if isinstance(item, ExceptionWrapper): 44 | out_q.put((item, idx), block=False) 45 | elif isinstance(item, StopIteration): 46 | out_q.put((item, idx), block=False) 47 | else: 48 | try: 49 | y = udf(item) 50 | except Exception: 51 | y = ExceptionWrapper(where="in _apply_udf") 52 | 53 | out_q.put((y, idx), block=False) 54 | -------------------------------------------------------------------------------- /torchdata/nodes/_populate_queue.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import queue 8 | import threading 9 | from typing import Any, Dict, Optional, Union 10 | 11 | import torch.multiprocessing as mp 12 | 13 | from torchdata.nodes.base_node import BaseNode 14 | 15 | from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper 16 | from torchdata.nodes.snapshot_store import MonotonicIndex, SnapshotStore 17 | 18 | from .constants import QUEUE_TIMEOUT 19 | 20 | 21 | def _populate_queue( 22 | source: BaseNode, 23 | q: Union[queue.Queue, mp.Queue], 24 | snapshot_store: SnapshotStore, 25 | snapshot_frequency: int, 26 | semaphore: threading.BoundedSemaphore, 27 | stop_event: threading.Event, 28 | ): 29 | """_populate_queue calls `iter(source)` to get an iterator `it`, waits for semaphore.acquire, 30 | and puts its outputs onto q. It never releases the sempahore. It continues to put items on the 31 | q as long as it can acquire the sempahore, stop_event is not set, and StopIteration has not 32 | been thrown by the `it`. 33 | 34 | This function will always put tuples of (x, idx) on the q where idx 35 | starts from 0 and is monotonically increasing. x may be the output of next(it), StopIteration, 36 | or an ExceptionWrapper. 37 | 38 | If there is an exception raised during the call to `iter(source)`, this function does not 39 | wait to acquire sempahore before putting StartupExceptionWrapper on q. 40 | 41 | Note: this is only intended to be used by a single thread at once. Each instance 42 | creates its own iter for source so if this is called with multiple threads, you may get 43 | duplicates if source is not sharded properly. 44 | """ 45 | 46 | # Include a monotonic index starting from 0 to each item in the queue 47 | idx = MonotonicIndex() 48 | 49 | def _put( 50 | item, 51 | block: bool = True, 52 | snapshot: Optional[Union[Dict[str, Any], StartupExceptionWrapper]] = None, 53 | ): 54 | _idx = idx.get() 55 | if snapshot: 56 | snapshot_store.append(snapshot=snapshot, version=_idx) 57 | q.put((item, _idx), block=block, timeout=1.0 if block else None) 58 | 59 | try: 60 | assert ( 61 | isinstance(snapshot_frequency, int) and snapshot_frequency >= 0 62 | ), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}" 63 | snapshot_store.append_initial_snapshot(snapshot=source.state_dict()) 64 | except Exception: 65 | e = StartupExceptionWrapper(where="in _populate_queue startup for device") 66 | snapshot_store.append_initial_snapshot(snapshot=e) 67 | return 68 | 69 | yielded = 0 70 | while not stop_event.is_set(): 71 | if not semaphore.acquire(blocking=True, timeout=QUEUE_TIMEOUT): 72 | continue 73 | try: 74 | item = next(source) # FIXME: This may hang! 75 | yielded += 1 76 | snapshot = None 77 | if snapshot_frequency > 0 and yielded % snapshot_frequency == 0: 78 | snapshot = source.state_dict() 79 | _put(item, block=False, snapshot=snapshot) 80 | except StopIteration as e: 81 | _put(e, block=False) 82 | break 83 | except Exception: 84 | item = ExceptionWrapper(where="in _populate_queue") 85 | _put(item, block=False) 86 | break 87 | -------------------------------------------------------------------------------- /torchdata/nodes/base_node.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from typing import Any, Dict, Iterator, Optional, TypeVar 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | T = TypeVar("T", covariant=True) 14 | 15 | 16 | class BaseNode(Iterator[T]): 17 | """BaseNodes are the base class for creating composable dataloading DAGs in ``torchdata.nodes``. 18 | 19 | Most end-users will not iterate over a BaseNode instance directly, but instead 20 | wrap it in a :class:`torchdata.nodes.Loader` which converts the DAG into a more familiar Iterable. 21 | 22 | .. code-block:: python 23 | 24 | node = MyBaseNodeImpl() 25 | loader = Loader(node) 26 | # loader supports state_dict() and load_state_dict() 27 | 28 | for epoch in range(5): 29 | for idx, batch in enumerate(loader): 30 | ... 31 | 32 | # or if using node directly: 33 | node = MyBaseNodeImpl() 34 | for epoch in range(5): 35 | node.reset() 36 | for idx, batch in enumerate(loader): 37 | ... 38 | """ 39 | 40 | def __init__(self, *args, **kwargs): 41 | """Subclasses must implement this method and call super().__init__(*args, **kwargs)""" 42 | self.__initialized = False 43 | 44 | def __iter__(self): 45 | return self 46 | 47 | def reset(self, initial_state: Optional[dict] = None): 48 | """Resets the iterator to the beginning, or to the state passed in by initial_state. 49 | 50 | Reset is a good place to put expensive initialization, as it will be lazily called when ``next()`` or ``state_dict()`` is called. 51 | Subclasses must call ``super().reset(initial_state)``. 52 | 53 | Args: 54 | initial_state: Optional[dict] - a state dict to pass to the node. If None, reset to the beginning. 55 | """ 56 | 57 | self.__initialized = True 58 | 59 | def get_state(self) -> Dict[str, Any]: 60 | """Subclasses must implement this method, instead of ``state_dict()``. Should only be called by BaseNode. 61 | 62 | Returns: 63 | Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future 64 | """ 65 | raise NotImplementedError(type(self)) 66 | 67 | def next(self) -> T: 68 | """Subclasses must implement this method, instead of ``__next__``. Should only be called by BaseNode. 69 | 70 | Returns: 71 | T - the next value in the sequence, or throw StopIteration 72 | """ 73 | raise NotImplementedError(type(self)) 74 | 75 | def __next__(self): 76 | try: 77 | self.__initialized 78 | except AttributeError: 79 | raise NotImplementedError(f"self.__initialized not found, did you call super().__init__()? {type(self)=}") 80 | if not self.__initialized: 81 | self.reset(None) 82 | if not self.__initialized: 83 | raise NotImplementedError( 84 | f"Failed to initialize after .reset(), did you call super().reset() in your .reset() method? {type(self)=}" 85 | ) 86 | return self.next() 87 | 88 | def state_dict(self) -> Dict[str, Any]: 89 | """Get a state_dict for this BaseNode. 90 | 91 | Returns: 92 | Dict[str, Any] - a state dict that may be passed to ``reset()`` at some point in the future. 93 | """ 94 | try: 95 | self.__initialized 96 | except AttributeError: 97 | raise NotImplementedError(f"self.__initialized not found, did you call super().__init__()? {type(self)=}") 98 | 99 | if not self.__initialized: 100 | self.reset(None) 101 | if not self.__initialized: 102 | raise NotImplementedError( 103 | f"Failed to initialize after .reset(), did you call super().reset() in your .reset() method? {type(self)=}" 104 | ) 105 | return self.get_state() 106 | -------------------------------------------------------------------------------- /torchdata/nodes/batch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, List, Optional, Sequence 8 | 9 | from torchdata.nodes.base_node import BaseNode, T 10 | 11 | 12 | class Batcher(BaseNode[List[T]]): 13 | """Batcher node batches the data from the source node into batches of size batch_size. 14 | If the source node is exhausted, it will return the batch or raise StopIteration. 15 | If drop_last is True, the last batch will be dropped if it is smaller than batch_size. 16 | If drop_last is False, the last batch will be returned even if it is smaller than batch_size. 17 | 18 | Args: 19 | source (BaseNode[T]): The source node to batch the data from. 20 | batch_size (int): The size of the batch. 21 | drop_last (bool): Whether to drop the last batch if it is smaller than batch_size. Default is True. 22 | """ 23 | 24 | SOURCE_KEY = "source" 25 | 26 | def __init__(self, source: BaseNode[T], batch_size: int, drop_last: bool = True): 27 | super().__init__() 28 | self.source = source 29 | self.batch_size = batch_size 30 | self.drop_last = drop_last 31 | 32 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 33 | super().reset(initial_state) 34 | if initial_state is not None: 35 | self.source.reset(initial_state[self.SOURCE_KEY]) 36 | else: 37 | self.source.reset() 38 | 39 | def next(self) -> List[T]: 40 | batch: List[T] = [] 41 | while len(batch) < self.batch_size: 42 | try: 43 | item = next(self.source) 44 | except StopIteration: 45 | break 46 | batch.append(item) 47 | if len(batch) == self.batch_size: 48 | return batch 49 | 50 | if len(batch) == self.batch_size: 51 | return batch 52 | elif len(batch) and not self.drop_last: 53 | return batch 54 | else: 55 | raise StopIteration() 56 | 57 | def get_state(self) -> Dict[str, Any]: 58 | return {self.SOURCE_KEY: self.source.state_dict()} 59 | 60 | 61 | class Unbatcher(BaseNode[T]): 62 | """Unbatcher will flatten batches pulled from source, and 63 | yields elements in sequential order when next() is called on it. 64 | 65 | Args: 66 | source (BaseNode[T]): The source node to pull batches from. 67 | """ 68 | 69 | SOURCE_KEY = "source" 70 | BATCH_IDX_KEY = "batch_idx" 71 | 72 | def __init__(self, source: BaseNode[Sequence[T]]): 73 | super().__init__(self) 74 | self.source = source 75 | 76 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 77 | super().reset(initial_state) 78 | if initial_state is not None: 79 | self.source.reset(initial_state[self.SOURCE_KEY]) 80 | self._cached_state_dict = initial_state[self.SOURCE_KEY] 81 | try: 82 | self._batch = next(self.source) 83 | self._batch_idx = initial_state[self.BATCH_IDX_KEY] 84 | except StopIteration: 85 | # next(self.source) will be called upon subsequent self.next() call 86 | # and raise StopIteration in the correct place. 87 | self._batch = [] 88 | self._batch_idx = 0 89 | else: 90 | self.source.reset() 91 | self._batch = [] 92 | self._cached_state_dict = None 93 | self._batch_idx = 0 94 | 95 | def next(self) -> T: 96 | while self._batch_idx >= len(self._batch): 97 | self._cached_state_dict = self.source.state_dict() 98 | self._batch = next(self.source) 99 | self._batch_idx = 0 100 | 101 | self._batch_idx += 1 102 | return self._batch[self._batch_idx - 1] 103 | 104 | def get_state(self) -> Dict[str, Any]: 105 | if self._cached_state_dict is None: 106 | self._cached_state_dict = self.source.state_dict() 107 | 108 | return { 109 | self.SOURCE_KEY: self._cached_state_dict, 110 | self.BATCH_IDX_KEY: self._batch_idx, 111 | } 112 | -------------------------------------------------------------------------------- /torchdata/nodes/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | QUEUE_TIMEOUT = 0.1 8 | -------------------------------------------------------------------------------- /torchdata/nodes/cycler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Optional, TypeVar 8 | 9 | from torchdata.nodes import BaseNode 10 | 11 | 12 | T = TypeVar("T") 13 | 14 | 15 | class Cycler(BaseNode[T]): 16 | """Node that cycles through source node a limited or unlimited number of times. 17 | 18 | This node will continuously loop through the source node. When the source node 19 | is exhausted, it will be reset and iteration will start from the beginning again. 20 | The node keeps track of how many times it has completed a full cycle through 21 | the source and the total number of items yielded. 22 | 23 | Args: 24 | source_node (BaseNode[T]): The source node to cycle through. 25 | max_cycles (Optional[int]): Maximum number of cycles to perform. If None, 26 | cycles indefinitely. Must be positive if specified. Default: None. 27 | """ 28 | 29 | SOURCE_KEY = "source" 30 | NUM_CYCLES_KEY = "num_cycles" 31 | HAS_STARTED_KEY = "has_started" 32 | NUM_YIELDED_KEY = "num_yielded" 33 | MAX_CYCLES_KEY = "max_cycles" 34 | 35 | def __init__(self, source_node: BaseNode[T], max_cycles: Optional[int] = None): 36 | super().__init__() 37 | if max_cycles is not None and max_cycles <= 0: 38 | raise ValueError("max_cycles must be positive if specified") 39 | 40 | self.source = source_node 41 | self.max_cycles = max_cycles 42 | self._num_cycles = 0 43 | self._has_started = False 44 | self._num_yielded = 0 45 | 46 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 47 | """Reset the node to its initial state or to the provided state. 48 | 49 | Args: 50 | initial_state: Optional state dictionary to restore from. 51 | """ 52 | super().reset(initial_state) 53 | if initial_state is not None: 54 | # Be strict about required keys in the state 55 | self._num_cycles = initial_state[self.NUM_CYCLES_KEY] 56 | self._has_started = initial_state[self.HAS_STARTED_KEY] 57 | self._num_yielded = initial_state[self.NUM_YIELDED_KEY] 58 | self.max_cycles = initial_state[self.MAX_CYCLES_KEY] 59 | self.source.reset(initial_state[self.SOURCE_KEY]) 60 | else: 61 | self._num_cycles = 0 62 | self._has_started = False 63 | self._num_yielded = 0 64 | self.source.reset(None) 65 | 66 | def next(self) -> T: 67 | """Get the next item from the source node, cycling if necessary. 68 | 69 | Returns: 70 | The next item from the source node. 71 | 72 | Raises: 73 | StopIteration: If the source node is empty or max_cycles is reached. 74 | """ 75 | try: 76 | item = next(self.source) 77 | self._has_started = True 78 | self._num_yielded += 1 79 | return item 80 | except StopIteration: 81 | # If this is the first time we're trying to get an item and it fails, 82 | # the source is empty - just propagate the StopIteration without cycling 83 | if not self._has_started: 84 | raise StopIteration 85 | 86 | # Otherwise, source is exhausted after yielding some items 87 | # Increment cycle count and check max_cycles limit 88 | self._num_cycles += 1 89 | 90 | # If we've reached max_cycles, stop iteration 91 | if self.max_cycles is not None and self._num_cycles >= self.max_cycles: 92 | raise StopIteration 93 | 94 | # Reset source and continue 95 | self.source.reset(None) 96 | 97 | # Try to get the first item after reset 98 | # This could still raise StopIteration if the source becomes empty 99 | # after reset (e.g., a dynamic source that changes over time) 100 | try: 101 | item = next(self.source) 102 | self._num_yielded += 1 103 | return item 104 | except StopIteration: 105 | raise 106 | 107 | def get_state(self) -> Dict[str, Any]: 108 | """Get the current state of the node. 109 | 110 | Returns: 111 | Dict[str, Any]: A dictionary containing the state of the source node, 112 | number of cycles completed, whether iteration has started, 113 | total number of items yielded, and the maximum number of cycles. 114 | """ 115 | return { 116 | self.SOURCE_KEY: self.source.state_dict(), 117 | self.NUM_CYCLES_KEY: self._num_cycles, 118 | self.HAS_STARTED_KEY: self._has_started, 119 | self.NUM_YIELDED_KEY: self._num_yielded, 120 | self.MAX_CYCLES_KEY: self.max_cycles, 121 | } 122 | -------------------------------------------------------------------------------- /torchdata/nodes/exception_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch._utils import ExceptionWrapper 8 | 9 | 10 | class StartupExceptionWrapper(ExceptionWrapper): 11 | pass 12 | -------------------------------------------------------------------------------- /torchdata/nodes/filter.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, Optional, TypeVar 2 | 3 | from torchdata.nodes import BaseNode 4 | 5 | 6 | T = TypeVar("T") 7 | 8 | 9 | class Filter(BaseNode[T]): 10 | """Node that filters items from source node based on predicate function. 11 | 12 | This node applies a filter function to each item from the source node and only yields 13 | items that satisfy the condition (when filter_fn returns True). It keeps track of both 14 | the number of items that were filtered out (rejected) and the number of items that were 15 | yielded (accepted). 16 | 17 | Args: 18 | source_node (BaseNode[T]): The source node to filter items from. 19 | filter_fn (Callable[[T], bool]): A function that takes an item and returns True if the item 20 | should be included, False otherwise. 21 | """ 22 | 23 | SOURCE_KEY = "source" 24 | NUM_FILTERED_KEY = "num_filtered" 25 | NUM_YIELDED_KEY = "num_yielded" 26 | 27 | def __init__(self, source_node: BaseNode[T], filter_fn: Callable[[T], bool]): 28 | super().__init__() 29 | self.source = source_node 30 | self.filter_fn = filter_fn 31 | self._num_filtered = 0 # Count of items that did NOT pass the filter 32 | self._num_yielded = 0 # Count of items that DID pass the filter 33 | 34 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 35 | """Reset the node to its initial state or to the provided state. 36 | 37 | Args: 38 | initial_state: Optional state dictionary to restore from. 39 | """ 40 | super().reset(initial_state) 41 | if initial_state is not None: 42 | self.source.reset(initial_state.get(self.SOURCE_KEY)) 43 | self._num_filtered = initial_state.get(self.NUM_FILTERED_KEY, 0) 44 | self._num_yielded = initial_state.get(self.NUM_YIELDED_KEY, 0) 45 | else: 46 | self.source.reset(None) 47 | self._num_filtered = 0 48 | self._num_yielded = 0 49 | 50 | def next(self) -> T: 51 | """Get the next item that passes the filter. 52 | 53 | Returns: 54 | The next item that satisfies the filter condition. 55 | 56 | Raises: 57 | StopIteration: If there are no more items in the source node. 58 | """ 59 | while True: 60 | item = next(self.source) 61 | if self.filter_fn(item): 62 | self._num_yielded += 1 63 | return item 64 | self._num_filtered += 1 65 | 66 | def get_state(self) -> Dict[str, Any]: 67 | """Get the current state of the node. 68 | 69 | Returns: 70 | A dictionary containing the state of the source node and counters. 71 | """ 72 | return { 73 | self.SOURCE_KEY: self.source.state_dict(), 74 | self.NUM_FILTERED_KEY: self._num_filtered, 75 | self.NUM_YIELDED_KEY: self._num_yielded, 76 | } 77 | -------------------------------------------------------------------------------- /torchdata/nodes/header.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Optional, TypeVar 8 | 9 | from torchdata.nodes import BaseNode 10 | 11 | 12 | T = TypeVar("T") 13 | 14 | 15 | class Header(BaseNode[T]): 16 | """Node that yields only the first N items from source node. 17 | 18 | This node limits the number of items yielded from the source node to at most N. 19 | After N items have been yielded, it will raise StopIteration on subsequent calls 20 | to next(), even if the source node has more items available. 21 | 22 | Args: 23 | source_node (BaseNode[T]): The source node to pull items from. 24 | n (int): The maximum number of items to yield. Must be non-negative. 25 | """ 26 | 27 | SOURCE_KEY = "source" 28 | NUM_YIELDED_KEY = "num_yielded" 29 | 30 | def __init__(self, source_node: BaseNode[T], n: int): 31 | super().__init__() 32 | if n < 0: 33 | raise ValueError("n must be non-negative") 34 | self.source = source_node 35 | self.n = n 36 | self._num_yielded = 0 # Count of items yielded so far 37 | 38 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 39 | """Reset the node to its initial state or to the provided state. 40 | 41 | Args: 42 | initial_state: Optional state dictionary to restore from. 43 | """ 44 | super().reset(initial_state) 45 | if initial_state is not None: 46 | # Be strict about required keys in the state 47 | self.source.reset(initial_state[self.SOURCE_KEY]) 48 | self._num_yielded = initial_state[self.NUM_YIELDED_KEY] 49 | else: 50 | self.source.reset(None) 51 | self._num_yielded = 0 52 | 53 | def next(self) -> T: 54 | """Get the next item from the source node if fewer than N items have been yielded. 55 | 56 | Returns: 57 | The next item from the source node. 58 | 59 | Raises: 60 | StopIteration: If N items have already been yielded or the source is exhausted. 61 | """ 62 | if self._num_yielded >= self.n: 63 | raise StopIteration 64 | 65 | item = next(self.source) 66 | self._num_yielded += 1 67 | return item 68 | 69 | def get_state(self) -> Dict[str, Any]: 70 | """Get the current state of the node. 71 | 72 | Returns: 73 | Dict[str, Any] - A dictionary containing the state of the source node and number of cycles completed. 74 | """ 75 | return { 76 | self.SOURCE_KEY: self.source.state_dict(), 77 | self.NUM_YIELDED_KEY: self._num_yielded, 78 | } 79 | -------------------------------------------------------------------------------- /torchdata/nodes/loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Generic, Optional 8 | 9 | from torchdata.nodes.base_node import BaseNode, T 10 | 11 | 12 | class Loader(Generic[T]): 13 | """Wraps the root BaseNode (an iterator) and provides a stateful iterable interface. 14 | 15 | The state of the last-returned iterator is returned by the state_dict() method, and can be 16 | loaded using the load_state_dict() method. 17 | 18 | Args: 19 | root (BaseNode[T]): The root node of the data pipeline. 20 | restart_on_stop_iteration (bool): Whether to restart the iterator when it reaches the end. Default is True 21 | """ 22 | 23 | def __init__(self, root: BaseNode[T], restart_on_stop_iteration: bool = True): 24 | super().__init__() 25 | self.root = root 26 | self.restart_on_stop_iteration = restart_on_stop_iteration 27 | self._next_iter_state_dict: Optional[Dict[str, Any]] = None 28 | self._it: Optional[LoaderIterator[T]] = None 29 | # Tracks whether an iterator was created solely for getting a state_dict, in which case 30 | # we don't want to reset the iterator. Consider these two cases, which should behave the same 31 | # it = iter(loader) 32 | # sd = loader.state_dict() # No extra __iter__ call as _it already exists 33 | # for _ in it: ... 34 | # -------- 35 | # sd = loader.state_dict() # Calls __iter__ since _it is None 36 | # it = iter(loader) # We don't want to reset the iterator here again 37 | # for _ in it: ... 38 | self._iter_for_state_dict: bool = False 39 | 40 | def __iter__(self): 41 | if self._it is None: 42 | self._it = LoaderIterator(self) 43 | elif self._iter_for_state_dict: 44 | self._iter_for_state_dict = False 45 | return self._it # This was already pre-called to get a state dict 46 | 47 | if self._next_iter_state_dict is not None: 48 | self._it.reset(initial_state=self._next_iter_state_dict) 49 | self._next_iter_state_dict = None 50 | if self.restart_on_stop_iteration and not self._it.has_next(): 51 | self._it.reset(None) 52 | else: 53 | self._it.reset(None) 54 | 55 | return self._it 56 | 57 | def load_state_dict(self, state_dict: Dict[str, Any]): 58 | """Loads a state_dict which will be used to initialize the next iter() requested 59 | from this loader. 60 | 61 | Args: 62 | state_dict (Dict[str, Any]): The state_dict to load. Should be generated from a call to state_dict(). 63 | """ 64 | self._next_iter_state_dict = state_dict 65 | 66 | def state_dict(self) -> Dict[str, Any]: 67 | """Returns a state_dict which can be passed to load_state_dict() in the future to 68 | resume iteration. 69 | 70 | The state_dict will come from the iterator returned by the most recent call to iter(). 71 | If no iterator has been created, a new iterator will be created and the state_dict returned from it. 72 | """ 73 | if self._it is None: 74 | iter(self) 75 | self._iter_for_state_dict = True 76 | return self._it.state_dict() # type:ignore[union-attr] 77 | 78 | 79 | class LoaderIterator(BaseNode[T]): 80 | """An iterator class that wraps a root node and works with the Loader class. 81 | 82 | The LoaderIterator object saves state of the underlying root node, and calls reset on the root node when 83 | the iterator is exhausted or on a reset call. We look one step ahead to determine if the iterator is exhausted. 84 | The state of the iterator is saved in the state_dict() method, and can be loaded on reset calls. 85 | 86 | Args: 87 | loader (Loader[T]): The loader object that contains the root node. 88 | """ 89 | 90 | NUM_YIELDED_KEY = "num_yielded" 91 | ROOT_KEY = "root" 92 | 93 | def __init__( 94 | self, 95 | loader: Loader[T], 96 | ): 97 | super().__init__() 98 | self.loader = loader 99 | self.root = loader.root 100 | self._cached_item = None 101 | self._cached_state_dict: Optional[Dict[str, Any]] = None 102 | self._num_yielded = 0 103 | 104 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 105 | super().reset(initial_state) 106 | if initial_state is not None: 107 | self.root.reset(initial_state[self.ROOT_KEY]) 108 | self._num_yielded = initial_state[self.NUM_YIELDED_KEY] 109 | else: 110 | self.root.reset(None) 111 | self._num_yielded = 0 112 | self._cached_item = None 113 | 114 | def has_next(self) -> bool: 115 | if self._cached_item is None: 116 | try: 117 | # Cache the current state dict 118 | self._cached_state_dict = self.state_dict() 119 | # Load and save the next item 120 | self._cached_item = next(self) 121 | except StopIteration: 122 | pass 123 | return self._cached_item is not None 124 | 125 | def next(self): 126 | if self._cached_item is not None: 127 | item = self._cached_item 128 | self._cached_item = None 129 | self._cached_state_dict = None 130 | else: 131 | item = next(self.root) 132 | self._num_yielded += 1 133 | return item 134 | 135 | def get_state(self) -> Dict[str, Any]: 136 | if self._cached_state_dict is not None: 137 | return self._cached_state_dict 138 | return { 139 | self.ROOT_KEY: self.root.state_dict(), 140 | self.NUM_YIELDED_KEY: self._num_yielded, 141 | } 142 | -------------------------------------------------------------------------------- /torchdata/nodes/pin_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | import queue 9 | import threading 10 | 11 | from typing import Any, Dict, Optional, Union 12 | 13 | import torch 14 | import torch.multiprocessing 15 | 16 | from torch.utils.data._utils.pin_memory import pin_memory 17 | from torchdata.nodes.base_node import BaseNode, T 18 | 19 | from torchdata.nodes.exception_wrapper import ExceptionWrapper, StartupExceptionWrapper 20 | from torchdata.nodes.map import _SingleThreadedMapper 21 | from torchdata.nodes.snapshot_store import MonotonicIndex, SnapshotStore 22 | 23 | 24 | def _pin_memory_loop( 25 | source: BaseNode, 26 | q: queue.Queue, 27 | snapshot_store: SnapshotStore, 28 | snapshot_frequency: int, 29 | semaphore: threading.BoundedSemaphore, 30 | stop_event: threading.Event, 31 | device_id: Union[int, str], 32 | device: Optional[str], 33 | ): 34 | """This is fork of from torch.utils.data._utils.pin_memory import _pin_memory_loop 35 | to remove the index tuples. 36 | 37 | This setting is thread local, and prevents the copy in pin_memory from 38 | consuming all CPU cores. 39 | """ 40 | 41 | idx = MonotonicIndex() 42 | 43 | def _put( 44 | item, 45 | block: bool = True, 46 | snapshot: Optional[Union[Dict[str, Any], StartupExceptionWrapper]] = None, 47 | ): 48 | _idx = idx.get() 49 | if snapshot: 50 | snapshot_store.append(snapshot=snapshot, version=_idx) 51 | q.put((item, _idx), block=block) 52 | 53 | try: 54 | torch.set_num_threads(1) 55 | 56 | torch.multiprocessing._set_thread_name("pt_data_pin") 57 | 58 | if device == "cuda": 59 | torch.cuda.set_device(device_id) 60 | elif device == "xpu": 61 | torch.xpu.set_device(device_id) # type: ignore[attr-defined] 62 | elif device == torch._C._get_privateuse1_backend_name(): 63 | custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) 64 | custom_device_mod.set_device(device_id) 65 | 66 | assert ( 67 | isinstance(snapshot_frequency, int) and snapshot_frequency >= 0 68 | ), f"snapshot_frequency must be non-negative integer! Got {snapshot_frequency}" 69 | snapshot_store.append_initial_snapshot(snapshot=source.state_dict()) 70 | except Exception: 71 | e = StartupExceptionWrapper(where=f"in _pin_memory_loop startup for device {device_id}") 72 | snapshot_store.append_initial_snapshot(snapshot=e) 73 | return 74 | 75 | yielded = 0 76 | while not stop_event.is_set(): 77 | if not semaphore.acquire(blocking=True, timeout=0.1): 78 | continue 79 | try: 80 | item = next(source) 81 | item = pin_memory(item, device) 82 | yielded += 1 83 | snapshot = None 84 | if snapshot_frequency > 0 and yielded % snapshot_frequency == 0: 85 | snapshot = source.state_dict() 86 | _put(item, block=False, snapshot=snapshot) 87 | except StopIteration as e: 88 | item = e 89 | _put(item, block=False) 90 | break 91 | except Exception: 92 | item = ExceptionWrapper(where=f"in _pin_memory_loop for device {device_id}") 93 | _put(item, block=False) 94 | break 95 | 96 | 97 | class PinMemory(BaseNode[T]): 98 | """Pins the data of the underlying node to a device. This is backed by torch.utils.data._utils.pin_memory._pin_memory_loop. 99 | 100 | Args: 101 | source (BaseNode[T]): The source node to pin the data from. 102 | pin_memory_device (str): The device to pin the data to. Default is "". 103 | snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 104 | 1, which means that the state of the source node will be snapshotted after every item. If set 105 | to a higher value, the state of the source node will be snapshotted after every snapshot_frequency 106 | items. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | source: BaseNode[T], 112 | pin_memory_device: str = "", 113 | snapshot_frequency: int = 1, 114 | ): 115 | super().__init__() 116 | self.source = source 117 | self.snapshot_frequency = snapshot_frequency 118 | if len(pin_memory_device) == 0: 119 | self._pin_memory_device = None 120 | else: 121 | self._pin_memory_device = pin_memory_device 122 | 123 | if self._pin_memory_device == "xpu": 124 | self._current_device = torch.xpu.current_device() # type: ignore[attr-defined] 125 | elif self._pin_memory_device == torch._C._get_privateuse1_backend_name(): 126 | custom_device_mod = getattr(torch, torch._C._get_privateuse1_backend_name()) 127 | self._current_device = custom_device_mod.current_device() 128 | else: 129 | self._current_device = torch.cuda.current_device() 130 | 131 | self._it: Optional[_SingleThreadedMapper[T]] = None 132 | 133 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 134 | super().reset(initial_state) 135 | if self._it is not None: 136 | self._it._shutdown() 137 | del self._it 138 | 139 | self._it = _SingleThreadedMapper( 140 | source=self.source, 141 | prefetch_factor=1, 142 | worker=functools.wraps(_pin_memory_loop)( 143 | functools.partial( 144 | _pin_memory_loop, 145 | device_id=self._current_device, 146 | device=self._pin_memory_device, 147 | ) 148 | ), 149 | snapshot_frequency=self.snapshot_frequency, 150 | initial_state=initial_state, 151 | ) 152 | 153 | def next(self): 154 | return next(self._it) # type: ignore[arg-type, union-attr] 155 | 156 | def get_state(self) -> Dict[str, Any]: 157 | return self._it.get_state() # type: ignore[union-attr] 158 | -------------------------------------------------------------------------------- /torchdata/nodes/prefetch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Optional 8 | 9 | from torchdata.nodes import BaseNode, T 10 | 11 | from torchdata.nodes.map import _SingleThreadedMapper 12 | 13 | from ._populate_queue import _populate_queue 14 | 15 | 16 | class Prefetcher(BaseNode[T]): 17 | """Prefetches data from the source node and stores it in a queue. 18 | 19 | Args: 20 | source (BaseNode[T]): The source node to prefetch data from. 21 | prefetch_factor (int): The number of items to prefetch ahead of time. 22 | snapshot_frequency (int): The frequency at which to snapshot the state of the source node. Default is 23 | 1, which means that the state of the source node will be snapshotted after every item. If set 24 | to a higher value, the state of the source node will be snapshotted after every snapshot_frequency 25 | items. 26 | """ 27 | 28 | def __init__(self, source: BaseNode[T], prefetch_factor: int, snapshot_frequency: int = 1): 29 | super().__init__() 30 | self.source = source 31 | self.prefetch_factor = prefetch_factor 32 | self.snapshot_frequency = snapshot_frequency 33 | self._it: Optional[_SingleThreadedMapper[T]] = None 34 | 35 | def reset(self, initial_state: Optional[Dict[str, Any]] = None): 36 | super().reset(initial_state) 37 | if self._it is not None: 38 | self._it._shutdown() 39 | del self._it 40 | self._it = _SingleThreadedMapper( 41 | source=self.source, 42 | prefetch_factor=self.prefetch_factor, 43 | worker=_populate_queue, 44 | snapshot_frequency=self.snapshot_frequency, 45 | initial_state=initial_state, 46 | ) 47 | 48 | def next(self): 49 | return next(self._it) # type: ignore[arg-type] 50 | 51 | def get_state(self) -> Dict[str, Any]: 52 | return self._it.get_state() # type: ignore[union-attr] 53 | -------------------------------------------------------------------------------- /torchdata/nodes/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/data/c66155a50f5b703730075987842d133fb6d7099f/torchdata/nodes/samplers/__init__.py -------------------------------------------------------------------------------- /torchdata/nodes/samplers/stop_criteria.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | class StopCriteria: 9 | """ 10 | Stopping criteria for the dataset samplers. 11 | 12 | 1) CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED: Stop once the last unseen dataset is exhausted. 13 | All datasets are seen at least once. In certain cases, some datasets may be 14 | seen more than once when there are still non-exhausted datasets. 15 | 16 | 2) ALL_DATASETS_EXHAUSTED: Stop once all have the datasets are exhausted. Each 17 | dataset is seen exactly once. No wraparound or restart will be performed. 18 | 19 | 3) FIRST_DATASET_EXHAUSTED: Stop when the first dataset is exhausted. 20 | 21 | 4) CYCLE_FOREVER: Cycle through the datasets by reinitializing each exhausted source nodes. 22 | This is useful when trainer want control over certain number of steps instead of epochs. 23 | """ 24 | 25 | CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED = "CYCLE_UNTIL_ALL_DATASETS_EXHAUSTED" 26 | ALL_DATASETS_EXHAUSTED = "ALL_DATASETS_EXHAUSTED" 27 | FIRST_DATASET_EXHAUSTED = "FIRST_DATASET_EXHAUSTED" 28 | CYCLE_FOREVER = "CYCLE_FOREVER" 29 | -------------------------------------------------------------------------------- /torchdata/nodes/samplers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | def _get_rank_seed(seed: int, generator_rank: torch.Generator, rank: int, world_size: int, epoch: int) -> int: 14 | generator_rank.manual_seed(seed * world_size + rank) 15 | return int(torch.randint(0, 2 ** 32 - 1, size=(epoch + 1,), generator=generator_rank)[-1].item()) 16 | 17 | 18 | def get_rank_and_world_size() -> tuple[int, int]: 19 | """ 20 | Returns the rank and world size of the current process. 21 | If distributed is initialized, returns the rank and world size from the distributed environment. 22 | If distributed is not initialized, returns the rank and world size from the environment variables. 23 | If neither distributed nor environment variables are set, returns a rank of 0 and a world size of 1. 24 | """ 25 | if dist.is_available() and dist.is_initialized(): 26 | rank, world_size = dist.get_rank(), dist.get_world_size() 27 | else: 28 | _rank = os.environ.get("RANK", "0") 29 | _world_size = os.environ.get("WORLD_SIZE", "1") 30 | try: 31 | rank = int(_rank) 32 | world_size = int(_world_size) 33 | except ValueError: 34 | rank = 0 35 | world_size = 1 36 | 37 | if rank >= world_size or rank < 0: 38 | raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {world_size - 1}]") 39 | 40 | return rank, world_size 41 | -------------------------------------------------------------------------------- /torchdata/nodes/snapshot_store.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import queue 8 | import threading 9 | import time 10 | from concurrent.futures import Future 11 | from dataclasses import dataclass 12 | from typing import Any, Optional, Protocol, Union 13 | 14 | from torchdata.nodes.constants import QUEUE_TIMEOUT 15 | 16 | from torchdata.nodes.exception_wrapper import ExceptionWrapper 17 | 18 | 19 | @dataclass 20 | class MonotonicIndex: 21 | initial: int = 0 22 | 23 | def __post_init__(self): 24 | self._idx = self.initial 25 | 26 | def get(self) -> int: 27 | idx = self._idx 28 | self._idx += 1 29 | return idx 30 | 31 | 32 | class SnapshotStore(Protocol): 33 | """Protocol for passing snapshot state around between threads and processes""" 34 | 35 | def append(self, snapshot: Any, version: int): 36 | ... 37 | 38 | def pop_version(self, version: int) -> Optional[Any]: 39 | ... 40 | 41 | def append_initial_snapshot(self, snapshot: Any): 42 | ... 43 | 44 | def get_initial_snapshot(self, thread: Union[Future, threading.Thread], timeout: float) -> Any: 45 | ... 46 | 47 | 48 | class QueueSnapshotStore(SnapshotStore): 49 | """A snapshot store that uses a queue to store snapshots""" 50 | 51 | SNAPSHOT_INIT_VERSION = -1 52 | 53 | def __init__(self) -> None: 54 | self._q: queue.Queue = queue.Queue() 55 | self._lock = threading.Lock() 56 | self._max_version: int = -1000 57 | 58 | def append(self, snapshot: Any, version: int) -> None: 59 | with self._lock: 60 | if version <= self._max_version: 61 | raise ValueError(f"{version=} is not strictly greater than {self._max_version=}") 62 | self._max_version = version 63 | self._q.put((version, snapshot)) 64 | 65 | def pop_version(self, version: int) -> Optional[Any]: 66 | ver, val = None, None 67 | with self._lock: 68 | # pop all items that have a lesser version index 69 | while self._q.queue and version >= self._q.queue[0][0]: 70 | ver, val = self._q.get_nowait() 71 | 72 | if ver == version: 73 | return val 74 | else: 75 | return None 76 | 77 | def append_initial_snapshot(self, snapshot: Any) -> None: 78 | self.append(snapshot, self.SNAPSHOT_INIT_VERSION) 79 | 80 | def get_initial_snapshot(self, thread: Union[Future, threading.Thread], timeout: float = 60.0) -> Any: 81 | snapshot = None 82 | ver = None 83 | 84 | ack_t0 = time.time() 85 | while snapshot is None and time.time() - ack_t0 < timeout: 86 | try: 87 | ver, snapshot = self._q.get(timeout=QUEUE_TIMEOUT) 88 | except queue.Empty: 89 | pass 90 | # Don't test this until after QUEUE_TIMEOUT has elapsed because 91 | # thread may inadvertently report "is_alive()==False" 92 | if isinstance(thread, Future) and not thread.running(): 93 | break 94 | if isinstance(thread, threading.Thread) and not thread.is_alive(): 95 | break 96 | 97 | if snapshot is not None and isinstance(snapshot, ExceptionWrapper): 98 | snapshot.reraise() 99 | 100 | if snapshot is None or ver != self.SNAPSHOT_INIT_VERSION: 101 | error_msg = thread.is_alive() if isinstance(thread, threading.Thread) else thread.running() 102 | raise RuntimeError( 103 | f"Failed to get initial snapshot after {time.time() - ack_t0} seconds! thread.is_alive()={error_msg} {snapshot=}, {ver=}" 104 | ) 105 | 106 | return snapshot 107 | -------------------------------------------------------------------------------- /torchdata/nodes/types.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from typing import Any, Dict, Protocol, runtime_checkable 9 | 10 | 11 | @runtime_checkable 12 | class Stateful(Protocol): 13 | """Protocol for objects implementing both ``state_dict()`` and ``load_state_dict(state_dict: Dict[str, Any])``""" 14 | 15 | def state_dict(self) -> Dict[str, Any]: 16 | ... 17 | 18 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 19 | ... 20 | -------------------------------------------------------------------------------- /torchdata/stateful_dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .stateful import Stateful 8 | from .stateful_dataloader import StatefulDataLoader 9 | 10 | __all__ = ["Stateful", "StatefulDataLoader"] 11 | -------------------------------------------------------------------------------- /torchdata/stateful_dataloader/stateful.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Dict, Protocol, runtime_checkable 8 | 9 | 10 | @runtime_checkable 11 | class Stateful(Protocol): 12 | def state_dict(self) -> Dict[str, Any]: 13 | ... 14 | 15 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 16 | ... 17 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.12.0a0 2 | --------------------------------------------------------------------------------