├── s3torchconnectorclient ├── python │ ├── src │ │ └── s3torchconnectorclient │ │ │ ├── py.typed │ │ │ ├── _logger_patch.py │ │ │ ├── __init__.py │ │ │ └── _mountpoint_s3_client.pyi │ └── tst │ │ ├── unit │ │ ├── test_s3exception.py │ │ └── test_structs.py │ │ └── integration │ │ └── conftest.py ├── rust-toolchain.toml ├── MANIFEST.in ├── rust │ ├── build.rs │ └── src │ │ ├── python_structs.rs │ │ ├── build_info.rs │ │ ├── python_structs │ │ ├── py_list_object_result.rs │ │ ├── py_restore_status.rs │ │ ├── py_head_object_result.rs │ │ └── py_object_info.rs │ │ ├── exception.rs │ │ ├── lib.rs │ │ ├── mock_client.rs │ │ ├── get_object_stream.rs │ │ ├── put_object_stream.rs │ │ └── logger_setup.rs ├── deny.toml └── Cargo.toml ├── s3torchbenchmarking ├── src │ └── s3torchbenchmarking │ │ ├── pytorch_checkpointing │ │ ├── __init__.py │ │ └── benchmark.py │ │ ├── lightning_checkpointing │ │ ├── __init__.py │ │ ├── sample_counter.py │ │ ├── checkpoint_profiler.py │ │ └── benchmark.py │ │ ├── __init__.py │ │ ├── dataset │ │ └── __init__.py │ │ ├── dcp_ddp │ │ ├── __init__.py │ │ ├── README.md │ │ ├── save_benchmark.py │ │ └── load_benchmark.py │ │ ├── dcp_fsdp │ │ ├── __init__.py │ │ ├── llama_model_config.py │ │ ├── README.md │ │ └── load_benchmark.py │ │ ├── constants.py │ │ ├── benchmark_utils.py │ │ └── dcp_common.py ├── conf │ ├── dataloader │ │ ├── fsspec.yaml │ │ ├── mountpoint.yaml │ │ ├── mountpointcache.yaml │ │ ├── s3mapdataset.yaml │ │ └── s3iterabledataset.yaml │ ├── aws │ │ └── dynamodb.yaml │ ├── hydra │ │ └── callbacks │ │ │ └── collate_results.yaml │ ├── lightning_checkpointing.yaml │ ├── pytorch_checkpointing.yaml │ ├── dcp_ddp_save.yaml │ ├── dcp_ddp_load.yaml │ ├── dcp_fsdp_save.yaml │ ├── dcp_fsdp_load.yaml │ └── dataset.yaml ├── utils │ ├── run_dataset_benchmarks.sh │ ├── run_checkpoint_benchmarks.sh │ ├── run_lightning_benchmarks.sh │ ├── run_dcp_ddp_benchmarks.sh │ ├── run_dcp_fsdp_benchmarks.sh │ ├── prepare_nvme.sh │ ├── generate_datasets_files.sh │ ├── run_benchmarks.sh │ └── collect_and_write_to_dynamodb.py ├── tst │ └── test_compatibility.py └── pyproject.toml ├── NOTICE ├── .github ├── CODEOWNERS ├── workflows │ ├── ci-integration-pr.yml │ ├── ci-integration-main.yml │ ├── ci.yml │ ├── notify-issue.yml │ ├── notify-pr.yml │ ├── notify-comment.yml │ ├── notify-wheel-failure.yml │ ├── documentation.yml │ ├── generate_third_party_licenses.yml │ └── rust-checks.yml ├── ISSUE_TEMPLATE │ ├── feature_request.yml │ └── bug_report.yml ├── dependabot.yml └── pull_request_template.md ├── s3torchconnector ├── tst │ ├── e2e │ │ ├── dcp │ │ │ ├── __init__.py │ │ │ └── test_e2e_s3_storage_reader.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── lightning_transformer.py │ │ │ └── net.py │ │ ├── test_e2e_s3checkpoint.py │ │ ├── test_common.py │ │ ├── test_s3_client.py │ │ └── test_mountpoint_client_parallel_access.py │ ├── unit │ │ ├── __init__.py │ │ ├── dcp │ │ │ ├── __init__.py │ │ │ ├── test_s3_storage_writer.py │ │ │ └── test_s3_storage_reader.py │ │ ├── lightning │ │ │ └── __init__.py │ │ ├── test_version.py │ │ ├── test_lightning_missing.py │ │ ├── _hypothesis_python_primitives.py │ │ ├── _checkpoint_byteorder_patch.py │ │ ├── test_user_agent.py │ │ ├── test_s3_client_config.py │ │ ├── test_s3reader_constructor.py │ │ ├── test_s3writer.py │ │ └── test_s3dataset_common.py │ └── conftest.py ├── src │ └── s3torchconnector │ │ ├── _version.py │ │ ├── lightning │ │ └── __init__.py │ │ ├── _s3client │ │ ├── __init__.py │ │ ├── s3client_config.py │ │ └── _mock_s3client.py │ │ ├── _s3bucket_key_data.py │ │ ├── s3reader │ │ ├── __init__.py │ │ ├── protocol.py │ │ ├── s3reader.py │ │ └── constructor.py │ │ ├── dcp │ │ └── __init__.py │ │ ├── __init__.py │ │ ├── _user_agent.py │ │ ├── _s3dataset_common.py │ │ ├── s3checkpoint.py │ │ ├── _s3_bucket_iterable.py │ │ └── s3writer.py ├── docs │ ├── Makefile │ ├── make.bat │ ├── index.rst │ └── conf.py └── pyproject.toml ├── CODE_OF_CONDUCT.md ├── conftest.py ├── examples └── lightning │ ├── checkpoint_reading.py │ ├── checkpoint_manual_save.py │ ├── async_checkpoint_writing.py │ └── checkpoint_writing.py ├── LICENSE ├── .gitignore ├── run_cibuildwheel_on_ec2.sh └── CONTRIBUTING.md /s3torchconnectorclient/python/src/s3torchconnectorclient/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/pytorch_checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/lightning_checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dataloader/fsspec.yaml: -------------------------------------------------------------------------------- 1 | kind: fsspec 2 | batch_size: 128 3 | num_workers: 8 -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dataloader/mountpoint.yaml: -------------------------------------------------------------------------------- 1 | kind: mountpoint 2 | batch_size: 128 3 | num_workers: 8 -------------------------------------------------------------------------------- /s3torchconnectorclient/rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "1.84" 3 | components = ["rust-src"] -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Amazon S3 Connector for PyTorch 2 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dataloader/mountpointcache.yaml: -------------------------------------------------------------------------------- 1 | kind: mountpointcache 2 | batch_size: 128 3 | num_workers: 8 4 | -------------------------------------------------------------------------------- /s3torchconnectorclient/MANIFEST.in: -------------------------------------------------------------------------------- 1 | # MANIFEST.in 2 | include Cargo.toml 3 | recursive-include rust *.rs 4 | global-include py.typed 5 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Sets default owners for everything to the s3-connector-for-pytorch *team* 2 | * @awslabs/s3-connector-for-pytorch 3 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/dcp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/__init__.py: -------------------------------------------------------------------------------- 1 | from .hydra_callback import ResultCollatingCallback 2 | 3 | __all__ = ["ResultCollatingCallback"] 4 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/dcp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/aws/dynamodb.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | # DynamoDB config; used to save run results 3 | dynamodb: 4 | region: ??? 5 | table: ??? 6 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/run_dataset_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Run dataset benchmarks. 4 | 5 | ./utils/run_benchmarks.sh -s dataset -d ./nvme/ "$@" 6 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_ddp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_fsdp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/hydra/callbacks/collate_results.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | hydra: 3 | callbacks: 4 | my_callback: 5 | _target_: s3torchbenchmarking.ResultCollatingCallback 6 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/run_checkpoint_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Run PyTorch Checkpointing benchmarks. 4 | 5 | ./utils/run_benchmarks.sh -s pytorch_checkpointing -d ./nvme/ "$@" 6 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/run_lightning_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Run PyTorch Lightning Checkpointing benchmarks. 4 | 5 | ./utils/run_benchmarks.sh -s lightning_checkpointing -d ./nvme/ "$@" 6 | -------------------------------------------------------------------------------- /s3torchbenchmarking/tst/test_compatibility.py: -------------------------------------------------------------------------------- 1 | def test_imports(): 2 | from s3torchbenchmarking import datagen 3 | from s3torchbenchmarking.dataset import benchmark 4 | 5 | assert benchmark is not None 6 | assert datagen is not None 7 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/build.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | fn main() { 7 | built::write_built_file().expect("Failed to acquire build-time information"); 8 | } 9 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import importlib.metadata 5 | 6 | # __package__ is 's3torchconnector' 7 | __version__ = importlib.metadata.version(__package__) 8 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dataloader/s3mapdataset.yaml: -------------------------------------------------------------------------------- 1 | kind: s3mapdataset 2 | batch_size: 128 3 | num_workers: 8 4 | s3reader: 5 | # s3reader type: sequential or range_based 6 | type: sequential # default 7 | # buffer_size (bytes): only used with range_based s3reader type 8 | buffer_size: 8*1024*1024 # default 9 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/test_version.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from s3torchconnector import __version__ 4 | 5 | 6 | def test_connector_version(): 7 | assert isinstance(__version__, str) 8 | assert __version__ > "1.0.0" 9 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dataloader/s3iterabledataset.yaml: -------------------------------------------------------------------------------- 1 | kind: s3iterabledataset 2 | batch_size: 128 3 | num_workers: 8 4 | s3reader: 5 | # s3reader type: sequential or range_based 6 | type: sequential # default 7 | # buffer_size (bytes): only used with range_based s3reader type 8 | buffer_size: 8*1024*1024 # default 9 | -------------------------------------------------------------------------------- /s3torchconnectorclient/python/src/s3torchconnectorclient/_logger_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | 6 | TRACE = 5 7 | 8 | 9 | def _install_trace_logging(): 10 | logging.addLevelName(TRACE, "TRACE") 11 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/python_structs.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | pub(crate) mod py_list_object_result; 7 | pub(crate) mod py_object_info; 8 | pub(crate) mod py_restore_status; 9 | pub(crate) mod py_head_object_result; 10 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | # Get a nice error message if lightning isn't available. 5 | import lightning 6 | 7 | from .s3_lightning_checkpoint import S3LightningCheckpoint 8 | 9 | __all__ = [ 10 | "S3LightningCheckpoint", 11 | ] 12 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_s3client/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from .s3client_config import S3ClientConfig 5 | from ._s3client import S3Client 6 | from ._mock_s3client import MockS3Client 7 | 8 | __all__ = [ 9 | "S3ClientConfig", 10 | "S3Client", 11 | "MockS3Client", 12 | ] 13 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/build_info.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | // Information from build, made available by built crate. 7 | mod built { 8 | include!(concat!(env!("OUT_DIR"), "/built.rs")); 9 | } 10 | 11 | pub const PACKAGE_NAME: &str = built::PKG_NAME; 12 | pub const FULL_VERSION: &str = built::PKG_VERSION; 13 | -------------------------------------------------------------------------------- /.github/workflows/ci-integration-pr.yml: -------------------------------------------------------------------------------- 1 | name: Integration tests (PR) 2 | 3 | on: 4 | pull_request_target: 5 | branches: [ "main", "feature/*" ] 6 | 7 | permissions: 8 | id-token: write 9 | contents: read 10 | 11 | 12 | jobs: 13 | integration: 14 | name: Integration 15 | uses: ./.github/workflows/python-integration.yml 16 | with: 17 | environment: "integration-tests" 18 | ref: ${{ github.event.pull_request.head.sha }} 19 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/run_dcp_ddp_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Run PyTorch's Distributed Checkpointing (DCP) benchmarks using DistributedDataParallel (DDP) training. 4 | # Usage: 5 | # ./run_dcp_ddp_benchmarks.sh # Run save benchmarks (default) 6 | # ./run_dcp_ddp_benchmarks.sh --save # Run save benchmarks (explicit) 7 | # ./run_dcp_ddp_benchmarks.sh --load # Run load benchmarks 8 | 9 | ./utils/run_benchmarks.sh -s dcp_ddp -d ./nvme/ "$@" 10 | -------------------------------------------------------------------------------- /.github/workflows/ci-integration-main.yml: -------------------------------------------------------------------------------- 1 | name: Integration tests (Main) 2 | 3 | on: 4 | push: 5 | branches: [ "main", "feature/*", "workflow/*" ] 6 | merge_group: 7 | types: [ "checks_requested" ] 8 | workflow_dispatch: 9 | 10 | permissions: 11 | id-token: write 12 | contents: read 13 | 14 | jobs: 15 | integration: 16 | name: Integration 17 | uses: ./.github/workflows/python-integration.yml 18 | with: 19 | ref: ${{ github.event.after }} 20 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/run_dcp_fsdp_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Run PyTorch’s Distributed Checkpointing (DCP) benchmarks using Fully Sharded Data Parallel (FSDP) training. 4 | # Usage: 5 | # ./run_dcp_fsdp_benchmarks.sh # Run save benchmarks (default) 6 | # ./run_dcp_fsdp_benchmarks.sh --save # Run save benchmarks (explicit) 7 | # ./run_dcp_fsdp_benchmarks.sh --load # Run load benchmarks 8 | ./utils/run_benchmarks.sh -s dcp_fsdp -d ./nvme/ "$@" 9 | 10 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ "main", "feature/*" ] 6 | pull_request: 7 | branches: [ "main", "feature/*" ] 8 | merge_group: 9 | types: [ "checks_requested" ] 10 | 11 | permissions: 12 | contents: read 13 | 14 | jobs: 15 | rust-checks: 16 | name: Rust Checks 17 | uses: ./.github/workflows/rust-checks.yml 18 | 19 | python-checks: 20 | name: Python Checks 21 | uses: ./.github/workflows/python-checks.yml 22 | 23 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_s3bucket_key_data.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from typing import NamedTuple, Optional 4 | 5 | from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo 6 | 7 | 8 | class S3BucketKeyData(NamedTuple): 9 | """Read-only information about object stored in S3.""" 10 | 11 | bucket: str 12 | key: str 13 | object_info: Optional[ObjectInfo] = None 14 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/s3reader/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from .s3reader import S3Reader 5 | from .constructor import S3ReaderConstructor 6 | from .sequential import SequentialS3Reader 7 | from .ranged import RangedS3Reader 8 | from .protocol import GetStreamCallable, S3ReaderConstructorProtocol 9 | 10 | __all__ = [ 11 | "S3Reader", 12 | "S3ReaderConstructor", 13 | "SequentialS3Reader", 14 | "RangedS3Reader", 15 | ] 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Submit a feature request. This might be a feature you would like to see added, or one you would like to contribute. 3 | labels: ["enhancement"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thank you for taking the time to share ideas for new features! 9 | - type: textarea 10 | id: feature-desc 11 | attributes: 12 | label: Tell us more about this new feature. 13 | placeholder: I would like to ... 14 | validations: 15 | required: true 16 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | updates: 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: weekly 8 | - package-ecosystem: pip 9 | directory: "/s3torchconnectorclient" 10 | schedule: 11 | interval: weekly 12 | groups: 13 | python-packages: 14 | patterns: 15 | - "*" 16 | - package-ecosystem: pip 17 | directory: "/s3torchconnector" 18 | schedule: 19 | interval: weekly 20 | groups: 21 | python-packages: 22 | patterns: 23 | - "*" 24 | -------------------------------------------------------------------------------- /s3torchconnectorclient/python/src/s3torchconnectorclient/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import copyreg 5 | 6 | from ._logger_patch import TRACE as LOG_TRACE 7 | from ._logger_patch import _install_trace_logging 8 | from ._mountpoint_s3_client import S3Exception, __version__ 9 | 10 | _install_trace_logging() 11 | 12 | 13 | def _s3exception_reduce(exc: S3Exception): 14 | return S3Exception, exc.args 15 | 16 | 17 | copyreg.pickle(S3Exception, _s3exception_reduce) 18 | 19 | __all__ = ["LOG_TRACE", "S3Exception", "__version__"] 20 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/test_lightning_missing.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from importlib.util import find_spec 4 | 5 | import pytest 6 | 7 | 8 | # Skip if lightning is installed 9 | @pytest.mark.skipif( 10 | find_spec("lightning"), 11 | reason="Test verifies error message if lightning extension is used without installation", 12 | ) 13 | def test_lightning_not_installed(): 14 | with pytest.raises(ModuleNotFoundError) as e: 15 | import s3torchconnector.lightning 16 | assert str(e.value) == "No module named 'lightning'" 17 | -------------------------------------------------------------------------------- /.github/workflows/notify-issue.yml: -------------------------------------------------------------------------------- 1 | name: Issue Slack Notifier 2 | 3 | on: 4 | issues: 5 | types: [opened, reopened, edited] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | notify: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Send notification to Slack 15 | uses: slackapi/slack-github-action@v2.1.1 16 | with: 17 | webhook: ${{ secrets.SLACK_WEBHOOK_URL_ISSUE }} 18 | webhook-type: incoming-webhook 19 | payload: | 20 | { 21 | "action": "${{ github.event.action }}", 22 | "issue_url": "${{ github.event.issue.html_url }}" 23 | } 24 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import os 5 | from datetime import timedelta 6 | 7 | import pytest 8 | from hypothesis import settings 9 | 10 | settings.register_profile("ci", max_examples=1000, deadline=timedelta(seconds=1)) 11 | settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "default")) 12 | 13 | per_test_timeout = int(os.getenv("TEST_TIMEOUT", "120")) 14 | 15 | 16 | def pytest_collection_modifyitems(items): 17 | for item in items: 18 | if item.get_closest_marker("timeout") is None: 19 | item.add_marker(pytest.mark.timeout(per_test_timeout)) 20 | -------------------------------------------------------------------------------- /s3torchconnectorclient/deny.toml: -------------------------------------------------------------------------------- 1 | [sources] 2 | unknown-registry = "deny" 3 | unknown-git = "deny" 4 | 5 | [licenses] 6 | # TODO: Enable after publishing 7 | #unlicensed = "deny" 8 | allow = [ 9 | "Apache-2.0", 10 | "Apache-2.0 WITH LLVM-exception", 11 | "BSD-2-Clause", 12 | "BSD-3-Clause", 13 | "ISC", 14 | "MIT", 15 | "OpenSSL", 16 | "Unicode-DFS-2016", 17 | "Unicode-3.0", 18 | "Zlib", 19 | ] 20 | 21 | [[licenses.clarify]] 22 | name = "ring" 23 | version = ">= 0.13.0, < 0.18.0" 24 | expression = "MIT AND ISC AND OpenSSL" 25 | license-files = [ 26 | { path = "LICENSE", hash = 0xbd0eed23 } 27 | ] 28 | 29 | [advisories] 30 | version = 2 31 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/prepare_nvme.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Mount an NVMe drive (by default, at `./nvme/`). Script assumes that it is run on a DLAMI-based EC2 instance. 4 | 5 | nvme_dir=${1:-"./nvme/"} # default value 6 | 7 | if ! mountpoint -q "$nvme_dir"; then 8 | rm -rf "$nvme_dir" 9 | mkdir -p "$nvme_dir" 10 | 11 | if grep -q 'NAME="Amazon Linux"' /etc/os-release; then 12 | sudo mkfs -t xfs /dev/nvme1n1 13 | sudo mount /dev/nvme1n1 "$nvme_dir" 14 | elif grep -q 'NAME="Ubuntu"' /etc/os-release; then 15 | sudo /opt/aws/dlami/bin/nvme_ephemeral_drives.sh 16 | sudo mount /dev/vg.01/lv_ephemeral "$nvme_dir" 17 | fi 18 | 19 | sudo chmod 777 "$nvme_dir" 20 | fi 21 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/dcp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from .s3_file_system import S3FileSystem, S3StorageReader, S3StorageWriter 5 | from .s3_prefix_strategy import ( 6 | S3PrefixStrategyBase, 7 | DefaultPrefixStrategy, 8 | NumericPrefixStrategy, 9 | BinaryPrefixStrategy, 10 | HexPrefixStrategy, 11 | ) 12 | 13 | __all__ = [ 14 | "S3FileSystem", 15 | "S3StorageReader", 16 | "S3StorageWriter", 17 | "S3PrefixStrategyBase", 18 | "DefaultPrefixStrategy", 19 | "NumericPrefixStrategy", 20 | "BinaryPrefixStrategy", 21 | "HexPrefixStrategy", 22 | ] 23 | -------------------------------------------------------------------------------- /.github/workflows/notify-pr.yml: -------------------------------------------------------------------------------- 1 | name: Pull Request Slack Notifier 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened, reopened, synchronize] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | notify: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Send notification to Slack 15 | uses: slackapi/slack-github-action@v2.1.1 16 | with: 17 | webhook: ${{ secrets.SLACK_WEBHOOK_URL_PR }} 18 | webhook-type: incoming-webhook 19 | payload: | 20 | { 21 | "action": "${{ github.event.action }}", 22 | "url": "${{ github.event.pull_request.html_url || github.event.head_commit.url }}" 23 | } 24 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/lightning_checkpointing/sample_counter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from lightning import Callback 4 | import lightning.pytorch as pl 5 | 6 | 7 | class SampleCounter(Callback): 8 | def __init__(self) -> None: 9 | super().__init__() 10 | self._count = 0 11 | 12 | def on_train_batch_start( 13 | self, 14 | trainer: "pl.Trainer", 15 | pl_module: "pl.LightningModule", 16 | batch: Any, 17 | batch_idx: int, 18 | ) -> None: 19 | super().on_train_batch_start(trainer, pl_module, batch, batch_idx) 20 | self._count += len(batch) 21 | 22 | @property 23 | def count(self): 24 | return self._count 25 | -------------------------------------------------------------------------------- /s3torchconnector/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.github/workflows/notify-comment.yml: -------------------------------------------------------------------------------- 1 | name: Comment Slack Notifier 2 | 3 | on: 4 | issue_comment: 5 | types: [created, edited] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | notify: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - name: Send notification to Slack 15 | uses: slackapi/slack-github-action@v2.1.1 16 | with: 17 | webhook: ${{ secrets.SLACK_WEBHOOK_URL_ISSUE_COMMENT }} 18 | webhook-type: incoming-webhook 19 | payload: | 20 | { 21 | "action": "${{ github.event.action }}", 22 | "comment_url": "${{ github.event.comment.html_url }}", 23 | "content": ${{ toJSON(github.event.comment.body) }} 24 | } 25 | -------------------------------------------------------------------------------- /s3torchconnectorclient/python/tst/unit/test_s3exception.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import pickle 5 | 6 | from hypothesis import given 7 | from hypothesis.strategies import text 8 | 9 | from s3torchconnectorclient import S3Exception 10 | 11 | 12 | @given(text()) 13 | def test_pickles(message): 14 | exc = S3Exception(message) 15 | assert exc.args[0] == message 16 | unpickled = pickle.loads(pickle.dumps(exc)) 17 | assert unpickled.args[0] == message 18 | 19 | 20 | def test_multiple_arguments(): 21 | args = ("foo", 1) 22 | exc = S3Exception(*args) 23 | assert exc.args == args 24 | unpickled = pickle.loads(pickle.dumps(exc)) 25 | assert unpickled.args == args 26 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/s3reader/protocol.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from typing import Protocol, Callable, Optional, Union 5 | from .s3reader import S3Reader 6 | from s3torchconnectorclient._mountpoint_s3_client import ( 7 | ObjectInfo, 8 | GetObjectStream, 9 | HeadObjectResult, 10 | ) 11 | 12 | 13 | class GetStreamCallable(Protocol): 14 | def __call__( 15 | self, start: Optional[int] = None, end: Optional[int] = None 16 | ) -> GetObjectStream: ... 17 | 18 | 19 | class S3ReaderConstructorProtocol(Protocol): 20 | def __call__( 21 | self, 22 | bucket: str, 23 | key: str, 24 | get_object_info: Callable[[], Union[ObjectInfo, HeadObjectResult]], 25 | get_stream: GetStreamCallable, 26 | ) -> S3Reader: ... 27 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/_hypothesis_python_primitives.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from hypothesis.strategies import ( 5 | integers, 6 | characters, 7 | floats, 8 | booleans, 9 | deferred, 10 | tuples, 11 | dictionaries, 12 | lists, 13 | text, 14 | ) 15 | 16 | scalars = ( 17 | booleans() 18 | | integers() 19 | # Disallow nan as it doesn't have self-equality 20 | | floats(allow_nan=False) 21 | | characters() 22 | | text(max_size=10) 23 | ) 24 | 25 | hashable = deferred(lambda: (scalars | tuples(hashable))) 26 | 27 | python_primitives = deferred( 28 | lambda: ( 29 | hashable 30 | | lists(python_primitives, max_size=5) 31 | | dictionaries(keys=hashable, values=python_primitives, max_size=3) 32 | ) 33 | ) 34 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/dcp/test_s3_storage_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import pytest 5 | from s3torchconnector.dcp import S3StorageWriter 6 | 7 | TEST_REGION = "eu-east-1" 8 | TEST_BUCKET = "test-bucket" 9 | TEST_KEY = "test-key.txt" 10 | TEST_PATH = f"s3://{TEST_BUCKET}/{TEST_KEY}" 11 | 12 | 13 | @pytest.mark.parametrize("thread_count", [1, 2, 4, 8, 16]) 14 | def test_s3storage_writer_thread_count(thread_count): 15 | storage_writer = S3StorageWriter( 16 | region=TEST_REGION, path=TEST_PATH, thread_count=thread_count 17 | ) 18 | assert storage_writer.thread_count == thread_count 19 | 20 | 21 | def test_s3storage_writer_thread_count_defaults_to_one(): 22 | storage_writer = S3StorageWriter(region=TEST_REGION, path=TEST_PATH) 23 | assert storage_writer.thread_count == 1 24 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from s3torchconnectorclient import S3Exception 4 | 5 | # The order of these imports is the same in which they will be rendered 6 | # in the API docs generated with Sphinx. 7 | 8 | from .s3reader import S3Reader, S3ReaderConstructor 9 | from .s3writer import S3Writer 10 | from .s3iterable_dataset import S3IterableDataset 11 | from .s3map_dataset import S3MapDataset 12 | from .s3checkpoint import S3Checkpoint 13 | from ._version import __version__ 14 | from ._s3client import S3ClientConfig 15 | 16 | __all__ = [ 17 | "S3IterableDataset", 18 | "S3MapDataset", 19 | "S3Checkpoint", 20 | "S3Reader", 21 | "S3ReaderConstructor", 22 | "S3Writer", 23 | "S3Exception", 24 | "S3ClientConfig", 25 | "__version__", 26 | ] 27 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/lightning_checkpointing.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/callbacks/collate_results 3 | - aws/dynamodb # save run results to DynamoDB -- comment me if not required 4 | - _self_ 5 | 6 | # S3 bucket to use to save checkpoints. 7 | # NOTE: a non-existing bucket will fail the benchmarks. 8 | s3: 9 | region: ??? # e.g., eu-west-1 10 | uri: ??? # e.g., s3://my-bucket/ 11 | # Number of iterations for "saving" a model's checkpoint. 12 | epochs: 5 13 | # Number of training steps between checkpoints. 14 | save_one_in: 1 15 | 16 | hydra: 17 | mode: MULTIRUN 18 | sweep: 19 | dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S} 20 | sweeper: 21 | params: 22 | # Short name of a pre-trained model (from Hugging Face), listed in `models.py`. 23 | +model: clip-vit, T0_3B, T0pp 24 | # Checkpoint storage location (valid options: "disk", "s3"). 25 | +checkpoint.storage: disk, s3 26 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/pytorch_checkpointing.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/callbacks/collate_results 3 | - aws/dynamodb # save run results to DynamoDB -- comment me if not required 4 | - _self_ 5 | 6 | # S3 bucket to use to save checkpoints. 7 | # NOTE: a non-existing bucket will fail the benchmarks. 8 | s3: 9 | region: ??? # e.g., eu-west-1 10 | uri: ??? # e.g., s3://my-bucket/ 11 | # Number of iterations for "saving" a model's checkpoint. 12 | epochs: 5 13 | # Number of training steps between checkpoints. 14 | save_one_in: 1 15 | 16 | hydra: 17 | mode: MULTIRUN 18 | sweep: 19 | dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S} 20 | sweeper: 21 | params: 22 | # Short name of a pre-trained model (from Hugging Face), listed in `models.py`. 23 | +model: vit-base, whisper, clip-vit, T0_3B, T0pp 24 | # Checkpoint storage location (valid options: "disk", "s3"). 25 | +checkpoint.storage: disk, s3 26 | -------------------------------------------------------------------------------- /s3torchconnector/docs/make.bat: -------------------------------------------------------------------------------- 1 | pushd %~dp0 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set SOURCEDIR=. 9 | set BUILDDIR=_build 10 | 11 | %SPHINXBUILD% >NUL 2>NUL 12 | if errorlevel 9009 ( 13 | echo. 14 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 15 | echo.installed, then set the SPHINXBUILD environment variable to point 16 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 17 | echo.may add the Sphinx directory to PATH. 18 | echo. 19 | echo.If you don't have Sphinx installed, grab it from 20 | echo.https://www.sphinx-doc.org/ 21 | exit /b 1 22 | ) 23 | 24 | if "%1" == "" goto help 25 | 26 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 27 | goto end 28 | 29 | :help 30 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 31 | 32 | :end 33 | popd 34 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/_checkpoint_byteorder_patch.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from contextlib import contextmanager 5 | from unittest.mock import patch 6 | import torch 7 | from hypothesis.strategies import one_of, just 8 | 9 | byteorders = one_of(just("little"), just("big")) 10 | 11 | 12 | @contextmanager 13 | def _patch_byteorder(byteorder: str): 14 | with patch("torch.serialization.sys") as mock_sys: 15 | mock_sys.byteorder = byteorder 16 | yield 17 | 18 | 19 | def save_with_byteorder(data, fobj, byteorder: str, use_modern_pytorch_format: bool): 20 | with _patch_byteorder(byteorder): 21 | torch.save(data, fobj, _use_new_zipfile_serialization=use_modern_pytorch_format) 22 | 23 | 24 | def load_with_byteorder(fobj, byteorder): 25 | with _patch_byteorder(byteorder): 26 | return torch.load(fobj, weights_only=True) 27 | -------------------------------------------------------------------------------- /.github/workflows/notify-wheel-failure.yml: -------------------------------------------------------------------------------- 1 | name: Workflow Failure Slack Notifier 2 | 3 | on: 4 | workflow_run: 5 | workflows: 6 | - Build Wheels 7 | types: [completed] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | on-failure: 14 | runs-on: ubuntu-latest 15 | if: ${{ github.event.workflow_run.conclusion == 'failure' }} 16 | steps: 17 | - name: Send notification to Slack 18 | uses: slackapi/slack-github-action@v2.1.1 19 | with: 20 | webhook: ${{ secrets.SLACK_WEBHOOK_URL_WORKFLOW_FAILURE }} 21 | webhook-type: incoming-webhook 22 | payload: | 23 | { 24 | "action": "${{ github.event.action }}", 25 | "conclusion": "${{ github.event.workflow_run.conclusion }}", 26 | "workflow_name": "${{ github.event.workflow.name }}", 27 | "workflow_run_url": "${{ github.event.workflow_run.html_url }}" 28 | } 29 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/python_structs/py_list_object_result.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use pyo3::{pyclass, pymethods}; 7 | 8 | use crate::python_structs::py_object_info::PyObjectInfo; 9 | 10 | #[pyclass( 11 | name = "ListObjectResult", 12 | module = "s3torchconnectorclient._mountpoint_s3_client" 13 | )] 14 | #[derive(Debug)] 15 | pub struct PyListObjectResult { 16 | #[pyo3(get)] 17 | object_info: Vec, 18 | #[pyo3(get)] 19 | common_prefixes: Vec, 20 | } 21 | 22 | impl PyListObjectResult { 23 | pub(crate) fn new(object_info: Vec, common_prefixes: Vec) -> Self { 24 | Self { 25 | object_info, 26 | common_prefixes, 27 | } 28 | } 29 | } 30 | 31 | #[pymethods] 32 | impl PyListObjectResult { 33 | fn __repr__(&self) -> String { 34 | format!("{:?}", self) 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/generate_datasets_files.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Check if the list of names is provided as an argument 3 | if [ "$#" -lt 4 ]; then 4 | echo "Usage: $0 [name2] [name3] ..." 5 | exit 1 6 | fi 7 | 8 | PATH_TO_STORE_DATASETS=$1 9 | BUCKET_NAME=$2 10 | REGION_NAME=$3 11 | shift 3 12 | 13 | # Create an array from the remaining arguments (the datasets) 14 | datasets=("$@") 15 | 16 | mkdir -p "${PATH_TO_STORE_DATASETS}" 17 | 18 | for dataset in "${datasets[@]}" 19 | do 20 | file_name="${dataset}.yaml" 21 | has_shards=$(echo "${dataset}" | grep -c "shards") 22 | if [ "$has_shards" -gt 0 ]; then 23 | sharding="TAR" 24 | else 25 | sharding="null" 26 | fi 27 | 28 | echo "prefix_uri: s3://${BUCKET_NAME}/${dataset}/" > "${PATH_TO_STORE_DATASETS}/${file_name}" 29 | echo "region: ${REGION_NAME}" >> "${PATH_TO_STORE_DATASETS}/${file_name}" 30 | echo "sharding: ${sharding}" >> "${PATH_TO_STORE_DATASETS}/${file_name}" 31 | done 32 | -------------------------------------------------------------------------------- /s3torchbenchmarking/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "s3torchbenchmarking" 7 | version = "0.0.1" 8 | description = "Tools to run and compare benchmarks against various PyTorch connectors like the s3torchconnector." 9 | requires-python = ">=3.8,<3.14" 10 | readme = "README.md" 11 | dependencies = [ 12 | "s3torchconnector[lightning,dcp]", 13 | "boto3", 14 | "click", 15 | "hydra-core", 16 | "pandas", 17 | "pillow", 18 | "prefixed", 19 | "psutil", 20 | "pynvml", 21 | "requests", 22 | "s3fs>=2024", # prevents "UserWarning: Your installed version of s3fs is very old" type of warnings 23 | "torchdata<0.10.0", # we have dependency on deprecated DataPipes, which were removed in 0.10.0 24 | "torchvision", 25 | "transformers", 26 | ] 27 | 28 | [project.optional-dependencies] 29 | test = [ 30 | "pytest" 31 | ] 32 | 33 | [project.scripts] 34 | s3torch-datagen = "s3torchbenchmarking.datagen:synthesize_dataset" 35 | -------------------------------------------------------------------------------- /s3torchconnectorclient/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "s3torchconnectorclient" 3 | version = "1.4.3" 4 | edition = "2021" 5 | publish = false 6 | license = "BSD-3-Clause" 7 | build = "rust/build.rs" 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [lib] 11 | name = "_mountpoint_s3_client" 12 | crate-type = ["cdylib"] 13 | path = "rust/src/lib.rs" 14 | 15 | [build-dependencies] 16 | built = "0.7" 17 | 18 | [dependencies] 19 | pyo3 = "0.24.1" 20 | futures = "0.3.28" 21 | mountpoint-s3-client = { version = "0.14.1", features = ["mock"] } 22 | mountpoint-s3-crt-sys = { version = "0.13.0" } 23 | log = "0.4.20" 24 | tracing = { version = "0.1.40", default-features = false, features = ["std", "log"] } 25 | tracing-subscriber = { version = "0.3.20", features = ["fmt", "env-filter"]} 26 | nix = { version = "0.27.1", features = ["process"] } 27 | rusty-fork = "0.3.0" 28 | tracing-appender = "0.2.3" 29 | 30 | [features] 31 | default = ["extension-module"] 32 | extension-module = ["pyo3/extension-module"] 33 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_user_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from typing import List, Optional 4 | import platform 5 | 6 | from ._version import __version__ 7 | 8 | # https://www.rfc-editor.org/rfc/rfc9110#name-user-agent 9 | 10 | 11 | class UserAgent: 12 | def __init__(self, comments: Optional[List[str]] = None): 13 | if comments is not None and not isinstance(comments, list): 14 | raise ValueError("Argument comments must be a List[str]") 15 | python_version = platform.python_version() 16 | self._user_agent_prefix = ( 17 | f"{__package__}/{__version__} ua/2.0 lang/python#{python_version}" 18 | ) 19 | self._comments = comments or [] 20 | 21 | @property 22 | def prefix(self): 23 | comments_str = "; ".join(filter(None, self._comments)) 24 | if comments_str: 25 | return f"{self._user_agent_prefix} ({comments_str})" 26 | return self._user_agent_prefix 27 | -------------------------------------------------------------------------------- /examples/lightning/checkpoint_reading.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from lightning import Trainer 5 | from lightning.pytorch.demos import WikiText2, LightningTransformer 6 | from torch.utils.data import DataLoader 7 | 8 | from s3torchconnector.lightning import S3LightningCheckpoint 9 | 10 | 11 | def main(region: str, checkpoint_path: str): 12 | dataset = WikiText2() 13 | dataloader = DataLoader(dataset, num_workers=3) 14 | 15 | model = LightningTransformer(vocab_size=dataset.vocab_size) 16 | s3_lightning_checkpoint = S3LightningCheckpoint(region) 17 | 18 | trainer = Trainer( 19 | plugins=[s3_lightning_checkpoint], 20 | min_epochs=4, 21 | max_epochs=5, 22 | max_steps=3, 23 | ) 24 | # Load the checkpoint in `ckpt_path` before training 25 | trainer.fit(model, dataloader, ckpt_path=checkpoint_path) 26 | 27 | 28 | if __name__ == "__main__": 29 | import os 30 | 31 | main(os.getenv("REGION"), os.getenv("CHECKPOINT_PATH")) 32 | -------------------------------------------------------------------------------- /s3torchconnector/tst/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import pytest 5 | 6 | from s3torchconnector.s3reader import ( 7 | S3ReaderConstructor, 8 | S3ReaderConstructorProtocol, 9 | ) 10 | 11 | # Shared reader constructors for parametrized tests 12 | # TODO: use this variable in test_distributed_training.py and test_multiprocess_dataloading.py 13 | READER_CONSTRUCTORS = [ 14 | S3ReaderConstructor.sequential(), # Sequential Reader 15 | S3ReaderConstructor.range_based(), # Default range-based reader, with buffer 16 | S3ReaderConstructor.range_based(buffer_size=0), # range-based reader, no buffer 17 | ] 18 | 19 | 20 | @pytest.fixture( 21 | params=READER_CONSTRUCTORS, 22 | ids=["sequential", "range_based_with_buffer", "range_based_no_buffer"], 23 | scope="module", 24 | ) 25 | def reader_constructor(request) -> S3ReaderConstructorProtocol: 26 | """Provide reader constructor (partial(S3Reader)) instances for all supported reader types.""" 27 | return request.param 28 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/models/lightning_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html 3 | Source: https://github.com/Lightning-AI/pytorch-lightning/blob/master/docs/source-pytorch/common/lightning_module.rst 4 | License: https://github.com/Lightning-AI/pytorch-lightning/blob/master/LICENSE 5 | """ 6 | 7 | import lightning as L 8 | import torch 9 | 10 | from lightning.pytorch.demos import Transformer 11 | 12 | 13 | class LightningTransformer(L.LightningModule): 14 | def __init__(self, vocab_size): 15 | super().__init__() 16 | self.model = Transformer(vocab_size=vocab_size) 17 | 18 | def forward(self, inputs, target): 19 | return self.model(inputs, target) 20 | 21 | def training_step(self, batch, batch_idx): 22 | inputs, target = batch 23 | output = self(inputs, target) 24 | loss = torch.nn.functional.nll_loss(output, target.view(-1)) 25 | return loss 26 | 27 | def configure_optimizers(self): 28 | return torch.optim.SGD(self.model.parameters(), lr=0.1) 29 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | 5 | 6 | ## Additional context 7 | 8 | 9 | 10 | 11 | - [ ] I have updated the CHANGELOG or README if appropriate 12 | 13 | ## Related items 14 | 15 | 16 | ## Testing 17 | 18 | 19 | -------- 20 | By submitting this pull request, I confirm that my contribution is made under the terms of BSD 3-Clause License and I agree to the terms of the [LICENSE](https://github.com/awslabs/s3-connector-for-pytorch/blob/main/LICENSE). 21 | -------------------------------------------------------------------------------- /s3torchconnector/docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Amazon S3 Connector for PyTorch's documentation! 2 | =========================================================== 3 | The Amazon S3 Connector for PyTorch delivers high throughput for PyTorch training jobs that access or store data in Amazon S3. Using the S3 Connector for PyTorch 4 | automatically optimizes performance when downloading training data from and writing checkpoints to Amazon S3, eliminating the need to write your own code to list S3 buckets and manage concurrent requests. 5 | 6 | 7 | Amazon S3 Connector for PyTorch provides implementations of PyTorch's dataset primitives that you can use to load training data from Amazon S3. 8 | It supports both map-style datasets for random data access patterns and iterable-style datasets for streaming sequential data access patterns. 9 | The S3 Connector for PyTorch also includes a checkpointing interface to save and load checkpoints directly to Amazon S3, without first saving to local storage. 10 | 11 | .. toctree:: 12 | :maxdepth: 3 13 | :caption: Contents: 14 | 15 | Indices and tables 16 | ================== 17 | 18 | * :ref:`genindex` 19 | * :ref:`modindex` 20 | * :ref:`search` 21 | -------------------------------------------------------------------------------- /examples/lightning/checkpoint_manual_save.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from lightning import Trainer 5 | from lightning.pytorch.demos import WikiText2, LightningTransformer 6 | from torch.utils.data import DataLoader 7 | 8 | from s3torchconnector.lightning import S3LightningCheckpoint 9 | 10 | 11 | def main(region: str, checkpoint_path: str): 12 | dataset = WikiText2() 13 | dataloader = DataLoader(dataset, num_workers=3) 14 | 15 | model = LightningTransformer(vocab_size=dataset.vocab_size) 16 | s3_lightning_checkpoint = S3LightningCheckpoint(region) 17 | 18 | # No automatic checkpointing set up here. 19 | trainer = Trainer( 20 | plugins=[s3_lightning_checkpoint], 21 | enable_checkpointing=False, 22 | min_epochs=4, 23 | max_epochs=5, 24 | max_steps=3, 25 | ) 26 | trainer.fit(model, dataloader) 27 | # Manually create checkpoint to the desired location 28 | trainer.save_checkpoint(checkpoint_path) 29 | 30 | 31 | if __name__ == "__main__": 32 | import os 33 | 34 | main(os.getenv("REGION"), os.getenv("CHECKPOINT_PATH")) 35 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dcp_ddp_save.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/callbacks/collate_results 3 | - aws/dynamodb # save run results to DynamoDB -- comment me if not required 4 | - _self_ 5 | 6 | # S3 bucket to use to save checkpoints. 7 | # NOTE: a non-existing bucket will fail the benchmarks. 8 | s3: 9 | region: ??? # e.g., eu-west-1 10 | uri: ??? # e.g., s3://my-bucket/ 11 | # Number of iterations for "saving" a model's checkpoint. 12 | # NOTE: this does not affect model training, as no actual training occurs in these benchmarks. 13 | epochs: 4 14 | 15 | hydra: 16 | mode: MULTIRUN 17 | sweep: 18 | dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S} 19 | sweeper: 20 | params: 21 | # Short name of a pre-trained model (from Hugging Face), listed in `models.py`. 22 | +model: vit-base, T0_3B 23 | # Type of Torch distributed backend (valid options: "nccl", "gloo"). 24 | +backend: nccl 25 | # Number of workers. 26 | +world_size: 8 27 | # Number of threads to use for saving the checkpoints. 28 | +thread_count: 8 29 | # Checkpoint storage location (valid options: "disk", "s3"). 30 | +checkpoint.storage: disk, s3 31 | 32 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dcp_ddp_load.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/callbacks/collate_results 3 | - aws/dynamodb # save run results to DynamoDB -- comment me if not required 4 | - _self_ 5 | 6 | # S3 bucket to use to save checkpoints. 7 | # NOTE: a non-existing bucket will fail the benchmarks. 8 | s3: 9 | region: ??? # e.g., eu-west-1 10 | uri: ??? # e.g., s3://my-bucket/ 11 | # Number of iterations for "saving" a model's checkpoint. 12 | # NOTE: this does not affect model training, as no actual training occurs in these benchmarks. 13 | epochs: 4 14 | 15 | hydra: 16 | mode: MULTIRUN 17 | sweep: 18 | dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S} 19 | sweeper: 20 | params: 21 | # Short name of a pre-trained model (from Hugging Face), listed in `models.py`. 22 | +model: ??? 23 | # Type of Torch distributed backend (valid options: "nccl", "gloo"). 24 | +backend: nccl 25 | # Number of workers. 26 | +world_size: 8 27 | # Checkpoint storage location (valid options: "disk", "s3"). 28 | +checkpoint.storage: disk, s3 29 | # Checkpoint storage suffix location generated by save benchmarks, e.g., 2025-09-23-11-05-zmuZ/ 30 | +checkpoint.suffix: ??? 31 | 32 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/lightning_checkpointing/checkpoint_profiler.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | from typing import Dict, Any, Optional 3 | 4 | from lightning.fabric.plugins import CheckpointIO 5 | from lightning.fabric.utilities.types import _PATH 6 | 7 | 8 | class CheckpointProfiler(CheckpointIO): 9 | def __init__(self, delegate: CheckpointIO) -> None: 10 | super().__init__() 11 | self.delegate = delegate 12 | self.save_times = [] 13 | 14 | def load_checkpoint( 15 | self, path: _PATH, map_location: Optional[Any] = None 16 | ) -> Dict[str, Any]: 17 | return self.delegate.load_checkpoint(path, map_location) 18 | 19 | def remove_checkpoint(self, path: _PATH) -> None: 20 | self.delegate.remove_checkpoint(path) 21 | 22 | def save_checkpoint( 23 | self, 24 | checkpoint: Dict[str, Any], 25 | path: _PATH, 26 | storage_options: Optional[Any] = None, 27 | ) -> None: 28 | # TODO: should we profile other operations as well? 29 | start_time = perf_counter() 30 | self.delegate.save_checkpoint(checkpoint, path, storage_options) 31 | elapsed_time = perf_counter() - start_time 32 | self.save_times.append(elapsed_time) 33 | -------------------------------------------------------------------------------- /.github/workflows/documentation.yml: -------------------------------------------------------------------------------- 1 | name: documentation 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | 7 | permissions: 8 | contents: write 9 | 10 | jobs: 11 | docs: 12 | runs-on: ubuntu-22.04 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v6 16 | 17 | - name: Setup Python 18 | uses: actions/setup-python@v6 19 | 20 | - name: Install local packages to update versions 21 | run: | 22 | python -m pip install --upgrade pip 23 | # Manually install CPU-only version of torch so we're not carrying around giant GPU drivers/kernels 24 | python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 25 | python -m pip install -e "s3torchconnectorclient" 26 | python -m pip install -e "s3torchconnector" 27 | 28 | - name: Install dependencies 29 | run: | 30 | pip install sphinx sphinx_rtd_theme sphinx-autoapi ghp-import 31 | 32 | - name: Sphinx build 33 | run: | 34 | cd s3torchconnector/docs 35 | sphinx-build -b html . _build/html 36 | 37 | - name: Import docs 38 | run: | 39 | ghp-import -n -p -f s3torchconnector/docs/_build/html -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/models/net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html 3 | Source: https://github.com/pytorch/tutorials/blob/main/beginner_source/blitz/neural_networks_tutorial.py 4 | License: https://github.com/pytorch/tutorials/blob/main/LICENSE 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | 12 | class Net(nn.Module): 13 | def __init__(self): 14 | super(Net, self).__init__() 15 | self.conv1 = nn.Conv2d(3, 6, 5) 16 | self.pool = nn.MaxPool2d(2, 2) 17 | self.conv2 = nn.Conv2d(6, 16, 5) 18 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 19 | self.fc2 = nn.Linear(120, 84) 20 | self.fc3 = nn.Linear(84, 10) 21 | 22 | def forward(self, x): 23 | x = self.pool(F.relu(self.conv1(x))) 24 | x = self.pool(F.relu(self.conv2(x))) 25 | x = x.view(-1, 16 * 5 * 5) 26 | x = F.relu(self.fc1(x)) 27 | x = F.relu(self.fc2(x)) 28 | x = self.fc3(x) 29 | return x 30 | 31 | def equals(self, other_model: nn.Module) -> bool: 32 | for key_item_1, key_item_2 in zip( 33 | self.state_dict().items(), other_model.state_dict().items() 34 | ): 35 | if not torch.equal(key_item_1[1], key_item_2[1]): 36 | return False 37 | return True 38 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/exception.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use log::error; 7 | use std::error::Error; 8 | use std::fmt::Write; 9 | 10 | use pyo3::exceptions::PyException; 11 | use pyo3::PyErr; 12 | 13 | pyo3::create_exception!( 14 | s3torchconnectorclient._mountpoint_s3_client, 15 | S3Exception, 16 | PyException 17 | ); 18 | 19 | fn log_error(message: &str) { 20 | error!("ERROR: {}", message); 21 | } 22 | 23 | pub fn python_exception(error: impl Error) -> PyErr { 24 | let mut s = String::new(); 25 | let mut error: &dyn Error = &error; 26 | 27 | write!(&mut s, "{}", error).unwrap(); 28 | while let Some(next) = error.source() { 29 | error = next; 30 | write!(&mut s, ": {}", error).unwrap(); 31 | } 32 | 33 | let py_err = S3Exception::new_err(s); 34 | let py_err_str = format!("{}", py_err); 35 | log_error(&py_err_str); 36 | py_err 37 | } 38 | 39 | #[cfg(test)] 40 | mod tests { 41 | use std::io; 42 | 43 | use crate::exception::python_exception; 44 | 45 | #[test] 46 | fn test_python_exception() { 47 | pyo3::prepare_freethreaded_python(); 48 | 49 | let err = io::Error::new(io::ErrorKind::InvalidData, "Test message"); 50 | let pyerr = python_exception(err); 51 | 52 | assert_eq!(pyerr.to_string(), "S3Exception: Test message"); 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_s3client/s3client_config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from dataclasses import dataclass 4 | from typing import Optional 5 | 6 | 7 | @dataclass(frozen=True) 8 | class S3ClientConfig: 9 | """A dataclass exposing configurable parameters for the S3 client. 10 | 11 | Args: 12 | throughput_target_gbps(float): Throughput target in Gigabits per second (Gbps) that we are trying to reach. 13 | 10.0 Gbps by default (may change in future). 14 | part_size(int): Size (bytes) of file parts that will be uploaded/downloaded. 15 | Note: for saving checkpoints, the inner client will adjust the part size to meet the service limits. 16 | (max number of parts per upload is 10,000, minimum upload part size is 5 MiB). 17 | Part size must have values between 5MiB and 5GiB. 18 | 8MiB by default (may change in future). 19 | unsigned(bool): Set to true to disable signing S3 requests. 20 | force_path_style(bool): forceful path style addressing for S3 client. 21 | max_attempts(int): amount of retry attempts for retrieable errors. 22 | profile(str): Profile name to use for S3 authentication. 23 | """ 24 | 25 | throughput_target_gbps: float = 10.0 26 | part_size: int = 8 * 1024 * 1024 27 | unsigned: bool = False 28 | force_path_style: bool = False 29 | max_attempts: int = 10 30 | profile: Optional[str] = None 31 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/test_user_agent.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from __future__ import annotations 4 | 5 | from typing import List 6 | import platform 7 | 8 | import pytest 9 | 10 | from s3torchconnector._version import __version__ 11 | from s3torchconnector._user_agent import UserAgent 12 | 13 | PYTHON_VERSION = platform.python_version() 14 | DEFAULT_PREFIX = f"s3torchconnector/{__version__} ua/2.0 lang/python#{PYTHON_VERSION}" 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "comments, expected_prefix", 19 | [ 20 | (None, DEFAULT_PREFIX), 21 | ([], DEFAULT_PREFIX), 22 | ([""], DEFAULT_PREFIX), 23 | (["", ""], DEFAULT_PREFIX), 24 | ( 25 | ["component/version", "metadata"], 26 | f"{DEFAULT_PREFIX} (component/version; metadata)", 27 | ), 28 | ], 29 | ) 30 | def test_user_agent_creation(comments: List[str] | None, expected_prefix: str): 31 | user_agent = UserAgent(comments) 32 | assert user_agent.prefix == expected_prefix 33 | 34 | 35 | def test_default_user_agent_creation(): 36 | user_agent = UserAgent() 37 | assert user_agent.prefix == DEFAULT_PREFIX 38 | 39 | 40 | @pytest.mark.parametrize("invalid_comment", [0, "string"]) 41 | def test_invalid_comments_argument(invalid_comment): 42 | with pytest.raises(ValueError, match=r"Argument comments must be a List\[str\]"): 43 | UserAgent(invalid_comment) 44 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from typing import TypedDict, Union, Any, List 5 | 6 | JOB_RESULTS_FILENAME = "job_results.json" 7 | RUN_RESULTS_FILENAME = "run_results.json" 8 | 9 | # URLs for EC2 metadata retrieval (IMDSv2) 10 | URL_IMDS_TOKEN = "http://169.254.169.254/latest/api/token" 11 | URL_IMDS_DOCUMENT = "http://169.254.169.254/latest/dynamic/instance-identity/document" 12 | 13 | 14 | class JobResults(TypedDict): 15 | """Results from a Hydra job.""" 16 | 17 | config: Any 18 | metrics: Any 19 | 20 | 21 | class Versions(TypedDict): 22 | """Version numbers (Python, PyTorch, and other libraries).""" 23 | 24 | python: str 25 | pytorch: str 26 | hydra: str 27 | 28 | 29 | class EC2Metadata(TypedDict): 30 | """EC2 metadata (fetched from IMDSv2).""" 31 | 32 | architecture: str 33 | image_id: str 34 | instance_type: str 35 | region: str 36 | 37 | 38 | class RunResults(TypedDict): 39 | """Results from a Hydra run. 40 | 41 | Note: 42 | An instance of :class:`RunResults` will be inserted as-is in DynamoDB. 43 | """ 44 | 45 | s3torchconnector_version: str # PK (Partition Key) 46 | timestamp_utc: float # SK (Sort Key) 47 | scenario: str 48 | disambiguator: Union[str, None] # helps to identify multi-instances benchmarks 49 | run_elapsed_time_s: float 50 | versions: Versions 51 | ec2_metadata: Union[EC2Metadata, None] 52 | job_results: List[JobResults] 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /examples/lightning/async_checkpoint_writing.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from lightning import Trainer 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | from lightning.pytorch.demos import WikiText2, LightningTransformer 7 | from lightning.pytorch.plugins import AsyncCheckpointIO 8 | from torch.utils.data import DataLoader 9 | 10 | from s3torchconnector.lightning import S3LightningCheckpoint 11 | 12 | 13 | def main(region: str, checkpoint_path: str): 14 | dataset = WikiText2() 15 | dataloader = DataLoader(dataset, num_workers=3) 16 | 17 | model = LightningTransformer(vocab_size=dataset.vocab_size) 18 | s3_lightning_checkpoint = S3LightningCheckpoint(region) 19 | async_checkpoint = AsyncCheckpointIO(s3_lightning_checkpoint) 20 | 21 | # This will create one checkpoint per 'step', which we define later to be 8. 22 | # To checkpoint more or less often, change `every_n_train_steps`. 23 | checkpoint_callback = ModelCheckpoint( 24 | dirpath=checkpoint_path, 25 | save_top_k=-1, 26 | every_n_train_steps=1, 27 | filename="checkpoint-{epoch:02d}-{step:02d}", 28 | enable_version_counter=True, 29 | ) 30 | 31 | trainer = Trainer( 32 | plugins=[async_checkpoint], 33 | callbacks=[checkpoint_callback], 34 | min_epochs=4, 35 | max_epochs=8, 36 | max_steps=8, 37 | ) 38 | trainer.fit(model, dataloader) 39 | 40 | 41 | if __name__ == "__main__": 42 | import os 43 | 44 | main(os.getenv("REGION"), os.getenv("CHECKPOINT_PATH")) 45 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dcp_fsdp_save.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/callbacks/collate_results 3 | - aws/dynamodb # save run results to DynamoDB -- comment me if not required 4 | - _self_ 5 | 6 | # S3 bucket to use to save checkpoints. 7 | # NOTE: a non-existing bucket will fail the benchmarks. 8 | s3: 9 | region: ??? # e.g., eu-west-1 10 | uri: ??? # e.g., s3://my-bucket/ 11 | # Number of iterations for "saving" a model's checkpoint. 12 | # NOTE: this does not affect model training, as no actual training occurs in these benchmarks. 13 | epochs: 4 14 | 15 | hydra: 16 | mode: MULTIRUN 17 | sweep: 18 | dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S} 19 | sweeper: 20 | params: 21 | # Short name of a pre-trained llama v2 model (valid options: "L7b", "L13b", "L30b", "L65b", "L70b"). 22 | +model: L7b, L13b, L30b 23 | # Type of Torch distributed backend (valid options: "nccl", "gloo"). 24 | +backend: nccl 25 | # Number of workers. 26 | +world_size: 8 27 | # Number of threads to use for saving the checkpoints. 28 | +thread_count: 8 29 | # Checkpoint storage location (valid options: "disk", "s3"). 30 | +checkpoint.storage: disk, s3 31 | # Sharding strategy (valid options: "full", "hybrid"). 32 | +checkpoint.sharding_strategy: full 33 | # Controls whether files are forcibly synced to disk (only relevant for "disk" storage). 34 | # NOTE: We disabled this option to improve performance since FSDP checkpointing with 35 | # forced syncing (maximum durability) was significantly slower than storage throughput. 36 | # This setting has no effect when using "s3" storage. 37 | +checkpoint.sync_files: false 38 | 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python .gitignore (https://github.com/github/gitignore/blob/main/Python.gitignore) -- cherry-picked ################## 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # Unit test / coverage reports 32 | htmlcov/ 33 | .tox/ 34 | .nox/ 35 | .coverage 36 | .coverage.* 37 | .cache 38 | nosetests.xml 39 | coverage.xml 40 | *.cover 41 | *.py,cover 42 | .hypothesis/ 43 | .pytest_cache/ 44 | cover/ 45 | 46 | # Jupyter Notebook 47 | .ipynb_checkpoints 48 | 49 | # Environments 50 | .env 51 | .venv 52 | env/ 53 | venv/ 54 | ENV/ 55 | env.bak/ 56 | venv.bak/ 57 | 58 | # mypy 59 | .mypy_cache/ 60 | .dmypy.json 61 | dmypy.json 62 | 63 | # PyTorch benchmarks: Hydra, NVMe directory, and CSV results 64 | s3torchbenchmarking/**/multirun/ 65 | s3torchbenchmarking/**/nvme/ 66 | s3torchbenchmarking/**/*.csv 67 | 68 | # Rust .gitignore (https://github.com/github/gitignore/blob/main/Rust.gitignore) -- cherry-picked ###################### 69 | 70 | # Generated by Cargo 71 | # will have compiled files and executables 72 | s3torchconnectorclient/debug/ 73 | s3torchconnectorclient/target/ 74 | 75 | # These are backup files generated by rustfmt 76 | s3torchconnectorclient/**/*.rs.bk 77 | 78 | # Other ################################################################################################################ 79 | 80 | # JetBrains 81 | .idea/ 82 | 83 | # Sphinx documentation 84 | s3torchconnector/docs/ -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dcp_fsdp_load.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/callbacks/collate_results 3 | - aws/dynamodb # save run results to DynamoDB -- comment me if not required 4 | - _self_ 5 | 6 | # S3 bucket to use to save checkpoints. 7 | # NOTE: a non-existing bucket will fail the benchmarks. 8 | s3: 9 | region: ??? # e.g., eu-west-1 10 | uri: ??? # e.g., s3://my-bucket/ 11 | # Number of iterations for "saving" a model's checkpoint. 12 | # NOTE: this does not affect model training, as no actual training occurs in these benchmarks. 13 | epochs: 4 14 | 15 | hydra: 16 | mode: MULTIRUN 17 | sweep: 18 | dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S} 19 | sweeper: 20 | params: 21 | # Short name of a pre-trained llama v2 model (valid options: "L7b", "L13b", "L30b", "L65b", "L70b"). 22 | +model: ??? 23 | # Type of Torch distributed backend (valid options: "nccl", "gloo"). 24 | +backend: nccl 25 | # Number of workers. 26 | +world_size: 8 27 | # Checkpoint storage location (valid options: "disk", "s3"). 28 | +checkpoint.storage: disk, s3 29 | # Sharding strategy (valid options: "full", "hybrid"). 30 | +checkpoint.sharding_strategy: full 31 | # Controls whether files are forcibly synced to disk (only relevant for "disk" storage). 32 | # NOTE: We disabled this option to improve performance since FSDP checkpointing with 33 | # forced syncing (maximum durability) was significantly slower than storage throughput. 34 | # This setting has no effect when using "s3" storage. 35 | +checkpoint.sync_files: false 36 | # Checkpoint storage suffix location generated by save benchmarks, e.g., 2025-09-23-11-05-zmuZ/ 37 | +checkpoint.suffix: ??? 38 | 39 | -------------------------------------------------------------------------------- /examples/lightning/checkpoint_writing.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from lightning import Trainer 5 | from lightning.pytorch.callbacks import ModelCheckpoint 6 | from lightning.pytorch.demos import WikiText2, LightningTransformer 7 | from torch.utils.data import DataLoader 8 | 9 | from s3torchconnector.lightning import S3LightningCheckpoint 10 | 11 | 12 | def main(region: str, checkpoint_path: str, save_only_latest: bool): 13 | dataset = WikiText2() 14 | dataloader = DataLoader(dataset, num_workers=3) 15 | 16 | model = LightningTransformer(vocab_size=dataset.vocab_size) 17 | s3_lightning_checkpoint = S3LightningCheckpoint(region) 18 | 19 | # Save once per step, and if `save_only_latest`, replace the last checkpoint each time. 20 | # Replacing is implemented by saving the new checkpoint, and then deleting the previous one. 21 | # If `save_only_latest` is False, a new checkpoint is created for each step. 22 | checkpoint_callback = ModelCheckpoint( 23 | dirpath=checkpoint_path, 24 | save_top_k=1 if save_only_latest else -1, 25 | every_n_train_steps=1, 26 | filename="checkpoint-{epoch:02d}-{step:02d}", 27 | enable_version_counter=True, 28 | ) 29 | 30 | trainer = Trainer( 31 | plugins=[s3_lightning_checkpoint], 32 | callbacks=[checkpoint_callback], 33 | min_epochs=4, 34 | max_epochs=5, 35 | max_steps=3, 36 | ) 37 | trainer.fit(model, dataloader) 38 | 39 | 40 | if __name__ == "__main__": 41 | import os 42 | 43 | main( 44 | os.getenv("REGION"), 45 | os.getenv("CHECKPOINT_PATH"), 46 | os.getenv("LATEST_CHECKPOINT_ONLY") == "1", 47 | ) 48 | -------------------------------------------------------------------------------- /s3torchbenchmarking/conf/dataset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - hydra/callbacks/collate_results 3 | - aws/dynamodb # save run results to DynamoDB -- comment me if not required 4 | - _self_ 5 | 6 | # S3 bucket where the dataset is stored. 7 | # NOTE: a non-existing bucket will fail the benchmarks. 8 | s3: 9 | region: ??? # e.g., eu-west-1 10 | bucket: ??? # e.g., my-bucket (*not* an S3 URI) 11 | # Boolean flag to tell whether the dataset is sharded or not. 12 | sharding: True 13 | # Number of iterations for training a model. 14 | epochs: 10 15 | checkpoint: 16 | # Number of training steps between checkpoints. 17 | save_one_in: 0 18 | # Checkpoint storage location. 19 | destination: disk 20 | # Path for checkpoint saving (local disk or S3 URI). 21 | uri: ./nvme/checkpoints/ 22 | # S3 region. 23 | region: eu-west-2 24 | 25 | hydra: 26 | mode: MULTIRUN 27 | sweep: 28 | dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S} 29 | sweeper: 30 | params: 31 | # Name of a model (valid options: "entitlement", "vit"). 32 | +model: entitlement 33 | # Kind of the dataloader (valid options: "fsspec", "s3iterabledataset", "mountpoint", "mountpointcache"). 34 | # For dataloader kind specific options, see specific conf/dataloader/{dataloader-kind}.yaml 35 | +dataloader: fsspec, s3iterabledataset, mountpoint, mountpointcache 36 | # Dataset name (corresponds to the name of a folder in S3); will be used to build an S3 URI 37 | +dataset: 100k_496x387_images 38 | # S3 reader sweeps (only applies to s3iterabledataset/s3mapdataset) 39 | # s3reader type: sequential or range_based 40 | dataloader.s3reader.type: sequential 41 | # buffer_size (bytes): only used with range_based s3reader 42 | dataloader.s3reader.buffer_size: 8*1024*1024 -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_s3client/_mock_s3client.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from typing import Optional 5 | 6 | from s3torchconnectorclient._mountpoint_s3_client import ( 7 | MockMountpointS3Client, 8 | MountpointS3Client, 9 | ) 10 | 11 | from . import S3Client 12 | from .._user_agent import UserAgent 13 | from .s3client_config import S3ClientConfig 14 | 15 | """ 16 | _mock_s3client.py 17 | Internal client wrapper mock class for unit testing. 18 | """ 19 | 20 | 21 | class MockS3Client(S3Client): 22 | def __init__( 23 | self, 24 | region: str, 25 | bucket: str, 26 | user_agent: Optional[UserAgent] = None, 27 | s3client_config: Optional[S3ClientConfig] = None, 28 | ): 29 | super().__init__( 30 | region, 31 | user_agent=user_agent, 32 | s3client_config=s3client_config, 33 | ) 34 | self._mock_client = MockMountpointS3Client( 35 | region, 36 | bucket, 37 | throughput_target_gbps=self.s3client_config.throughput_target_gbps, 38 | part_size=self.s3client_config.part_size, 39 | user_agent_prefix=self.user_agent_prefix, 40 | unsigned=self.s3client_config.unsigned, 41 | force_path_style=self.s3client_config.force_path_style, 42 | max_attempts=self.s3client_config.max_attempts, 43 | ) 44 | 45 | def add_object(self, key: str, data: bytes) -> None: 46 | self._mock_client.add_object(key, data) 47 | 48 | def remove_object(self, key: str) -> None: 49 | self._mock_client.remove_object(key) 50 | 51 | def _client_builder(self) -> MountpointS3Client: 52 | return self._mock_client.create_mocked_client() 53 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/s3reader/s3reader.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import io 5 | from io import SEEK_SET 6 | from abc import ABC, abstractmethod 7 | from typing import Optional 8 | 9 | 10 | class S3Reader(ABC, io.BufferedIOBase): 11 | """An abstract base class for read-only, file-like representation of a single object stored in S3. 12 | 13 | This class defines the interface for S3 readers. Concrete implementations (SequentialS3Reader or 14 | RangedS3Reader extend this class. S3ReaderConstructor creates partial functions of these 15 | implementations, which are then completed by S3Client with the remaining required parameters. 16 | """ 17 | 18 | @property 19 | @abstractmethod 20 | def bucket(self) -> str: 21 | pass 22 | 23 | @property 24 | @abstractmethod 25 | def key(self) -> str: 26 | pass 27 | 28 | @abstractmethod 29 | def read(self, size: Optional[int] = None) -> bytes: 30 | pass 31 | 32 | @abstractmethod 33 | def seek(self, offset: int, whence: int = SEEK_SET, /) -> int: 34 | pass 35 | 36 | @abstractmethod 37 | def tell(self) -> int: 38 | pass 39 | 40 | @abstractmethod 41 | def readinto(self, buf) -> int: 42 | pass 43 | 44 | def seekable(self) -> bool: 45 | """ 46 | Returns: 47 | bool: Return whether object supports seek operations. 48 | """ 49 | return True 50 | 51 | def readable(self) -> bool: 52 | """ 53 | Returns: 54 | bool: Return whether object was opened for reading. 55 | """ 56 | return True 57 | 58 | def writable(self) -> bool: 59 | """ 60 | Returns: 61 | bool: Return whether object was opened for writing. 62 | """ 63 | return False 64 | -------------------------------------------------------------------------------- /.github/workflows/generate_third_party_licenses.yml: -------------------------------------------------------------------------------- 1 | name: Generate THIRD-PARTY-LICENSES 2 | 3 | on: 4 | push: 5 | tags: [ "v[0-9]+.[0-9]+.[0-9]+" ] 6 | branches: [ "dependabot/*", "main", "workflow/*" ] 7 | workflow_call: 8 | outputs: 9 | artifact_name: 10 | description: "The created artifact name for ORT results" 11 | value: ${{ jobs.generate_third_party_licenses.outputs.artifact_name }} 12 | 13 | permissions: 14 | contents: read 15 | 16 | jobs: 17 | generate_third_party_licenses: 18 | name: Generate NOTICE_DEFAULT file 19 | runs-on: ubuntu-24.04 20 | 21 | outputs: 22 | artifact_name: ${{ steps.artifact_name.outputs.artifact_name }} 23 | 24 | steps: 25 | - uses: actions/checkout@v6 26 | - name: Install Python dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install pipreqs safety 30 | # Added to fix AttributeError: module 'lib' has no attribute 'X509_V_FLAG_CB_ISSUER_CHECK' 31 | python -m pip install -U urllib3 requests 32 | 33 | - name: Generate requirements 34 | run: | 35 | pipreqs s3torchconnectorclient 36 | pipreqs s3torchconnector 37 | 38 | - name: Generate NOTICE_DEFAULT file 39 | id: ort-action 40 | # https://github.com/oss-review-toolkit/ort-ci-github-action/issues/28 41 | uses: oss-review-toolkit/ort-ci-github-action@1805edcf1f4f55f35ae6e4d2d9795ccfb29b6021 42 | with: 43 | ort-cli-report-args: -f PlainTextTemplate 44 | run: > 45 | cache-dependencies, 46 | labels, 47 | analyzer, 48 | reporter, 49 | upload-results 50 | sw-version: "-" 51 | 52 | - name: Export artifact name 53 | id: artifact_name 54 | run: | 55 | echo "artifact_name=${ORT_RESULTS_ARTIFACT_NAME}" >> $GITHUB_OUTPUT 56 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/run_benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Template script to run other benchmarks (not meant to be used directly). 4 | 5 | set -euo pipefail 6 | 7 | # Check for --save and --load flags before getopts 8 | load_mode=false 9 | save_mode=false 10 | filtered_args=() 11 | for arg in "$@"; do 12 | case $arg in 13 | --load) load_mode=true ;; 14 | --save) save_mode=true ;; 15 | *) filtered_args+=("$arg") ;; 16 | esac 17 | done 18 | 19 | # Set filtered arguments 20 | set -- "${filtered_args[@]}" 21 | 22 | while getopts "s:d:" opt; do 23 | case $opt in 24 | s) scenario=$OPTARG ;; # name of the scenario 25 | d) nvme_dir=$OPTARG ;; # mount point dir for saving checkpoints (will use NVMe drive) 26 | *) ;; 27 | esac 28 | done 29 | 30 | shift $((OPTIND - 1)) # remove all processed positional arguments from "$@" 31 | 32 | # Prepare NVMe drive mount 33 | if [[ -n $nvme_dir ]]; then 34 | ./utils/prepare_nvme.sh "$nvme_dir" 35 | fi 36 | 37 | 38 | # Run benchmarks; will write to DynamoDB table, if specified in the config (in `conf/aws/dynamodb.yaml`) 39 | # Use load_benchmark.py if -l flag is provided, otherwise use default benchmark.py 40 | if [[ "${load_mode:-}" == "true" ]]; then 41 | python ./src/s3torchbenchmarking/"$scenario"/load_benchmark.py -cd conf -cn "${scenario}_load" +path="$nvme_dir" "$@" 42 | elif [[ "${save_mode:-}" == "true" ]]; then 43 | python ./src/s3torchbenchmarking/"$scenario"/save_benchmark.py -cd conf -cn "${scenario}_save" +path="$nvme_dir" "$@" 44 | elif [[ "$scenario" == "dcp_ddp" || "$scenario" == "dcp_fsdp" ]]; then 45 | echo "No flags detected, running DCP save benchmarks. To run DCP load benchmarks, use the --load flag" 46 | python ./src/s3torchbenchmarking/"$scenario"/save_benchmark.py -cd conf -cn "${scenario}_save" +path="$nvme_dir" "$@" 47 | else 48 | python ./src/s3torchbenchmarking/"$scenario"/benchmark.py -cd conf -cn "$scenario" +path="$nvme_dir" "$@" 49 | fi -------------------------------------------------------------------------------- /s3torchconnector/docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Path setup -------------------------------------------------------------- 7 | 8 | # If extensions (or modules to document with autodoc) are in another directory, 9 | # add these directories to sys.path here. If the directory is relative to the 10 | # documentation root, use os.path.abspath to make it absolute, like shown here. 11 | # 12 | import os 13 | import sys 14 | from s3torchconnector import __version__ 15 | 16 | sys.path.insert(0, os.path.abspath("..")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 21 | 22 | project = "Amazon S3 Connector for PyTorch" 23 | copyright = "2023, Amazon S3" 24 | author = "Amazon S3" 25 | release = __version__ 26 | 27 | # -- General configuration --------------------------------------------------- 28 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 29 | 30 | extensions = ["sphinx.ext.napoleon", "sphinx.ext.viewcode", "autoapi.extension"] 31 | 32 | templates_path = ["_templates"] 33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 34 | 35 | 36 | # -- Options for HTML output ------------------------------------------------- 37 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 38 | 39 | html_theme = "sphinx_rtd_theme" 40 | html_static_path = ["_static"] 41 | 42 | 43 | autoapi_dirs = ["../src/s3torchconnector"] 44 | autoapi_type = "python" 45 | autoapi_options = [ 46 | "members", 47 | "undoc-members", 48 | "show-inheritance", 49 | "show-module-summary", 50 | "imported-members", 51 | ] 52 | autoapi_keep_files = True 53 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_s3dataset_common.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from typing import Iterable, Union, Tuple 5 | 6 | from ._s3_bucket_iterable import S3BucketIterable 7 | from ._s3client import S3Client 8 | from . import S3Reader 9 | from ._s3bucket_key_data import S3BucketKeyData 10 | 11 | """ 12 | _s3dataset_common.py 13 | Collection of common methods for S3 datasets, containing logic for URIs parsing and objects listing. 14 | """ 15 | 16 | 17 | def identity(obj: S3Reader) -> S3Reader: 18 | return obj 19 | 20 | 21 | # TODO: Check boto3 implementation for this 22 | def parse_s3_uri(uri: str) -> Tuple[str, str]: 23 | if not uri or not uri.startswith("s3://"): 24 | raise ValueError("Only s3:// URIs are supported") 25 | uri = uri[len("s3://") :] 26 | if not uri: 27 | raise ValueError("Bucket name must be non-empty") 28 | split = uri.split("/", maxsplit=1) 29 | if len(split) == 1: 30 | bucket = split[0] 31 | prefix = "" 32 | else: 33 | bucket, prefix = split 34 | if not bucket: 35 | raise ValueError("Bucket name must be non-empty") 36 | return bucket, prefix 37 | 38 | 39 | def get_objects_from_uris( 40 | object_uris: Union[str, Iterable[str]], client: S3Client 41 | ) -> Iterable[S3BucketKeyData]: 42 | if isinstance(object_uris, str): 43 | object_uris = [object_uris] 44 | # TODO: We should be consistent with URIs parsing. Revise if we want to do this upfront or lazily. 45 | bucket_key_pairs = [parse_s3_uri(uri) for uri in object_uris] 46 | 47 | return (S3BucketKeyData(bucket, key) for bucket, key in bucket_key_pairs) 48 | 49 | 50 | def get_objects_from_prefix(s3_uri: str, client: S3Client) -> Iterable[S3BucketKeyData]: 51 | bucket, prefix = parse_s3_uri(s3_uri) 52 | return iter(S3BucketIterable(client, bucket, prefix)) 53 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/python_structs/py_restore_status.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use mountpoint_s3_client::types::RestoreStatus; 7 | use pyo3::types::PyTuple; 8 | use pyo3::{IntoPyObject, IntoPyObjectExt}; 9 | use pyo3::{pyclass, pymethods}; 10 | use pyo3::{Bound, PyResult}; 11 | 12 | use crate::PyRef; 13 | 14 | #[pyclass( 15 | name = "RestoreStatus", 16 | module = "s3torchconnectorclient._mountpoint_s3_client" 17 | )] 18 | #[derive(Debug, Clone)] 19 | pub struct PyRestoreStatus { 20 | #[pyo3(get)] 21 | in_progress: bool, 22 | #[pyo3(get)] 23 | expiry: Option, 24 | } 25 | 26 | impl PyRestoreStatus { 27 | pub(crate) fn from_restore_status(restore_status: RestoreStatus) -> Self { 28 | match restore_status { 29 | RestoreStatus::InProgress => PyRestoreStatus::new(true, None), 30 | RestoreStatus::Restored { expiry } => { 31 | let expiry = expiry 32 | .duration_since(expiry) 33 | .expect("Expired before unix epoch!") 34 | .as_millis(); 35 | PyRestoreStatus::new(false, Some(expiry)) 36 | } 37 | } 38 | } 39 | } 40 | 41 | #[pymethods] 42 | impl PyRestoreStatus { 43 | #[new] 44 | #[pyo3(signature = (in_progress, expiry=None))] 45 | pub fn new(in_progress: bool, expiry: Option) -> Self { 46 | Self { 47 | in_progress, 48 | expiry, 49 | } 50 | } 51 | 52 | pub fn __getnewargs__(slf: PyRef<'_, Self>) -> PyResult> { 53 | let py = slf.py(); 54 | let state = [ 55 | slf.in_progress.into_py_any(py)?.bind(py).to_owned(), 56 | slf.expiry.into_pyobject(py)?.into_any() 57 | ]; 58 | PyTuple::new(py, state) 59 | } 60 | 61 | fn __repr__(&self) -> String { 62 | format!("{:?}", self) 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /run_cibuildwheel_on_ec2.sh: -------------------------------------------------------------------------------- 1 | if [ $# -ne 10 ]; then 2 | echo "Invalid number of parameters, you need to provide role name, region name, bucket name, prefix, express region name and express bucket name, custom endpoint for s3 standard, auth profile arn and buckets names for testing auth profile" 3 | echo "Usage: $0 S3RoleName us-west-2 s3torchconnector-test-bucket-name prefix-name/ us-east-1 s3torchconnectorclient-express-bucket-name https://s3.amazon.com arn:aws:iam::XXXXXXXXXXX:role/RoleName profile-test-bucket-name profile-test-express-bucket-name " 4 | exit 1 5 | fi 6 | 7 | ROLE_NAME=$1 8 | REGION_NAME=$2 9 | BUCKET_NAME=$3 10 | PREFIX=$4 11 | EXPRESS_REGION_NAME=$5 12 | EXPRESS_BUCKET_NAME=$6 13 | S3_CUSTOM_ENDPOINT_URL=$7 14 | PROFILE_IAM_ROLE=$8 15 | S3_PROFILE_BUCKET=$9 16 | S3_EXPRESS_PROFILE_BUCKET=${10} 17 | 18 | FILE_NAME="tmp_cred.json" 19 | # Set metadata token TTL to 6 hours 20 | TOKEN=`curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 21600"` 21 | # Retrieve temporary credentials and save to file 22 | curl -H "X-aws-ec2-metadata-token: $TOKEN" http://169.254.169.254/latest/meta-data/iam/security-credentials/${ROLE_NAME} >> ${FILE_NAME} 23 | # Expose temporary credentials to use from cibuildwheel container 24 | export AWS_ACCESS_KEY_ID=`cat ${FILE_NAME} | jq -r '.AccessKeyId'` 25 | export AWS_SECRET_ACCESS_KEY=`cat ${FILE_NAME} | jq -r '.SecretAccessKey'` 26 | export AWS_SESSION_TOKEN=`cat ${FILE_NAME} | jq -r '.Token'` 27 | rm ${FILE_NAME} 28 | 29 | # Expose settings for integration tests to use from cibuildwheel container 30 | export S3_REGION=${REGION_NAME} 31 | export S3_BUCKET=${BUCKET_NAME} 32 | export S3_PREFIX=${PREFIX} 33 | export S3_EXPRESS_REGION=${EXPRESS_REGION_NAME} 34 | export S3_EXPRESS_BUCKET=${EXPRESS_BUCKET_NAME} 35 | export S3_CUSTOM_ENDPOINT_URL=${S3_CUSTOM_ENDPOINT_URL} 36 | export PROFILE_IAM_ROLE=${PROFILE_IAM_ROLE} 37 | export S3_PROFILE_BUCKET=${S3_PROFILE_BUCKET} 38 | export S3_EXPRESS_PROFILE_BUCKET=${S3_EXPRESS_PROFILE_BUCKET} 39 | 40 | cibuildwheel --output-dir wheelhouse --platform linux s3torchconnectorclient 41 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/lib.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use crate::exception::S3Exception; 7 | use crate::get_object_stream::GetObjectStream; 8 | use crate::list_object_stream::ListObjectStream; 9 | use crate::mock_client::PyMockClient; 10 | use crate::mountpoint_s3_client::join_all_managed_threads; 11 | use crate::mountpoint_s3_client::MountpointS3Client; 12 | use crate::put_object_stream::PutObjectStream; 13 | use crate::python_structs::py_head_object_result::PyHeadObjectResult; 14 | use crate::python_structs::py_list_object_result::PyListObjectResult; 15 | use crate::python_structs::py_object_info::PyObjectInfo; 16 | use crate::python_structs::py_restore_status::PyRestoreStatus; 17 | use pyo3::prelude::*; 18 | 19 | mod build_info; 20 | mod exception; 21 | mod get_object_stream; 22 | mod list_object_stream; 23 | mod logger_setup; 24 | mod mock_client; 25 | mod mountpoint_s3_client; 26 | mod mountpoint_s3_client_inner; 27 | mod put_object_stream; 28 | mod python_structs; 29 | 30 | #[pymodule] 31 | #[pyo3(name = "_mountpoint_s3_client")] 32 | fn make_lib(py: Python, mountpoint_s3_client: &Bound<'_, PyModule>) -> PyResult<()> { 33 | logger_setup::setup_logging()?; 34 | mountpoint_s3_client.add_class::()?; 35 | mountpoint_s3_client.add_class::()?; 36 | mountpoint_s3_client.add_class::()?; 37 | mountpoint_s3_client.add_class::()?; 38 | mountpoint_s3_client.add_class::()?; 39 | mountpoint_s3_client.add_class::()?; 40 | mountpoint_s3_client.add_class::()?; 41 | mountpoint_s3_client.add_class::()?; 42 | mountpoint_s3_client.add_class::()?; 43 | mountpoint_s3_client.add("S3Exception", py.get_type::())?; 44 | mountpoint_s3_client.add("__version__", build_info::FULL_VERSION)?; 45 | mountpoint_s3_client.add_function(wrap_pyfunction!( 46 | join_all_managed_threads, 47 | mountpoint_s3_client 48 | )?)?; 49 | 50 | Ok(()) 51 | } 52 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/test_e2e_s3checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import torch 5 | import pytest 6 | 7 | from s3torchconnector import S3Checkpoint 8 | from models.net import Net 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "tensor_dimensions", 13 | [[3, 2], [10, 1024, 1024]], 14 | ) 15 | def test_general_checkpointing(checkpoint_directory, tensor_dimensions): 16 | tensor = torch.rand(tensor_dimensions) 17 | checkpoint_name = "general_checkpoint.pt" 18 | checkpoint = S3Checkpoint(region=checkpoint_directory.region) 19 | s3_uri = f"{checkpoint_directory.s3_uri}/{checkpoint_name}" 20 | with checkpoint.writer(s3_uri) as writer: 21 | torch.save(tensor, writer) 22 | 23 | loaded = torch.load(checkpoint.reader(s3_uri), weights_only=True) 24 | 25 | assert torch.equal(tensor, loaded) 26 | 27 | 28 | def test_nn_checkpointing(checkpoint_directory): 29 | nn_model = Net() 30 | checkpoint_name = "neural_network_model.pt" 31 | checkpoint = S3Checkpoint(region=checkpoint_directory.region) 32 | 33 | epoch = 5 34 | s3_uri = f"{checkpoint_directory.s3_uri}/{checkpoint_name}" 35 | loss = 0.4 36 | 37 | with checkpoint.writer(s3_uri) as writer: 38 | torch.save( 39 | { 40 | "epoch": epoch, 41 | "model_state_dict": nn_model.state_dict(), 42 | "loss": loss, 43 | }, 44 | writer, 45 | ) 46 | 47 | loaded_nn_model = Net() 48 | 49 | # assert models are not equal before loading from checkpoint 50 | assert not nn_model.equals(loaded_nn_model) 51 | 52 | loaded_checkpoint = torch.load(checkpoint.reader(s3_uri), weights_only=True) 53 | loaded_nn_model.load_state_dict(loaded_checkpoint["model_state_dict"]) 54 | assert nn_model.equals(loaded_nn_model) 55 | 56 | loaded_epoch = loaded_checkpoint["epoch"] 57 | loaded_loss = loaded_checkpoint["loss"] 58 | assert loss == loaded_loss 59 | assert epoch == loaded_epoch 60 | 61 | # Assert that eval and train do not raise 62 | loaded_nn_model.eval() 63 | loaded_nn_model.train() 64 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/s3checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from typing import Optional 4 | 5 | from ._s3dataset_common import parse_s3_uri 6 | from ._s3client import S3Client, S3ClientConfig 7 | from . import S3Reader, S3Writer 8 | 9 | 10 | class S3Checkpoint: 11 | """A checkpoint manager for S3. 12 | 13 | To read a checkpoint from S3, users need to create an S3Reader 14 | by providing s3_uri of the checkpoint stored in S3. Similarly, to save a 15 | checkpoint to S3, users need to create an S3Writer by providing s3_uri. 16 | S3Reader and S3Writer implements io.BufferedIOBase therefore, they can be passed to 17 | torch.load, and torch.save. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | region: str, 23 | endpoint: Optional[str] = None, 24 | s3client_config: Optional[S3ClientConfig] = None, 25 | ): 26 | self.region = region 27 | self.endpoint = endpoint 28 | self._client = S3Client( 29 | region, endpoint=endpoint, s3client_config=s3client_config 30 | ) 31 | 32 | def reader(self, s3_uri: str) -> S3Reader: 33 | """Creates an S3Reader from a given s3_uri. 34 | 35 | Args: 36 | s3_uri (str): A valid s3_uri. (i.e. s3:///) 37 | 38 | Returns: 39 | S3Reader: a read-only binary stream of the S3 object's contents, specified by the s3_uri. 40 | 41 | Raises: 42 | S3Exception: An error occurred accessing S3. 43 | """ 44 | bucket, key = parse_s3_uri(s3_uri) 45 | return self._client.get_object(bucket, key) 46 | 47 | def writer(self, s3_uri: str) -> S3Writer: 48 | """Creates an S3Writer from a given s3_uri. 49 | 50 | Args: 51 | s3_uri (str): A valid s3_uri. (i.e. s3:///) 52 | 53 | Returns: 54 | S3Writer: a write-only binary stream. The content is saved to S3 using the specified s3_uri. 55 | 56 | Raises: 57 | S3Exception: An error occurred accessing S3. 58 | """ 59 | bucket, key = parse_s3_uri(s3_uri) 60 | return self._client.put_object(bucket, key) 61 | -------------------------------------------------------------------------------- /s3torchconnector/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "build"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "s3torchconnector" 7 | version = "1.4.3" 8 | description = "S3 connector integration for PyTorch" 9 | requires-python = ">=3.8,<3.14" 10 | readme = "README.md" 11 | classifiers = [ 12 | "Development Status :: 5 - Production/Stable", 13 | "Programming Language :: Python :: 3", 14 | "Programming Language :: Python :: 3.8", 15 | "Programming Language :: Python :: 3.9", 16 | "Programming Language :: Python :: 3.10", 17 | "Programming Language :: Python :: 3.11", 18 | "Programming Language :: Python :: 3.12", 19 | "Programming Language :: Python :: 3.13", 20 | "License :: OSI Approved :: BSD License", 21 | "Operating System :: OS Independent", 22 | "Topic :: Utilities" 23 | ] 24 | 25 | dependencies = [ 26 | "torch >= 2.0.1, != 2.5.0", 27 | "s3torchconnectorclient == 1.4.3", 28 | ] 29 | 30 | [project.optional-dependencies] 31 | test = [ 32 | "pytest", 33 | "pytest-timeout", 34 | "hypothesis", 35 | "flake8", 36 | "black", 37 | "mypy" 38 | ] 39 | 40 | e2e = [ 41 | "torchdata<=0.9.0", # we have dependency on deprecated DataPipes, which were removed in 0.10.0 42 | "Pillow>=10.3.0", 43 | "boto3<1.37.2", # prevent conflict caused by aiobotocore that restrict version of botocore 44 | "numpy < 2", 45 | "pytest-xdist", 46 | "fsspec==2025.3.0; python_version == '3.8'", # pin fsspec version for Python 3.8 to prevent dataset e2e test failures 47 | ] 48 | 49 | lightning = [ 50 | "lightning >= 2.0", 51 | "packaging", 52 | ] 53 | 54 | lightning-tests = [ 55 | "s3torchconnector[lightning]", 56 | "s3fs", 57 | "torchmetrics != 1.7.0, != 1.7.1", # version 1.7.0 and 1.7.1 breaks lightning checkpoints e2e tests during "lightning" module import 58 | ] 59 | 60 | dcp = [ 61 | "tenacity", 62 | "torch >= 2.3, != 2.5.0", 63 | ] 64 | 65 | dcp-test = [ 66 | "s3torchconnector[dcp]", 67 | "pytest", 68 | "importlib_metadata; python_version == '3.9'", 69 | ] 70 | 71 | [tool.setuptools.packages] 72 | # Pure Python packages/modules 73 | find = { where = ["src"] } 74 | 75 | [tool.setuptools] 76 | license-files = ["LICENSE", "THIRD-PARTY-LICENSES", "NOTICE"] 77 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/test_common.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import platform 5 | import torch 6 | from s3torchconnector import S3Reader 7 | import boto3 8 | 9 | from typing import Tuple, List 10 | 11 | 12 | def _get_fork_methods() -> List[str]: 13 | """Get a set of valid start methods for PyTorch's multiprocessing. 14 | On macOS, the 'fork' and 'forkserver' start methods are known to crash, 15 | despite being reported as usable by PyTorch. This function filters out 16 | those methods for macOS systems. 17 | 18 | Returns: 19 | List[str]: A set of valid start methods for the current platform. 20 | """ 21 | methods = set(torch.multiprocessing.get_all_start_methods()) 22 | 23 | if platform.system() == "Darwin": 24 | # fork and forkserver crash on MacOS, even though it's reported as usable. 25 | # https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods 26 | # https://bugs.python.org/issue?@action=redirect&bpo=33725 27 | methods -= {"fork", "forkserver"} 28 | return [method for method in methods] 29 | 30 | 31 | def _set_start_method(start_method: str): 32 | torch.multiprocessing.set_start_method(start_method, force=True) 33 | 34 | 35 | def _read_data(s3reader: S3Reader) -> Tuple[str, bytes]: 36 | return s3reader.key, s3reader.read() 37 | 38 | 39 | def _list_folders_in_bucket(bucket_name, prefix=""): 40 | if prefix and not prefix.endswith("/"): 41 | prefix += "/" 42 | 43 | s3_client = boto3.client("s3") 44 | paginator = s3_client.get_paginator("list_objects_v2") 45 | 46 | pages = paginator.paginate(Bucket=bucket_name, Delimiter="/", Prefix=prefix) 47 | 48 | folders = [] 49 | for page in pages: 50 | # Common prefixes are the folders 51 | if "CommonPrefixes" in page: 52 | for obj in page["CommonPrefixes"]: 53 | folder_name = obj["Prefix"] 54 | if prefix: 55 | # Remove the prefix from the folder name if it exists 56 | folder_name = folder_name[len(prefix) :] 57 | if folder_name: # Avoid empty folder names 58 | folders.append(folder_name.rstrip("/")) 59 | return folders 60 | -------------------------------------------------------------------------------- /s3torchbenchmarking/utils/collect_and_write_to_dynamodb.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | """Collect and write collated results to DynamoDB table. 5 | 6 | This script collects all "collated_results.json" files in a given directory, and write them (in batch) to the specified 7 | DynamoDB table. 8 | 9 | Requires AWS credentials to be correctly set beforehand (env var). 10 | """ 11 | 12 | import argparse 13 | import json 14 | import logging 15 | from decimal import Decimal 16 | from pathlib import Path 17 | from typing import List, Any 18 | 19 | import boto3 20 | from botocore.exceptions import ClientError 21 | 22 | _RUN_FILENAME = "run.json" 23 | 24 | logger = logging.getLogger(__name__) 25 | logging.basicConfig(level=logging.INFO) 26 | 27 | 28 | def collect_collated_results(results_dir: str) -> List[Any]: 29 | all_job_results = [] 30 | for entry in Path(results_dir).glob(f"**/{_RUN_FILENAME}"): 31 | if entry.is_file(): 32 | with open(entry, "r") as f: 33 | all_job_results.append(json.load(f, parse_float=Decimal)) 34 | 35 | logger.info("Collected %i job results", len(all_job_results)) 36 | return all_job_results 37 | 38 | 39 | def write_collated_results_to_dynamodb( 40 | collated_results: List[Any], region: str, table: str 41 | ) -> None: 42 | dynamodb = boto3.resource("dynamodb", region_name=region) 43 | 44 | try: 45 | with dynamodb.Table(table).batch_writer() as writer: 46 | for collated_result in collated_results: 47 | writer.put_item(Item=collated_result) 48 | except ClientError: 49 | logger.error("Couldn't load data into table %s", table, exc_info=True) 50 | 51 | 52 | def main(): 53 | parser = argparse.ArgumentParser( 54 | description="Collect and write collated results to DynamoDB table" 55 | ) 56 | parser.add_argument( 57 | "collated_results_dir", help="directory where the collated results are" 58 | ) 59 | parser.add_argument("region", help="region where the DynamoDB table is") 60 | parser.add_argument("table", help="DynamoDB table name") 61 | args = parser.parse_args() 62 | 63 | collated_results = collect_collated_results(args.collated_results_dir) 64 | write_collated_results_to_dynamodb(collated_results, args.region, args.table) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/test_s3_client.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import os 5 | import tempfile 6 | import pytest 7 | from s3torchconnectorclient import S3Exception 8 | 9 | from s3torchconnector._s3client import S3Client, S3ClientConfig 10 | 11 | HELLO_WORLD_DATA = b"Hello, World!\n" 12 | TEST_PROFILE_NAME = "test-profile" 13 | 14 | 15 | def test_no_access_objects_without_profile(empty_directory): 16 | if empty_directory.profile_bucket is None: 17 | pytest.skip("No profile bucket configured") 18 | 19 | client = S3Client( 20 | empty_directory.region, 21 | ) 22 | filename = f"{empty_directory.prefix}hello_world.txt" 23 | 24 | with pytest.raises(S3Exception): 25 | put_stream = client.put_object( 26 | empty_directory.profile_bucket, 27 | filename, 28 | ) 29 | put_stream.write(HELLO_WORLD_DATA) 30 | 31 | 32 | def test_access_objects_with_profile(empty_directory): 33 | if empty_directory.profile_bucket is None: 34 | pytest.skip("No profile bucket configured") 35 | 36 | try: 37 | tmp_file = tempfile.NamedTemporaryFile() 38 | tmp_file.write( 39 | f"""[profile default] 40 | aws_access_key_id = {os.getenv("AWS_ACCESS_KEY_ID")} 41 | aws_secret_access_key = {os.getenv("AWS_SECRET_ACCESS_KEY")} 42 | aws_session_token = {os.getenv("AWS_SESSION_TOKEN")} 43 | 44 | [profile {TEST_PROFILE_NAME}] 45 | role_arn = {empty_directory.profile_arn} 46 | region = {empty_directory.region} 47 | source_profile = default""".encode() 48 | ) 49 | tmp_file.flush() 50 | os.environ["AWS_CONFIG_FILE"] = tmp_file.name 51 | 52 | client = S3Client( 53 | empty_directory.region, 54 | s3client_config=S3ClientConfig(profile=TEST_PROFILE_NAME), 55 | ) 56 | filename = f"{empty_directory.prefix}hello_world.txt" 57 | 58 | put_stream = client.put_object( 59 | empty_directory.profile_bucket, 60 | filename, 61 | ) 62 | 63 | put_stream.write(HELLO_WORLD_DATA) 64 | put_stream.close() 65 | 66 | get_stream = client.get_object( 67 | empty_directory.profile_bucket, 68 | filename, 69 | ) 70 | assert b"".join(get_stream) == HELLO_WORLD_DATA 71 | finally: 72 | os.environ["AWS_CONFIG_FILE"] = "" 73 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/dcp/test_s3_storage_reader.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from unittest.mock import Mock 5 | from hypothesis import given 6 | from hypothesis.strategies import composite, integers, lists 7 | 8 | from torch.distributed.checkpoint.planner import LoadPlan, ReadItem 9 | 10 | from s3torchconnector.dcp import S3StorageReader 11 | 12 | TEST_REGION = "eu-east-1" 13 | TEST_PATH = "s3://test-bucket/test-checkpoint/" 14 | 15 | 16 | @composite 17 | def load_plan_with_offsets(draw): 18 | """Generate LoadPlan with random offsets.""" 19 | offsets = draw(lists(integers(0, 10_000_000), min_size=1, max_size=10_000)) 20 | 21 | storage_data = {} 22 | items = [] 23 | 24 | for i, offset in enumerate(offsets): 25 | storage_index = f"item{i}" 26 | storage_data[storage_index] = Mock(offset=offset) 27 | items.append(Mock(spec=ReadItem, storage_index=storage_index)) 28 | 29 | return LoadPlan(items), storage_data 30 | 31 | 32 | def test_s3storage_reader_prepare_local_plan_empty(): 33 | """Test prepare_local_plan handles empty plans.""" 34 | s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) 35 | 36 | sorted_plan = s3_storage_reader.prepare_local_plan(LoadPlan([])) 37 | # Output: LoadPlan(items=[], storage_data=None, planner_data=None) 38 | 39 | assert isinstance(sorted_plan, LoadPlan) 40 | assert len(sorted_plan.items) == 0 41 | 42 | 43 | @given(load_plan_with_offsets()) 44 | def test_s3storage_reader_prepare_local_plan(loadplan_and_storagedata): 45 | """Test prepare local plan sorts items by storage_data offset.""" 46 | load_plan, storage_data = loadplan_and_storagedata 47 | 48 | s3_storage_reader = S3StorageReader(TEST_REGION, TEST_PATH) 49 | s3_storage_reader.storage_data = storage_data 50 | 51 | sorted_plan = s3_storage_reader.prepare_local_plan(load_plan) 52 | sorted_offsets = [ 53 | storage_data[item.storage_index].offset for item in sorted_plan.items 54 | ] 55 | 56 | # Verify return type 57 | assert isinstance(sorted_plan, LoadPlan) 58 | 59 | # Verify Load Ordering sorts offsets 60 | assert sorted_offsets == sorted(sorted_offsets) 61 | 62 | # Verify Load Ordering keeps items the same 63 | assert len(sorted_plan.items) == len(load_plan.items) 64 | assert set(sorted_plan.items) == set(load_plan.items) 65 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/test_s3_client_config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | from hypothesis import given, example 4 | from hypothesis.strategies import integers, floats 5 | 6 | from s3torchconnector import S3ClientConfig 7 | from .test_s3_client import MiB, GiB 8 | 9 | 10 | def test_default(): 11 | config = S3ClientConfig() 12 | assert config.part_size == 8 * MiB 13 | assert config.throughput_target_gbps == 10.0 14 | assert config.force_path_style is False 15 | assert config.max_attempts == 10 16 | assert config.profile is None 17 | 18 | 19 | def test_enable_force_path_style(): 20 | config = S3ClientConfig(force_path_style=True) 21 | assert config.force_path_style is True 22 | 23 | 24 | def test_change_profile(): 25 | config = S3ClientConfig(profile="test_profile") 26 | assert config.profile == "test_profile" 27 | 28 | 29 | @given(part_size=integers(min_value=5 * MiB, max_value=5 * GiB)) 30 | def test_part_size_setup(part_size: int): 31 | config = S3ClientConfig(part_size=part_size) 32 | assert config.part_size == part_size 33 | assert config.throughput_target_gbps == 10.0 34 | 35 | 36 | @given(throughput_target_gbps=floats(min_value=1.0, max_value=100.0)) 37 | def test_throughput_target_gbps_setup(throughput_target_gbps: float): 38 | config = S3ClientConfig(throughput_target_gbps=throughput_target_gbps) 39 | assert config.part_size == 8 * 1024 * 1024 40 | assert config.throughput_target_gbps == throughput_target_gbps 41 | 42 | 43 | @given(max_attempts=integers(min_value=1, max_value=10)) 44 | def test_max_attempts_setup(max_attempts: int): 45 | config = S3ClientConfig(max_attempts=max_attempts) 46 | assert config.max_attempts == max_attempts 47 | 48 | 49 | @given( 50 | part_size=integers(min_value=5 * MiB, max_value=5 * GiB), 51 | throughput_target_gbps=floats(min_value=1.0, max_value=100.0), 52 | max_attempts=integers(min_value=1, max_value=10), 53 | ) 54 | @example(part_size=5 * MiB, throughput_target_gbps=10.0, max_attempts=2) 55 | @example(part_size=5 * GiB, throughput_target_gbps=15.0, max_attempts=8) 56 | def test_custom_setup(part_size: int, throughput_target_gbps: float, max_attempts: int): 57 | config = S3ClientConfig( 58 | part_size=part_size, 59 | throughput_target_gbps=throughput_target_gbps, 60 | max_attempts=max_attempts, 61 | ) 62 | assert config.part_size == part_size 63 | assert config.throughput_target_gbps == throughput_target_gbps 64 | assert config.max_attempts == max_attempts 65 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_fsdp/llama_model_config.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from dataclasses import dataclass 5 | from transformers import LlamaConfig, AutoModelForCausalLM 6 | 7 | 8 | @dataclass 9 | class LlamaModelConfig: 10 | hidden_size: int 11 | intermediate_size: int 12 | num_hidden_layers: int 13 | num_attention_heads: int 14 | num_key_value_heads: int 15 | 16 | 17 | # LlamaModelParams is a class that takes a model size as input and returns the corresponding model configuration 18 | class LlamaModelParams: 19 | def __init__(self, model_size: str): 20 | configs = { 21 | "L7b": LlamaModelConfig(4096, 11008, 32, 32, 32), 22 | "L13b": LlamaModelConfig(5120, 13824, 40, 40, 40), 23 | # "mixed": LlamaModelConfig(8192, 20000, 40, 64, 8), # 101194 24 | "L30b": LlamaModelConfig(6656, 17920, 60, 52, 52), # 125024 25 | "L65b": LlamaModelConfig(8192, 22016, 80, 64, 64), 26 | "L70b": LlamaModelConfig(8192, 28672, 80, 64, 8), 27 | } 28 | 29 | if model_size not in configs: 30 | raise ValueError(f"Invalid model size. Choose from: {list(configs.keys())}") 31 | 32 | config = configs[model_size] 33 | self.hidden_size = config.hidden_size 34 | self.intermediate_size = config.intermediate_size 35 | self.num_hidden_layers = config.num_hidden_layers 36 | self.num_attention_heads = config.num_attention_heads 37 | self.num_key_value_heads = config.num_key_value_heads 38 | 39 | 40 | # create a function that returns a llama model config 41 | def get_llama_model_config(model_name: str): 42 | params = LlamaModelParams(model_name) 43 | model_config = LlamaConfig( 44 | vocab_size=50432, 45 | hidden_size=params.hidden_size, 46 | intermediate_size=params.intermediate_size, 47 | num_hidden_layers=params.num_hidden_layers, 48 | num_attention_heads=params.num_attention_heads, 49 | num_key_value_heads=params.num_key_value_heads, 50 | max_position_embeddings=4096, 51 | rms_norm_eps=1e-5, 52 | use_cache=False, 53 | pretraining_tp=1, 54 | tie_word_embeddings=False, 55 | rope_scaling=None, 56 | ) 57 | return model_config 58 | 59 | 60 | # create a function that returns a llama model 61 | def get_llama_model(model_name: str): 62 | model_config = get_llama_model_config(model_name) 63 | model = AutoModelForCausalLM.from_config(model_config) 64 | return model 65 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/python_structs/py_head_object_result.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use mountpoint_s3_client::types::HeadObjectResult; 7 | use pyo3::types::PyTuple; 8 | use pyo3::{pyclass, pymethods}; 9 | use pyo3::{Bound}; 10 | use pyo3::{IntoPyObject, PyResult}; 11 | 12 | 13 | use crate::python_structs::py_restore_status::PyRestoreStatus; 14 | use crate::PyRef; 15 | 16 | #[pyclass( 17 | name = "HeadObjectResult", 18 | module = "s3torchconnectorclient._mountpoint_s3_client", 19 | frozen 20 | )] 21 | #[derive(Debug, Clone)] 22 | pub struct PyHeadObjectResult { 23 | #[pyo3(get)] 24 | etag: String, 25 | #[pyo3(get)] 26 | size: u64, 27 | #[pyo3(get)] 28 | last_modified: i64, 29 | #[pyo3(get)] 30 | storage_class: Option, 31 | #[pyo3(get)] 32 | restore_status: Option, 33 | } 34 | 35 | impl PyHeadObjectResult { 36 | pub(crate) fn from_head_object_result(head_object_result: HeadObjectResult) -> Self { 37 | PyHeadObjectResult::new( 38 | head_object_result.etag.into_inner(), 39 | head_object_result.size, 40 | head_object_result.last_modified.unix_timestamp(), 41 | head_object_result.storage_class, 42 | head_object_result 43 | .restore_status 44 | .map(PyRestoreStatus::from_restore_status), 45 | ) 46 | } 47 | } 48 | 49 | #[pymethods] 50 | impl PyHeadObjectResult { 51 | #[new] 52 | #[pyo3(signature = (etag, size, last_modified, storage_class=None, restore_status=None))] 53 | pub fn new( 54 | etag: String, 55 | size: u64, 56 | last_modified: i64, 57 | storage_class: Option, 58 | restore_status: Option, 59 | ) -> Self { 60 | Self { 61 | etag, 62 | size, 63 | last_modified, 64 | storage_class, 65 | restore_status, 66 | } 67 | } 68 | 69 | pub fn __getnewargs__(slf: PyRef<'_, Self>) -> PyResult> { 70 | let py = slf.py(); 71 | 72 | let state = [ 73 | slf.etag.clone().into_pyobject(py)?.into_any(), 74 | slf.size.into_pyobject(py)?.into_any(), 75 | slf.last_modified.into_pyobject(py)?.into_any(), 76 | slf.storage_class.clone().into_pyobject(py)?.into_any(), 77 | slf.restore_status.clone().into_pyobject(py)?.into_any(), 78 | ]; 79 | 80 | PyTuple::new(py, &state) 81 | } 82 | 83 | fn __repr__(&self) -> String { 84 | format!("{:?}", self) 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_ddp/README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch's Distributed Checkpoint (DCP) benchmarks using DistributedDataParallel (DDP) training 2 | 3 | The `dcp` Python package provides a suite of benchmarks designed to evaluate and measure the performance 4 | of [PyTorch's Distributed Checkpointing (DCP)][DCP] feature in comparison to the `s3torchconnector` library. 5 | 6 | These benchmarks specifically use DistributedDataParallel (DDP), which is PyTorch's standard approach 7 | for distributed training where the model is replicated across multiple GPUs/processes. 8 | With DDP, each process maintains a complete copy of the model parameters, making it suitable for scenarios 9 | where memory requirements per GPU are manageable. 10 | 11 | ### Purpose 12 | 13 | These benchmarks test both "save" and "load" mechanisms of PyTorch DCP (`torch.distributed.checkpoint.save` and `torch.distributed.checkpoint.load`). The primary objectives are to evaluate the `s3torchconnector` library's performance against other libraries and local storage options, by measuring the following metrics: 14 | 15 | **Save Benchmarks:** 16 | - Checkpoint saving throughput (in MiB/s) 17 | - Checkpoint "corrected" save durations (in seconds), which exclude the influence of model load duration on the device 18 | 19 | **Load Benchmarks:** 20 | - Checkpoint loading throughput (in MiB/s) 21 | - Checkpoint "corrected" load durations (in seconds), which exclude the influence of process setup and model loading to device 22 | 23 | ### Configuration 24 | 25 | The benchmark runs can be customized through configuration files: 26 | 27 | - **Save benchmarks**: [`dcp_ddp_save.yaml`](../../../conf/dcp_ddp.yaml) 28 | - **Load benchmarks**: [`dcp_ddp_load.yaml`](../../../conf/dcp_ddp_load.yaml) 29 | 30 | The load configuration includes a `checkpoint.suffix` parameter that specifies which saved checkpoint to load. 31 | 32 | > [!IMPORTANT] 33 | > A `+path` option is passed to the running script ([`run_dcp_ddp_benchmarks.sh`](../../../utils/run_dcp_ddp_benchmarks.sh)), 34 | > and will be used only if `checkpoint.storage` key includes `disk`. 35 | 36 | ### Usage 37 | 38 | **Save benchmarks (default):** 39 | ```bash 40 | ./utils/run_dcp_ddp_benchmarks.sh 41 | ./utils/run_dcp_ddp_benchmarks.sh --save 42 | ``` 43 | 44 | **Load benchmarks:** 45 | ```bash 46 | ./utils/run_dcp_ddp_benchmarks.sh --load 47 | ``` 48 | 49 | ### References 50 | 51 | - https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html 52 | - https://pytorch.org/docs/stable/elastic/run.html 53 | - https://pytorch.org/tutorials/intermediate/ddp_tutorial.html 54 | 55 | [DCP]: https://pytorch.org/docs/stable/distributed.checkpoint.html 56 | 57 | [multirun]: https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run/ 58 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_ddp/save_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | from multiprocessing.queues import Queue 6 | from time import perf_counter 7 | from typing import Tuple 8 | 9 | import hydra 10 | import torch 11 | import torch.distributed as dist 12 | import torch.distributed.checkpoint as dcp 13 | from omegaconf import DictConfig 14 | from torch.nn.parallel import DistributedDataParallel 15 | 16 | from s3torchbenchmarking.dcp_common import setup, get_writer, benchmark_common_runner 17 | from s3torchbenchmarking.models import get_benchmark_model, BenchmarkModel 18 | 19 | Timestamps = Tuple[float, float] 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | # TODO: add Structured Config (https://hydra.cc/docs/tutorials/structured_config/intro/) 24 | @hydra.main(version_base=None) 25 | def run_benchmark(cfg: DictConfig) -> dict: 26 | """DCP benchmarks entry point.""" 27 | benchmark_model = get_benchmark_model(cfg.model) 28 | 29 | return benchmark_common_runner(cfg, run_ddp_save, (cfg, benchmark_model)) 30 | 31 | 32 | def run_ddp_save( 33 | rank: int, # needs to be passed first (provided by `multiprocessing.spawn` automatically) 34 | cfg: DictConfig, 35 | proxy_model: BenchmarkModel, 36 | suffix: str, 37 | save_timestamps: Queue, 38 | ) -> None: 39 | """Execute the actual code for checkpoint saving. 40 | 41 | This function is meant to be executed in subprocesses.""" 42 | begin_process = perf_counter() 43 | 44 | storage_writer = get_writer(cfg, suffix) 45 | model_size = proxy_model.size 46 | model = proxy_model.model 47 | 48 | setup(cfg.backend, world_size=cfg.world_size, rank=rank) 49 | if cfg.backend == "nccl": 50 | device_id = rank % torch.cuda.device_count() 51 | torch.cuda.set_device(device_id) 52 | model.to(device_id) 53 | model = DistributedDataParallel(model, device_ids=[device_id]) 54 | else: 55 | device_id = rank % torch.cpu.device_count() 56 | torch.cpu.set_device(device_id) 57 | model.to(device=torch.device("cpu")) 58 | model = DistributedDataParallel(model) 59 | 60 | state_dict = model.state_dict() 61 | 62 | begin_save = perf_counter() # also "end_process" 63 | dcp.save(state_dict, storage_writer=storage_writer) 64 | end_save = perf_counter() 65 | 66 | # Record the save times excluding the influence of the process setup and model loading to device. 67 | save_timestamps.put( 68 | (begin_process, end_save - (begin_save - begin_process), model_size) 69 | ) 70 | 71 | dist.destroy_process_group() 72 | 73 | 74 | if __name__ == "__main__": 75 | run_benchmark() 76 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_fsdp/README.md: -------------------------------------------------------------------------------- 1 | ## PyTorch's Distributed Checkpoint (DCP) benchmarks using Fully Sharded Data Parallel (FSDP) training 2 | 3 | The `dcp` Python package provides a suite of benchmarks designed to evaluate and measure the performance 4 | of [PyTorch's Distributed Checkpointing (DCP)][DCP] feature in comparison to the `s3torchconnector` library. 5 | 6 | These benchmarks specifically use Fully Sharded Data Parallel (FSDP), which is PyTorch's memory-efficient 7 | distributed training approach where model parameters are sharded across GPUs/processes. 8 | Unlike DDP, FSDP distributes model parameters across processes, making it particularly suitable 9 | for training large models that wouldn't fit in a single GPU's memory. 10 | 11 | ### Purpose 12 | 13 | These benchmarks test both "save" and "load" mechanisms of PyTorch DCP (`torch.distributed.checkpoint.save` and `torch.distributed.checkpoint.load`). The primary objectives are to evaluate the `s3torchconnector` library's performance against other libraries and local storage options, by measuring the following metrics: 14 | 15 | **Save Benchmarks:** 16 | - Checkpoint saving throughput (in MiB/s) 17 | - Checkpoint "corrected" save durations (in seconds), which exclude the influence of model load duration on the device 18 | 19 | **Load Benchmarks:** 20 | - Checkpoint loading throughput (in MiB/s) 21 | - Checkpoint "corrected" load durations (in seconds), which exclude the influence of process setup and model loading to device 22 | 23 | ### Configuration 24 | 25 | The benchmark runs can be customized through configuration files: 26 | 27 | - **Save benchmarks**: [`dcp_fsdp_save.yaml`](../../../conf/dcp_fsdp.yaml) 28 | - **Load benchmarks**: [`dcp_fsdp_load.yaml`](../../../conf/dcp_fsdp_load.yaml) 29 | 30 | The load configuration includes a `checkpoint.suffix` parameter that specifies which saved checkpoint to load. 31 | 32 | > [!IMPORTANT] 33 | > A `+path` option is passed to the running script ([`run_dcp_fsdp_benchmarks.sh`](../../../utils/run_dcp_fsdp_benchmarks.sh)), 34 | > and will be used only if `checkpoint.storage` key includes `disk`. 35 | 36 | ### Usage 37 | 38 | **Save benchmarks (default):** 39 | ```bash 40 | ./utils/run_dcp_fsdp_benchmarks.sh 41 | ./utils/run_dcp_fsdp_benchmarks.sh --save 42 | ``` 43 | 44 | **Load benchmarks:** 45 | ```bash 46 | ./utils/run_dcp_fsdp_benchmarks.sh --load 47 | ``` 48 | 49 | ### References 50 | 51 | - https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html 52 | - https://pytorch.org/docs/stable/elastic/run.html 53 | - https://pytorch.org/tutorials/intermediate/ddp_tutorial.html 54 | 55 | [DCP]: https://pytorch.org/docs/stable/distributed.checkpoint.html 56 | 57 | [multirun]: https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run/ 58 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/python_structs/py_object_info.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use mountpoint_s3_client::types::ObjectInfo; 7 | use pyo3::types::PyTuple; 8 | use pyo3::{pyclass, pymethods}; 9 | use pyo3::{Bound}; 10 | use pyo3::{IntoPyObject, PyResult}; 11 | 12 | use crate::python_structs::py_restore_status::PyRestoreStatus; 13 | use crate::PyRef; 14 | 15 | #[pyclass( 16 | name = "ObjectInfo", 17 | module = "s3torchconnectorclient._mountpoint_s3_client", 18 | frozen 19 | )] 20 | #[derive(Debug, Clone)] 21 | pub struct PyObjectInfo { 22 | #[pyo3(get)] 23 | key: String, 24 | #[pyo3(get)] 25 | etag: String, 26 | #[pyo3(get)] 27 | size: u64, 28 | #[pyo3(get)] 29 | last_modified: i64, 30 | #[pyo3(get)] 31 | storage_class: Option, 32 | #[pyo3(get)] 33 | restore_status: Option, 34 | } 35 | 36 | impl PyObjectInfo { 37 | pub(crate) fn from_object_info(object_info: ObjectInfo) -> Self { 38 | PyObjectInfo::new( 39 | object_info.key, 40 | object_info.etag, 41 | object_info.size, 42 | object_info.last_modified.unix_timestamp(), 43 | object_info.storage_class, 44 | object_info 45 | .restore_status 46 | .map(PyRestoreStatus::from_restore_status), 47 | ) 48 | } 49 | } 50 | 51 | #[pymethods] 52 | impl PyObjectInfo { 53 | #[new] 54 | #[pyo3(signature = (key, etag, size, last_modified, storage_class=None, restore_status=None))] 55 | pub fn new( 56 | key: String, 57 | etag: String, 58 | size: u64, 59 | last_modified: i64, 60 | storage_class: Option, 61 | restore_status: Option, 62 | ) -> Self { 63 | Self { 64 | key, 65 | etag, 66 | size, 67 | last_modified, 68 | storage_class, 69 | restore_status, 70 | } 71 | } 72 | 73 | pub fn __getnewargs__(slf: PyRef<'_, Self>) -> PyResult> { 74 | let py = slf.py(); 75 | let state = [ 76 | slf.key.clone().into_pyobject(py)?.into_any(), 77 | slf.etag.clone().into_pyobject(py)?.into_any(), 78 | slf.size.into_pyobject(py)?.into_any(), 79 | slf.last_modified.into_pyobject(py)?.into_any(), 80 | slf.storage_class.clone().into_pyobject(py)?.into_any(), 81 | slf.restore_status.clone().into_pyobject(py)?.into_any(), 82 | ]; 83 | PyTuple::new(py, state) 84 | } 85 | 86 | fn __repr__(&self) -> String { 87 | format!("{:?}", self) 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_ddp/load_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | from multiprocessing.queues import Queue 6 | from time import perf_counter 7 | from typing import Tuple 8 | 9 | import hydra 10 | import torch 11 | import torch.distributed as dist 12 | import torch.distributed.checkpoint as dcp 13 | from omegaconf import DictConfig 14 | from torch.nn.parallel import DistributedDataParallel 15 | 16 | from s3torchbenchmarking.dcp_common import setup, get_reader, benchmark_common_runner 17 | from s3torchbenchmarking.models import get_benchmark_model, BenchmarkModel 18 | 19 | Timestamps = Tuple[float, float] 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | # TODO: add Structured Config (https://hydra.cc/docs/tutorials/structured_config/intro/) 24 | @hydra.main(version_base=None) 25 | def run_benchmark(cfg: DictConfig) -> dict: 26 | """DCP benchmarks entry point.""" 27 | benchmark_model = get_benchmark_model(cfg.model) 28 | 29 | return benchmark_common_runner(cfg, run_ddp_load, (cfg, benchmark_model)) 30 | 31 | 32 | def run_ddp_load( 33 | rank: int, # needs to be passed first (provided by `multiprocessing.spawn` automatically) 34 | cfg: DictConfig, 35 | proxy_model: BenchmarkModel, 36 | suffix: str, 37 | load_timestamps: Queue, 38 | ) -> None: 39 | """Execute the actual code for checkpoint loading. 40 | 41 | This function is meant to be executed in subprocesses.""" 42 | begin_process = perf_counter() 43 | # Override random suffix with suffix from config 44 | storage_reader = get_reader(cfg) 45 | model_size = proxy_model.size 46 | model = proxy_model.model 47 | 48 | setup(cfg.backend, world_size=cfg.world_size, rank=rank) 49 | if cfg.backend == "nccl": 50 | device_id = rank % torch.cuda.device_count() 51 | torch.cuda.set_device(device_id) 52 | model.to(device_id) 53 | model = DistributedDataParallel(model, device_ids=[device_id]) 54 | else: 55 | device_id = rank % torch.cpu.device_count() 56 | torch.cpu.set_device(device_id) 57 | model.to(device=torch.device("cpu")) 58 | model = DistributedDataParallel(model) 59 | 60 | state_dict = model.state_dict() 61 | 62 | begin_load = perf_counter() # also "end_process" 63 | dcp.load(state_dict, storage_reader=storage_reader) 64 | end_load = perf_counter() 65 | 66 | # Record the load times excluding the influence of the process setup and model loading to device. 67 | load_timestamps.put( 68 | (begin_process, end_load - (begin_load - begin_process), model_size) 69 | ) 70 | 71 | dist.destroy_process_group() 72 | 73 | 74 | if __name__ == "__main__": 75 | run_benchmark() 76 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/_s3_bucket_iterable.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from functools import partial 5 | from itertools import chain 6 | from typing import Iterator, List 7 | 8 | from s3torchconnectorclient._mountpoint_s3_client import ( 9 | ObjectInfo, 10 | ListObjectResult, 11 | ListObjectStream, 12 | ) 13 | 14 | from ._s3bucket_key_data import S3BucketKeyData 15 | from ._s3client import S3Client 16 | 17 | 18 | class S3BucketIterable: 19 | def __init__(self, client: S3Client, bucket: str, prefix: str): 20 | self._client = client 21 | self._bucket = bucket 22 | self._prefix = prefix 23 | 24 | def __iter__(self) -> Iterator[S3BucketKeyData]: 25 | # This allows us to iterate multiple times by re-creating the `_list_stream` 26 | return iter(S3BucketIterator(self._client, self._bucket, self._prefix)) 27 | 28 | 29 | class S3BucketIterator: 30 | def __init__(self, client: S3Client, bucket: str, prefix: str): 31 | self._client = client 32 | self._bucket = bucket 33 | self._list_stream = _PickleableListObjectStream(client, bucket, prefix) 34 | 35 | def __iter__(self) -> Iterator[S3BucketKeyData]: 36 | return chain.from_iterable( 37 | map(partial(_extract_list_results, self._bucket), self._list_stream) 38 | ) 39 | 40 | 41 | class _PickleableListObjectStream: 42 | def __init__(self, client: S3Client, bucket: str, prefix: str): 43 | self._client = client 44 | self._list_stream = iter(client.list_objects(bucket, prefix)) 45 | 46 | def __iter__(self): 47 | return self 48 | 49 | def __next__(self) -> ListObjectResult: 50 | return next(self._list_stream) 51 | 52 | def __getstate__(self): 53 | return { 54 | "client": self._client, 55 | "bucket": self._list_stream.bucket, 56 | "prefix": self._list_stream.prefix, 57 | "delimiter": self._list_stream.delimiter, 58 | "max_keys": self._list_stream.max_keys, 59 | "continuation_token": self._list_stream.continuation_token, 60 | "complete": self._list_stream.complete, 61 | } 62 | 63 | def __setstate__(self, state): 64 | self._client = state["client"] 65 | self._list_stream = ListObjectStream._from_state(**state) 66 | 67 | 68 | def _extract_list_results( 69 | bucket: str, list_result: ListObjectResult 70 | ) -> Iterator[S3BucketKeyData]: 71 | return map(partial(_extract_object_info, bucket), list_result.object_info) 72 | 73 | 74 | def _extract_object_info(bucket: str, object_info: ObjectInfo) -> S3BucketKeyData: 75 | return S3BucketKeyData(bucket=bucket, key=object_info.key, object_info=object_info) 76 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: File a bug report 3 | labels: ["bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thank you for taking the time to submit a bug report! 9 | - type: markdown 10 | attributes: 11 | value: | 12 | **Note: Security issues should not be reported here.** 13 | Please follow the [security policy for this repository](https://github.com/awslabs/s3-connector-for-pytorch/security/policy). 14 | - type: input 15 | id: s3torchconnector-version 16 | attributes: 17 | label: s3torchconnector version 18 | description: | 19 | Which version of s3torchconnector are you using? 20 | If you are building from source or a fork, please state that. 21 | placeholder: s3torchconnector-x.y 22 | validations: 23 | required: true 24 | - type: input 25 | id: s3torchconnectorclient-version 26 | attributes: 27 | label: s3torchconnectorclient version 28 | description: | 29 | Which version of s3torchconnectorclient are you using? 30 | If you are building from source or a fork, please state that. 31 | placeholder: s3torchconnectorclient-x.y 32 | validations: 33 | required: true 34 | - type: input 35 | id: region 36 | attributes: 37 | label: AWS Region 38 | description: Which AWS region did you experience the bug in? 39 | placeholder: us-east-1 40 | validations: 41 | required: false 42 | - type: textarea 43 | id: environment 44 | attributes: 45 | label: Describe the running environment 46 | description: | 47 | What else can you share about the environment you are running the project? 48 | For example, was this using Amazon EC2? Which type/OS version/architecture? 49 | placeholder: Running in EC2 on Amazon Linux 2. 50 | validations: 51 | required: true 52 | - type: textarea 53 | id: behavior 54 | attributes: 55 | label: What happened? 56 | description: Please also tell us what you expected to happen. 57 | placeholder: The connector failed to load my checkpoint from S3. 58 | validations: 59 | required: true 60 | - type: textarea 61 | id: logs 62 | attributes: 63 | label: Relevant log output 64 | description: Please copy and paste any relevant log output. This will be automatically formatted into code, so no need for backticks. 65 | render: shell 66 | validations: 67 | required: false 68 | - type: checkboxes 69 | id: terms 70 | attributes: 71 | label: Code of Conduct 72 | description: By submitting this issue, you agree to follow our [Code of Conduct](https://github.com/awslabs/s3-connector-for-pytorch/blob/main/CODE_OF_CONDUCT.md) 73 | options: 74 | - label: I agree to follow this project's Code of Conduct 75 | required: true 76 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/test_s3reader_constructor.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import pytest 5 | 6 | from s3torchconnector import S3ReaderConstructor 7 | from s3torchconnector.s3reader import SequentialS3Reader, RangedS3Reader 8 | from s3torchconnector.s3reader.ranged import DEFAULT_BUFFER_SIZE 9 | 10 | TEST_BUCKET = "test-bucket" 11 | TEST_KEY = "test-key" 12 | 13 | 14 | def test_s3readerconstructor_sequential_constructor(): 15 | """Test sequential reader construction""" 16 | constructor = S3ReaderConstructor.sequential() 17 | s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) 18 | assert isinstance(s3reader, SequentialS3Reader) 19 | 20 | 21 | def test_s3readerconstructor_range_based_constructor(): 22 | """Test range-based reader construction""" 23 | constructor = S3ReaderConstructor.range_based() 24 | s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) 25 | assert isinstance(s3reader, RangedS3Reader) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "buffer_size, expected_buffer_size, expected_enable_buffering", 30 | [ 31 | (None, DEFAULT_BUFFER_SIZE, True), # Default buffer size 32 | (16 * 1024 * 1024, 16 * 1024 * 1024, True), # Custom buffer size 33 | (0, 0, False), # Disabled buffering 34 | ], 35 | ) 36 | def test_s3readerconstructor_range_based_constructor_buffer_configurations( 37 | buffer_size, expected_buffer_size, expected_enable_buffering 38 | ): 39 | """Test range-based reader construction with different buffer configurations""" 40 | constructor = S3ReaderConstructor.range_based(buffer_size=buffer_size) 41 | s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) 42 | 43 | assert isinstance(s3reader, RangedS3Reader) 44 | assert s3reader._buffer_size == expected_buffer_size 45 | assert s3reader._enable_buffering is expected_enable_buffering 46 | 47 | 48 | def test_s3readerconstructor_default_constructor(): 49 | """Test default constructor returns sequential reader""" 50 | constructor = S3ReaderConstructor.default() 51 | s3reader = constructor(TEST_BUCKET, TEST_KEY, lambda: None, lambda: iter([])) 52 | assert isinstance(s3reader, SequentialS3Reader) 53 | 54 | 55 | def test_s3readerconstructor_get_reader_type_string(): 56 | """Test reader type string generation""" 57 | assert ( 58 | S3ReaderConstructor.get_reader_type_string(S3ReaderConstructor.sequential()) 59 | == "sequential" 60 | ) 61 | assert ( 62 | S3ReaderConstructor.get_reader_type_string(S3ReaderConstructor.range_based()) 63 | == "range_based" 64 | ) 65 | assert S3ReaderConstructor.get_reader_type_string(None) == "sequential" 66 | assert ( 67 | S3ReaderConstructor.get_reader_type_string(S3ReaderConstructor.default()) 68 | == "sequential" 69 | ) 70 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/mock_client.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use std::sync::Arc; 7 | 8 | use mountpoint_s3_client::mock_client::{MockClient, MockClientConfig, MockObject}; 9 | use pyo3::{pyclass, pymethods}; 10 | 11 | use crate::MountpointS3Client; 12 | 13 | #[derive(Clone)] 14 | #[pyclass( 15 | name = "MockMountpointS3Client", 16 | module = "s3torchconnectorclient._mountpoint_s3_client", 17 | frozen 18 | )] 19 | pub struct PyMockClient { 20 | mock_client: Arc, 21 | #[pyo3(get)] 22 | pub(crate) throughput_target_gbps: f64, 23 | #[pyo3(get)] 24 | pub(crate) region: String, 25 | #[pyo3(get)] 26 | pub(crate) part_size: usize, 27 | #[pyo3(get)] 28 | pub(crate) user_agent_prefix: String, 29 | #[pyo3(get)] 30 | pub(crate) unsigned: bool, 31 | #[pyo3(get)] 32 | pub(crate) force_path_style: bool, 33 | #[pyo3(get)] 34 | max_attempts: usize, 35 | } 36 | 37 | #[pymethods] 38 | impl PyMockClient { 39 | #[new] 40 | #[pyo3(signature = (region, bucket, throughput_target_gbps = 10.0, part_size = 8 * 1024 * 1024, user_agent_prefix="mock_client".to_string(), unsigned=false, force_path_style=false, max_attempts=10))] 41 | #[allow(clippy::too_many_arguments)] 42 | pub fn new( 43 | region: String, 44 | bucket: String, 45 | throughput_target_gbps: f64, 46 | part_size: usize, 47 | user_agent_prefix: String, 48 | unsigned: bool, 49 | force_path_style: bool, 50 | max_attempts: usize, 51 | ) -> PyMockClient { 52 | let unordered_list_seed: Option = None; 53 | let config = MockClientConfig { 54 | bucket, 55 | part_size, 56 | unordered_list_seed, 57 | ..Default::default() 58 | }; 59 | let mock_client = Arc::new(MockClient::new(config)); 60 | 61 | PyMockClient { 62 | mock_client, 63 | region, 64 | throughput_target_gbps, 65 | part_size, 66 | user_agent_prefix, 67 | unsigned, 68 | force_path_style, 69 | max_attempts 70 | } 71 | } 72 | 73 | fn create_mocked_client(&self) -> MountpointS3Client { 74 | MountpointS3Client::new( 75 | self.region.clone(), 76 | self.user_agent_prefix.to_string(), 77 | self.throughput_target_gbps, 78 | self.part_size, 79 | None, 80 | self.unsigned, 81 | self.force_path_style, 82 | self.max_attempts, 83 | self.mock_client.clone(), 84 | None, 85 | ) 86 | } 87 | 88 | fn add_object(&self, key: String, data: Vec) { 89 | self.mock_client.add_object(&key, MockObject::from(data)); 90 | } 91 | 92 | fn remove_object(&self, key: String) { 93 | self.mock_client.remove_object(&key); 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/lightning_checkpointing/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | from pathlib import Path 6 | 7 | import hydra 8 | import pandas as pd 9 | from lightning import Trainer 10 | from lightning.pytorch import callbacks 11 | from lightning.pytorch.strategies import SingleDeviceStrategy 12 | from omegaconf import DictConfig 13 | from torch.utils.data import DataLoader 14 | from torchdata.datapipes.iter import IterableWrapper # type: ignore 15 | 16 | from s3torchbenchmarking.benchmark_utils import ( 17 | ResourceMonitor, 18 | build_checkpoint_uri, 19 | build_random_suffix, 20 | ) 21 | from s3torchbenchmarking.lightning_checkpointing.checkpoint_profiler import ( 22 | CheckpointProfiler, 23 | ) 24 | from s3torchbenchmarking.models import get_benchmark_model, LightningAdapter 25 | from s3torchconnector.lightning import S3LightningCheckpoint 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | # TODO: add Structured Config (https://hydra.cc/docs/tutorials/structured_config/intro/) 31 | @hydra.main(version_base=None) 32 | def run_benchmark(config: DictConfig) -> dict: 33 | """Lightning benchmarks entry point.""" 34 | benchmark_model = get_benchmark_model(config.model) 35 | 36 | strategy = SingleDeviceStrategy() 37 | 38 | if config.checkpoint.storage == "disk": 39 | checkpoint_callback = callbacks.ModelCheckpoint(dirpath=config.path) 40 | checkpoint_io = strategy.checkpoint_io 41 | else: 42 | checkpoint_callback = callbacks.ModelCheckpoint(dirpath=config.s3.uri) 43 | checkpoint_io = S3LightningCheckpoint(config.s3.region) 44 | 45 | profiling_checkpointer = CheckpointProfiler(checkpoint_io) 46 | trainer = Trainer( 47 | logger=False, plugins=[profiling_checkpointer], callbacks=[checkpoint_callback] 48 | ) 49 | dataloader = DataLoader(IterableWrapper([]), num_workers=8) 50 | trainer.fit( 51 | LightningAdapter.DelegateModule(benchmark_model.model), 52 | train_dataloaders=dataloader, 53 | ) 54 | 55 | suffix = build_random_suffix() 56 | with ResourceMonitor() as monitor: 57 | for i in range(config.epochs): 58 | filepath = f"{suffix}/{config.model}-{i}.ckpt" 59 | if config.checkpoint.storage == "disk": 60 | checkpoint_path = Path(config.path) / filepath 61 | else: 62 | checkpoint_path = build_checkpoint_uri(config.s3.uri, filepath) 63 | trainer.save_checkpoint(checkpoint_path) 64 | 65 | save_times_s = pd.Series(profiling_checkpointer.save_times) 66 | throughput_mibs = benchmark_model.size / save_times_s 67 | 68 | metrics = { 69 | "throughput_mibs": throughput_mibs.dropna().to_list(), 70 | "save_times_s": save_times_s.dropna().to_list(), 71 | "utilization": {k: v.summarize() for k, v in monitor.resource_data.items()}, 72 | } 73 | return {"metrics": metrics} 74 | 75 | 76 | if __name__ == "__main__": 77 | run_benchmark() 78 | -------------------------------------------------------------------------------- /.github/workflows/rust-checks.yml: -------------------------------------------------------------------------------- 1 | name: Rust Checks 2 | 3 | on: 4 | workflow_call: 5 | 6 | permissions: 7 | contents: read 8 | 9 | env: 10 | RUST_BACKTRACE: 1 11 | CARGO_TERM_COLOR: always 12 | CARGO_INCREMENTAL: 0 13 | RUSTFLAGS: "-Dwarnings" 14 | RUST_TOOLCHAIN: 1.84.1 15 | 16 | jobs: 17 | deny: 18 | runs-on: ubuntu-22.04 19 | name: Licenses 20 | strategy: 21 | matrix: 22 | checks: 23 | # The advisories check is used to detect issues for crates by looking in an advisory database. 24 | - advisories 25 | # The bans check is used to deny (or allow) specific crates, as well as detect and handle multiple 26 | # versions of the same crate. 27 | # The licenses check is used to verify that every crate you use has license terms you find acceptable. 28 | # The sources check ensures crates only come from sources you trust. 29 | - bans licenses sources 30 | steps: 31 | - name: Checkout code 32 | uses: actions/checkout@v6 33 | 34 | - name: Set up stable Rust 35 | uses: dtolnay/rust-toolchain@stable 36 | with: 37 | toolchain: ${{env.RUST_TOOLCHAIN }} 38 | 39 | - name: Run cargo deny 40 | uses: EmbarkStudios/cargo-deny-action@v2 41 | with: 42 | command: check ${{ matrix.checks }} 43 | manifest-path: s3torchconnectorclient/Cargo.toml 44 | 45 | clippy: 46 | runs-on: ubuntu-22.04 47 | name: Clippy 48 | steps: 49 | - name: Checkout code 50 | uses: actions/checkout@v6 51 | 52 | - name: Set up stable Rust 53 | uses: dtolnay/rust-toolchain@stable 54 | with: 55 | toolchain: ${{env.RUST_TOOLCHAIN }} 56 | components: clippy 57 | 58 | - name: Cargo cache 59 | uses: actions/cache@v5 60 | with: 61 | path: | 62 | ~/.cargo/registry/index/ 63 | ~/.cargo/registry/cache/ 64 | ~/.cargo/git/db/ 65 | target/ 66 | key: ${{ runner.os }}-${{ github.job }}-cargo-${{ hashFiles('**/Cargo.lock') }} 67 | 68 | - name: Lint with clippy 69 | run: cargo clippy --all-targets --all-features --manifest-path s3torchconnectorclient/Cargo.toml 70 | 71 | tests: 72 | runs-on: ${{ matrix.runner }} 73 | name: Rust tests 74 | strategy: 75 | matrix: 76 | # macos-15-intel planned to end August 2027 77 | runner: [ubuntu-22.04, macos-15-intel] 78 | steps: 79 | - name: Checkout code 80 | uses: actions/checkout@v6 81 | 82 | # Use Python 3.13 since Python 3.14 requires pyo3 version bump from 0.24.0 83 | - name: Set up Python 3.13 84 | uses: actions/setup-python@v6 85 | with: 86 | python-version: "3.13" 87 | 88 | - name: Set up stable Rust 89 | uses: dtolnay/rust-toolchain@stable 90 | with: 91 | toolchain: ${{env.RUST_TOOLCHAIN }} 92 | 93 | - name: Build Rust tests 94 | run: cargo test --no-default-features --no-run --manifest-path s3torchconnectorclient/Cargo.toml 95 | 96 | - name: Run Rust tests 97 | run: cargo test --no-default-features --manifest-path s3torchconnectorclient/Cargo.toml 98 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/s3writer.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import io 5 | from typing import Union 6 | import threading 7 | import logging 8 | 9 | from s3torchconnectorclient._mountpoint_s3_client import PutObjectStream 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class S3Writer(io.BufferedIOBase): 15 | """A write-only, file like representation of a single object stored in S3.""" 16 | 17 | def __init__(self, stream: PutObjectStream): 18 | self.stream = stream 19 | self._position = 0 20 | self._closed = False 21 | self._lock = threading.Lock() 22 | 23 | def __enter__(self): 24 | self._position = 0 25 | return self 26 | 27 | def __exit__(self, exc_type, exc_val, exc_tb): 28 | """Close stream on normal exit, log any exceptions that occurred.""" 29 | if exc_type is not None: 30 | try: 31 | logger.info( 32 | f"Exception occurred before closing stream: {exc_type.__name__}: {exc_val}" 33 | ) 34 | except: 35 | pass 36 | else: 37 | self.close() 38 | 39 | def write( 40 | self, 41 | # Ignoring the type for this as we don't currently support the Buffer protocol 42 | data: Union[bytes, memoryview], # type: ignore 43 | ) -> int: 44 | """Write bytes to S3 Object specified by bucket and key 45 | 46 | Args: 47 | data (bytes | memoryview): bytes to write 48 | 49 | Returns: 50 | int: Number of bytes written 51 | 52 | Raises: 53 | S3Exception: An error occurred accessing S3. 54 | ValueError: If the writer is closed. 55 | """ 56 | self._checkClosed() # from python/cpython/Lib/_pyio.py 57 | if isinstance(data, memoryview): 58 | data = data.tobytes() 59 | self.stream.write(data) 60 | self._position += len(data) 61 | return len(data) 62 | 63 | def close(self): 64 | """Close write-stream to S3. Ensures all bytes are written successfully. 65 | 66 | Raises: 67 | S3Exception: An error occurred accessing S3. 68 | """ 69 | with self._lock: 70 | if not self._closed: 71 | self._closed = True 72 | self.stream.close() 73 | 74 | @property 75 | def closed(self) -> bool: 76 | """ 77 | Returns: 78 | bool: Return whether the object is closed. 79 | """ 80 | return self._closed 81 | 82 | def flush(self): 83 | """No-op""" 84 | pass 85 | 86 | def readable(self) -> bool: 87 | """ 88 | Returns: 89 | bool: Return whether object was opened for reading. 90 | """ 91 | return False 92 | 93 | def writable(self) -> bool: 94 | """ 95 | Returns: 96 | bool: Return whether object is open for writing. 97 | """ 98 | return not self.closed 99 | 100 | def tell(self) -> int: 101 | """ 102 | Returns: 103 | int: Current stream position. 104 | """ 105 | return self._position 106 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/dcp/test_e2e_s3_storage_reader.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import pytest 5 | from unittest.mock import patch 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.distributed.checkpoint as dcp 10 | 11 | from s3torchconnector import S3ReaderConstructor 12 | from s3torchconnector.dcp import S3StorageWriter, S3StorageReader 13 | from s3torchconnector.s3reader.sequential import SequentialS3Reader 14 | 15 | 16 | SIMPLE_MODEL = torch.nn.Sequential( 17 | nn.Linear(5, 5), 18 | nn.Linear(20, 20), 19 | nn.Linear(10, 10), 20 | ) 21 | 22 | 23 | class NeuralNetwork(nn.Module): 24 | """NeuralNetwork from PyTorch quickstart tutorial.""" 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.flatten = nn.Flatten() 29 | self.linear_relu_stack = nn.Sequential( 30 | nn.Linear(28 * 28, 512), 31 | nn.ReLU(), 32 | nn.Linear(512, 512), 33 | nn.ReLU(), 34 | nn.Linear(512, 10), 35 | ) 36 | 37 | 38 | LARGER_MODEL = NeuralNetwork() 39 | 40 | 41 | @pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL]) 42 | def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model): 43 | """ 44 | Test that prepare_local_plan allows dcp.load() to read items in offset order. 45 | 46 | This does not prevent backwards seek, since torch.load() would still call 47 | backwards seek operations. 48 | 49 | pytorch/torch/serialization.py load() function will call _is_zipfile(), which 50 | includes this read() call: f.read(len(local_header_magic_number)). This is 51 | followed by readinto() calls on the actual tensor. 52 | 53 | Hence we can track read() call positions to determine if load ordering is 54 | being applied correctly. 55 | """ 56 | region = checkpoint_directory.region 57 | s3_uri = checkpoint_directory.s3_uri 58 | 59 | state_dict = model.state_dict() 60 | storage_writer = S3StorageWriter(region=region, path=s3_uri, overwrite=True) 61 | dcp.save(state_dict, storage_writer=storage_writer) 62 | 63 | read_positions = [] 64 | 65 | original_read = SequentialS3Reader.read 66 | 67 | def track_reads(self, size=None): 68 | if not self.key.endswith(".metadata"): 69 | read_positions.append(self._position) 70 | return original_read(self, size) 71 | 72 | # Load with position tracking on read() (called at the start of each torch.load()) 73 | with patch.object(SequentialS3Reader, "read", track_reads): 74 | loaded_state_dict = {k: torch.empty_like(v) for k, v in state_dict.items()} 75 | storage_reader = S3StorageReader( 76 | region=region, 77 | path=s3_uri, 78 | reader_constructor=S3ReaderConstructor.sequential(), 79 | ) 80 | dcp.load(loaded_state_dict, storage_reader=storage_reader) 81 | 82 | print(f"Read positions: {read_positions}") 83 | 84 | # Assert load ordering works (read() calls should be in sorted order) 85 | assert read_positions == sorted(read_positions) 86 | 87 | # Assert all tensors are correctly loaded 88 | assert len(loaded_state_dict) == len(state_dict) 89 | assert loaded_state_dict.keys() == state_dict.keys() 90 | for key in state_dict: 91 | assert torch.equal(loaded_state_dict[key], state_dict[key]) 92 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. When writing a Git commit message, follow these [guidelines](https://chris.beams.io/posts/git-commit/). 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | ## Finding contributions to work on 43 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 44 | 45 | 46 | ## Code of Conduct 47 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 48 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 49 | opensource-codeofconduct@amazon.com with any additional questions or comments. 50 | 51 | 52 | ## Security issue notifications 53 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 54 | 55 | 56 | ## Licensing 57 | 58 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 59 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/pytorch_checkpointing/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | from pathlib import Path 6 | from time import perf_counter 7 | from typing import Dict, Any 8 | 9 | import hydra 10 | import pandas as pd 11 | import torch 12 | from omegaconf import DictConfig 13 | 14 | from s3torchbenchmarking.benchmark_utils import ( 15 | ResourceMonitor, 16 | build_checkpoint_uri, 17 | build_random_suffix, 18 | ) 19 | from s3torchbenchmarking.models import get_benchmark_model 20 | from s3torchconnector import S3Checkpoint 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def create_checkpoint_dir(path: str, suffix: str) -> Path: 26 | """Create and return the checkpoint directory.""" 27 | parent_folder = Path(path) / suffix 28 | parent_folder.mkdir(parents=True, exist_ok=True) 29 | 30 | 31 | def save_checkpoint( 32 | model: torch.nn.Module, path: str, checkpoint: S3Checkpoint = None 33 | ) -> float: 34 | """Save checkpoint and return the time taken.""" 35 | start_time = perf_counter() 36 | if checkpoint: 37 | with checkpoint.writer(path) as writer: 38 | torch.save(model.state_dict(), writer) 39 | else: 40 | torch.save(model.state_dict(), path) 41 | end_time = perf_counter() 42 | return end_time - start_time 43 | 44 | 45 | def calculate_metrics( 46 | save_times: list, model_size: float, monitor: ResourceMonitor 47 | ) -> Dict[str, Any]: 48 | """Calculate and return benchmark metrics.""" 49 | save_times_s = pd.Series(save_times) 50 | throughput_mibs = model_size / save_times_s 51 | return { 52 | "throughput_mibs": throughput_mibs.dropna().tolist(), 53 | "save_times_s": save_times_s.dropna().tolist(), 54 | "utilization": {k: v.summarize() for k, v in monitor.resource_data.items()}, 55 | } 56 | 57 | 58 | @hydra.main(version_base=None) 59 | def run_benchmark(config: DictConfig) -> Dict[str, Any]: 60 | """Checkpoint benchmarks entry point.""" 61 | logger.info("Starting Checkpoint benchmark run") 62 | 63 | try: 64 | benchmark_model = get_benchmark_model(config.model) 65 | checkpoint = ( 66 | None 67 | if config.checkpoint.storage == "disk" 68 | else S3Checkpoint(region=config.s3.region) 69 | ) 70 | 71 | suffix = build_random_suffix() 72 | if config.checkpoint.storage == "disk": 73 | create_checkpoint_dir(config.path, suffix) 74 | 75 | save_times = [] 76 | with ResourceMonitor() as monitor: 77 | for i in range(config.epochs): 78 | filepath = f"{suffix}/{config.model}-{i}.ckpt" 79 | if config.checkpoint.storage == "disk": 80 | checkpoint_path = Path(config.path) / filepath 81 | else: 82 | checkpoint_path = build_checkpoint_uri(config.s3.uri, filepath) 83 | 84 | logger.info(f"Saving checkpoint to {checkpoint_path}") 85 | save_time = save_checkpoint( 86 | benchmark_model.model, str(checkpoint_path), checkpoint 87 | ) 88 | save_times.append(save_time) 89 | 90 | metrics = calculate_metrics(save_times, benchmark_model.size, monitor) 91 | logger.info("Benchmark run completed successfully") 92 | return {"metrics": metrics} 93 | 94 | except Exception as e: 95 | logger.exception(f"An error occurred during the benchmark: {str(e)}") 96 | return {"error": str(e)} 97 | 98 | 99 | if __name__ == "__main__": 100 | run_benchmark() 101 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/get_object_stream.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use pyo3::types::PyBytes; 7 | use pyo3::{pyclass, pymethods, Bound, PyRef, PyRefMut, PyResult}; 8 | use mountpoint_s3_client::types::GetBodyPart; 9 | 10 | use crate::exception::S3Exception; 11 | use crate::mountpoint_s3_client_inner::MPGetObjectClosure; 12 | 13 | #[pyclass( 14 | name = "GetObjectStream", 15 | module = "s3torchconnectorclient._mountpoint_s3_client" 16 | )] 17 | pub struct GetObjectStream { 18 | next_part: MPGetObjectClosure, 19 | offset: u64, 20 | #[pyo3(get)] 21 | bucket: String, 22 | #[pyo3(get)] 23 | key: String, 24 | } 25 | 26 | impl GetObjectStream { 27 | pub(crate) fn new(next_part: MPGetObjectClosure, bucket: String, key: String, start_offset: Option) -> Self { 28 | Self { 29 | next_part, 30 | offset: start_offset.unwrap_or(0), 31 | bucket, 32 | key, 33 | } 34 | } 35 | } 36 | 37 | #[pymethods] 38 | impl GetObjectStream { 39 | pub fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { 40 | slf 41 | } 42 | 43 | pub fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult>> { 44 | let py = slf.py(); 45 | 46 | let body_part = (slf.next_part)(py)?; 47 | match body_part { 48 | None => Ok(None), 49 | Some(GetBodyPart { offset, data }) => { 50 | if offset != slf.offset { 51 | return Err(S3Exception::new_err( 52 | "Data from S3 was returned out of order!", 53 | )); 54 | } 55 | slf.offset += data.len() as u64; 56 | let data = PyBytes::new(py, data.as_ref()); 57 | Ok(Some(data)) 58 | } 59 | } 60 | } 61 | 62 | pub fn tell(slf: PyRef<'_, Self>) -> u64 { 63 | slf.offset 64 | } 65 | } 66 | 67 | #[cfg(test)] 68 | mod tests { 69 | use pyo3::types::IntoPyDict; 70 | use pyo3::{py_run, PyResult, Python}; 71 | use tracing_subscriber::layer::SubscriberExt; 72 | use tracing_subscriber::util::SubscriberInitExt; 73 | 74 | use crate::mock_client::PyMockClient; 75 | use crate::mountpoint_s3_client::MountpointS3Client; 76 | 77 | #[test] 78 | fn test_get_object() -> PyResult<()> { 79 | let layer = tracing_subscriber::fmt::layer().with_ansi(true); 80 | let registry = tracing_subscriber::registry().with(layer); 81 | let _ = registry.try_init(); 82 | 83 | pyo3::prepare_freethreaded_python(); 84 | 85 | Python::with_gil(|py| { 86 | let locals = [ 87 | ( 88 | "MountpointS3Client", 89 | py.get_type::(), 90 | ), 91 | ( 92 | "MockMountpointS3Client", 93 | py.get_type::(), 94 | ), 95 | ]; 96 | 97 | py_run!( 98 | py, 99 | *locals.into_py_dict(py).unwrap(), 100 | r#" 101 | mock_client = MockMountpointS3Client("us-east-1", "mock-bucket") 102 | client = mock_client.create_mocked_client() 103 | 104 | mock_client.add_object("key", b"data") 105 | stream = client.get_object("mock-bucket", "key") 106 | 107 | returned_data = b''.join(stream) 108 | assert returned_data == b"data" 109 | "# 110 | ); 111 | }); 112 | 113 | Ok(()) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/test_s3writer.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from io import BytesIO 5 | from typing import List, Tuple 6 | from unittest.mock import Mock 7 | import threading 8 | 9 | import pytest 10 | from hypothesis import given 11 | from hypothesis.strategies import lists, binary, composite 12 | from s3torchconnectorclient._mountpoint_s3_client import ObjectInfo, PutObjectStream 13 | 14 | from s3torchconnector import S3Writer 15 | 16 | MOCK_OBJECT_INFO = Mock(ObjectInfo) 17 | MOCK_STREAM = Mock(PutObjectStream) 18 | 19 | 20 | @composite 21 | def bytestream_and_lengths(draw): 22 | byte_array = draw(lists(binary(min_size=1, max_size=5000))) 23 | lengths = [len(b) for b in byte_array] 24 | return byte_array, lengths 25 | 26 | 27 | def test_s3writer_creation(): 28 | s3writer = S3Writer(MOCK_STREAM) 29 | assert s3writer 30 | assert isinstance(s3writer.stream, PutObjectStream) 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "stream", 35 | [ 36 | [b"1", b"2", b"3"], 37 | [], 38 | [b"hello!"], 39 | ], 40 | ) 41 | def test_s3writer_write(stream): 42 | s3writer = S3Writer(MOCK_STREAM) 43 | s3writer.write(stream) 44 | s3writer.close() 45 | MOCK_STREAM.write.assert_called_with(stream) 46 | 47 | 48 | @given(bytestream_and_lengths()) 49 | def test_s3writer_tell(stream_and_lengths: Tuple[List[bytes], List[int]]): 50 | with S3Writer(MOCK_STREAM) as s3writer, BytesIO() as bytewriter: 51 | for data, length in zip(*stream_and_lengths): 52 | b_length = s3writer.write(data) 53 | bytewriter.write(data) 54 | 55 | assert b_length == length 56 | assert bytewriter.tell() == s3writer.tell() 57 | 58 | 59 | def test_s3writer_closed_property(): 60 | """Test that closed property reflects the writer's state.""" 61 | writer = S3Writer(MOCK_STREAM) 62 | assert not writer.closed 63 | writer.close() 64 | assert writer.closed 65 | 66 | 67 | def test_s3writer_not_writable_after_closed(): 68 | """Test that writable() returns False after writer is closed.""" 69 | writer = S3Writer(MOCK_STREAM) 70 | assert writer.writable() 71 | writer.close() 72 | assert not writer.writable() 73 | 74 | 75 | def test_s3writer_write_when_closed(): 76 | """Test that write() raises ValueError via _checkClosed() when writer is closed""" 77 | writer = S3Writer(MOCK_STREAM) 78 | writer.close() 79 | with pytest.raises(ValueError): 80 | writer.write(b"test") 81 | 82 | 83 | def test_multiple_close_calls(): 84 | """Test that multiple calls to close() only close the stream once.""" 85 | MOCK_STREAM.reset_mock() 86 | 87 | writer = S3Writer(MOCK_STREAM) 88 | 89 | writer.close() 90 | writer.close() 91 | writer.close() 92 | 93 | MOCK_STREAM.close.assert_called_once() 94 | assert writer.closed 95 | 96 | 97 | def test_concurrent_close_calls(): 98 | """Test that concurrent calls to close() only close the stream once.""" 99 | MOCK_STREAM.reset_mock() 100 | 101 | writer = S3Writer(MOCK_STREAM) 102 | threads = [] 103 | 104 | for _ in range(5): 105 | thread = threading.Thread(target=writer.close) 106 | threads.append(thread) 107 | thread.start() 108 | 109 | for thread in threads: 110 | thread.join() 111 | 112 | MOCK_STREAM.close.assert_called_once() 113 | assert writer.closed 114 | 115 | 116 | def test_exit_without_exception(): 117 | """Test __exit__ method when no exception occurs.""" 118 | MOCK_STREAM.reset_mock() 119 | 120 | writer = S3Writer(MOCK_STREAM) 121 | writer.__exit__(None, None, None) 122 | 123 | MOCK_STREAM.close.assert_called_once() 124 | 125 | 126 | def test_exit_with_exception(caplog): 127 | """Test __exit__ method when an exception occurs.""" 128 | MOCK_STREAM.reset_mock() 129 | 130 | writer = S3Writer(MOCK_STREAM) 131 | test_exception = ValueError("Test exception") 132 | writer.__exit__(ValueError, test_exception, None) 133 | 134 | # Stream should not be closed on exception 135 | MOCK_STREAM.close.assert_not_called() 136 | -------------------------------------------------------------------------------- /s3torchconnectorclient/python/tst/unit/test_structs.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import pickle 5 | from typing import Optional 6 | 7 | from hypothesis import given, example 8 | from hypothesis.strategies import booleans, integers, none, one_of, builds, text 9 | from s3torchconnectorclient._mountpoint_s3_client import RestoreStatus, ObjectInfo 10 | 11 | restore_status_args = ( 12 | booleans(), 13 | one_of(none(), integers(min_value=0, max_value=2**128 - 1)), 14 | ) 15 | 16 | restore_status = builds(RestoreStatus, *restore_status_args) 17 | 18 | object_info_args = { 19 | "key": text(), 20 | "etag": text(), 21 | "size": integers(min_value=0, max_value=2**64 - 1), 22 | "last_modified": integers(min_value=-(2**63), max_value=2**63 - 1), 23 | "storage_class": one_of(none(), text()), 24 | "restore_status": one_of(none(), restore_status), 25 | } 26 | 27 | 28 | @given(*restore_status_args) 29 | @example(False, 2**128 - 1) 30 | def test_restore_status_constructor(in_progress: bool, expiry: Optional[int]): 31 | restore_status = RestoreStatus(in_progress, expiry) 32 | assert restore_status.in_progress is in_progress 33 | assert restore_status.expiry == expiry 34 | 35 | 36 | @given(*restore_status_args) 37 | def test_restore_status_unpickles(in_progress: bool, expiry: Optional[int]): 38 | restore_status = RestoreStatus(in_progress, expiry) 39 | unpickled: RestoreStatus = pickle.loads(pickle.dumps(restore_status)) 40 | 41 | assert type(unpickled) is RestoreStatus 42 | assert restore_status.in_progress is unpickled.in_progress is in_progress 43 | assert restore_status.expiry == unpickled.expiry == expiry 44 | 45 | 46 | @given(**object_info_args) 47 | @example("", "", 0, 0, None, None) 48 | @example("", "", 2**64 - 1, 0, None, None) 49 | @example("", "", 0, -(2**63), None, None) 50 | @example("", "", 0, 2**63 - 1, None, None) 51 | def test_object_info_constructor( 52 | key: str, 53 | etag: str, 54 | size: int, 55 | last_modified: int, 56 | storage_class: Optional[str], 57 | restore_status: Optional[RestoreStatus], 58 | ): 59 | object_info = ObjectInfo( 60 | key, 61 | etag, 62 | size, 63 | last_modified, 64 | storage_class, 65 | restore_status, 66 | ) 67 | assert object_info.key == key 68 | assert object_info.etag == etag 69 | assert object_info.size == size 70 | assert object_info.last_modified == last_modified 71 | assert object_info.storage_class == storage_class 72 | 73 | if restore_status is None: 74 | assert object_info.restore_status is None 75 | else: 76 | assert object_info.restore_status.expiry == restore_status.expiry 77 | assert object_info.restore_status.in_progress == restore_status.in_progress 78 | 79 | 80 | @given(**object_info_args) 81 | def test_object_info_pickles( 82 | key: str, 83 | etag: str, 84 | size: int, 85 | last_modified: int, 86 | storage_class: Optional[str], 87 | restore_status: Optional[RestoreStatus], 88 | ): 89 | object_info = ObjectInfo( 90 | key, 91 | etag, 92 | size, 93 | last_modified, 94 | storage_class, 95 | restore_status, 96 | ) 97 | 98 | unpickled: ObjectInfo = pickle.loads(pickle.dumps(object_info)) 99 | 100 | assert type(unpickled) is ObjectInfo 101 | 102 | assert object_info.key == unpickled.key == key 103 | assert object_info.etag == unpickled.etag == etag 104 | assert object_info.size == unpickled.size == size 105 | assert object_info.last_modified == unpickled.last_modified == last_modified 106 | assert object_info.storage_class == unpickled.storage_class == storage_class 107 | if restore_status is None: 108 | assert object_info.restore_status is unpickled.restore_status is None 109 | else: 110 | assert ( 111 | object_info.restore_status.expiry 112 | == unpickled.restore_status.expiry 113 | == restore_status.expiry 114 | ) 115 | assert ( 116 | object_info.restore_status.in_progress 117 | == unpickled.restore_status.in_progress 118 | == restore_status.in_progress 119 | ) 120 | assert ( 121 | object_info.restore_status 122 | is not unpickled.restore_status 123 | is not restore_status 124 | ) 125 | -------------------------------------------------------------------------------- /s3torchconnector/tst/e2e/test_mountpoint_client_parallel_access.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import threading 4 | import pytest 5 | from s3torchconnector._s3client import S3Client 6 | from s3torchconnectorclient._mountpoint_s3_client import MountpointS3Client 7 | 8 | from test_common import _get_fork_methods 9 | from conftest import getenv 10 | 11 | 12 | NATIVE_S3_CLIENT = None 13 | 14 | 15 | class S3ClientWithoutLock(S3Client): 16 | @property 17 | def _client(self) -> MountpointS3Client: 18 | global NATIVE_S3_CLIENT 19 | if self._client_pid is None or self._client_pid != os.getpid(): 20 | self._client_pid = os.getpid() 21 | # `MountpointS3Client` does not survive forking, so re-create it if the PID has changed. 22 | NATIVE_S3_CLIENT = self._client_builder() 23 | assert NATIVE_S3_CLIENT is not None 24 | return NATIVE_S3_CLIENT 25 | 26 | def _client_builder(self): 27 | time.sleep(1) 28 | return super()._client_builder() 29 | 30 | 31 | class S3ClientWithLock(S3Client): 32 | def _client_builder(self): 33 | time.sleep(1) 34 | return super()._client_builder() 35 | 36 | 37 | def test_s3_client_reset_after_fork(): 38 | methods = _get_fork_methods() 39 | if "fork" not in methods: 40 | pytest.skip("fork is not supported") 41 | region = getenv("CI_REGION") 42 | s3_client1 = S3Client(region=region) 43 | s3_client2 = S3Client(region=region) 44 | 45 | if ( 46 | s3_client1._client is None 47 | or s3_client2._client is None 48 | or s3_client1._native_client is None 49 | or s3_client2._native_client is None 50 | ): 51 | pytest.fail("Native client is not initialized") 52 | # fork process to clean-up clients 53 | # Fork process 54 | pid = os.fork() 55 | 56 | if pid == 0: 57 | # Child process 58 | try: 59 | if ( 60 | s3_client1._native_client is not None 61 | or s3_client2._native_client is not None 62 | ): 63 | os._exit(1) # Fail if clients are not reset 64 | os._exit(0) # Success 65 | finally: 66 | os._exit(1) # Ensure child exits in case of any other error 67 | else: 68 | # Parent process 69 | _, status = os.waitpid(pid, 0) 70 | exit_code = os.WEXITSTATUS(status) 71 | 72 | if ( 73 | exit_code != 0 74 | or s3_client1._native_client is not None 75 | or s3_client2._native_client is not None 76 | ): 77 | pytest.fail("Native client is not reset after fork") 78 | 79 | 80 | def access_client(client, error_event): 81 | try: 82 | if not error_event.is_set(): 83 | client._client 84 | print(f"Successfully accessed by thread {threading.current_thread().name}") 85 | except AssertionError as e: 86 | print(f"AssertionError in thread {threading.current_thread().name}: {e}") 87 | error_event.set() 88 | 89 | 90 | def test_multiple_thread_accessing_mountpoint_client_in_parallel_without_lock(): 91 | print("Running test without lock...") 92 | client = S3ClientWithoutLock("us-west-2") 93 | if not access_mountpoint_client_in_parallel(client): 94 | pytest.fail( 95 | "Test failed as AssertionError did not happen in one of the threads." 96 | ) 97 | 98 | 99 | def test_multiple_thread_accessing_mountpoint_client_in_parallel_with_lock(): 100 | print("Running test with lock...") 101 | client = S3ClientWithLock("us-west-2") 102 | if access_mountpoint_client_in_parallel(client): 103 | pytest.fail("Test failed as AssertionError happened in one of the threads.") 104 | 105 | 106 | def access_mountpoint_client_in_parallel(client): 107 | 108 | error_event = threading.Event() 109 | # Create and start multiple threads 110 | accessor_threads = [] 111 | num_accessor_threads = 10 112 | 113 | for i in range(num_accessor_threads): 114 | if error_event.is_set(): 115 | break 116 | accessor_thread = threading.Thread( 117 | target=access_client, 118 | args=( 119 | client, 120 | error_event, 121 | ), 122 | name=f"Accessor-{i + 1}", 123 | ) 124 | accessor_threads.append(accessor_thread) 125 | accessor_thread.start() 126 | 127 | for thread in accessor_threads: 128 | thread.join() 129 | 130 | return error_event.is_set() 131 | -------------------------------------------------------------------------------- /s3torchconnector/src/s3torchconnector/s3reader/constructor.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from functools import partial 5 | from typing import Optional 6 | 7 | from .protocol import S3ReaderConstructorProtocol 8 | from .sequential import SequentialS3Reader 9 | from .ranged import RangedS3Reader 10 | 11 | 12 | class S3ReaderConstructor: 13 | """Constructor for creating ``partial(S3Reader)`` instances. 14 | 15 | Creates partial ``S3Reader`` instances that will be completed by ``S3Client`` with the 16 | remaining required parameters (e.g. ``bucket``, ``key``, ``get_object_info``, ``get_stream``). 17 | 18 | The constructor provides factory methods for different reader types: 19 | 20 | - ``sequential()``: Creates a constructor for sequential readers that buffer the entire object. 21 | Best for full reads and repeated access. 22 | - ``range_based()``: Creates a constructor for range-based readers that fetch specific byte ranges. 23 | Suitable for sparse partial reads for large objects. 24 | """ 25 | 26 | @staticmethod 27 | def sequential() -> S3ReaderConstructorProtocol: 28 | """Creates a constructor for sequential readers 29 | 30 | Returns: 31 | S3ReaderConstructorProtocol: Partial constructor for SequentialS3Reader 32 | 33 | Example:: 34 | 35 | reader_constructor = S3ReaderConstructor.sequential() 36 | 37 | """ 38 | return partial(SequentialS3Reader) 39 | 40 | @staticmethod 41 | def range_based(buffer_size: Optional[int] = None) -> S3ReaderConstructorProtocol: 42 | """Creates a constructor for range-based readers 43 | 44 | Args: 45 | buffer_size: Internal buffer size in bytes. If None, uses default 8MB. 46 | Set to 0 to disable buffering. 47 | 48 | Returns: 49 | S3ReaderConstructorProtocol: Partial constructor for RangedS3Reader 50 | 51 | Range-based reader performs byte-range requests to read specific portions of S3 objects without 52 | downloading the entire file. 53 | 54 | Buffer size affects read performance: 55 | 56 | * Small reads (< ``buffer_size``): Loads ``buffer_size`` bytes to buffer to reduce S3 API calls for small, sequential reads 57 | * Large reads (≥ ``buffer_size``): bypass the buffer for direct transfer from S3 58 | * Forward overlap reads: Reuses buffered data when reading ranges that extend beyond current buffer, and processes remaining 59 | data according to size with logic above. 60 | 61 | Configuration Guide: 62 | 63 | * Use larger buffer sizes for workloads with many small, sequential reads of nearby bytes 64 | * Use smaller buffer sizes or disable buffering for sparse partial reads 65 | * Buffer can be disabled by setting ``buffer_size`` to 0 66 | * If ``buffer_size`` is None, uses default 8MB buffer 67 | 68 | Examples:: 69 | 70 | # Range-based reader with default 8MB buffer 71 | reader_constructor = S3ReaderConstructor.range_based() 72 | 73 | # Range-based reader with custom buffer size 74 | reader_constructor = S3ReaderConstructor.range_based(buffer_size=16*1024*1024) 75 | 76 | # Range-based reader with buffering disabled 77 | reader_constructor = S3ReaderConstructor.range_based(buffer_size=0) 78 | """ 79 | return partial(RangedS3Reader, buffer_size=buffer_size) 80 | 81 | @staticmethod 82 | def default() -> S3ReaderConstructorProtocol: 83 | """Creates default reader constructor (sequential) 84 | 85 | Returns: 86 | S3ReaderConstructorProtocol: Partial constructor for SequentialS3Reader 87 | """ 88 | return S3ReaderConstructor.sequential() 89 | 90 | @staticmethod 91 | def get_reader_type_string( 92 | constructor: Optional[S3ReaderConstructorProtocol], 93 | ) -> str: 94 | """Returns the reader type string for the given constructor.""" 95 | if constructor is None: 96 | return S3ReaderConstructor.get_reader_type_string( 97 | S3ReaderConstructor.default() 98 | ) 99 | 100 | if not isinstance(constructor, partial): 101 | return "unknown" 102 | 103 | if constructor.func == RangedS3Reader: 104 | return "range_based" 105 | elif constructor.func == SequentialS3Reader: 106 | return "sequential" 107 | else: 108 | return "unknown" 109 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/benchmark_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | import random 4 | import string 5 | import threading 6 | import time 7 | from collections import defaultdict 8 | from collections import deque 9 | from datetime import datetime 10 | from pathlib import Path 11 | from typing import Dict, Optional, List, TypedDict 12 | 13 | import numpy as np 14 | import psutil 15 | import torch.cuda 16 | from pynvml import ( # type: ignore 17 | nvmlInit, 18 | nvmlDeviceGetUtilizationRates, 19 | nvmlDeviceGetHandleByIndex, 20 | nvmlDeviceGetMemoryInfo, 21 | ) 22 | from torchvision.transforms import v2 # type: ignore 23 | 24 | monitor_gpu = False 25 | if torch.cuda.is_available(): 26 | monitor_gpu = True 27 | nvmlInit() 28 | 29 | 30 | class Distribution: 31 | def __init__(self, initial_capacity: int, precision: int = 4): 32 | self.initial_capacity = initial_capacity 33 | self._values = deque(maxlen=None) 34 | self.precision = precision 35 | 36 | def add(self, val: float): 37 | self._values.append(val) 38 | 39 | def summarize(self) -> dict: 40 | if not self._values: 41 | return {} 42 | window = np.array(self._values) 43 | return { 44 | "n": len(window), 45 | "mean": round(float(window.mean()), self.precision), 46 | "min": round(np.percentile(window, 0), self.precision), 47 | "p50": round(np.percentile(window, 50), self.precision), 48 | "p75": round(np.percentile(window, 75), self.precision), 49 | "p90": round(np.percentile(window, 90), self.precision), 50 | "max": round(np.percentile(window, 100), self.precision), 51 | } 52 | 53 | 54 | class ExperimentResult(TypedDict, total=False): 55 | training_duration_s: float 56 | epoch_durations_s: List[float] 57 | volume: int 58 | checkpoint_times: Optional[List[float]] 59 | utilization: Dict[str, Distribution] 60 | 61 | 62 | class ResourceMonitor: 63 | """ 64 | Monitors CPU, GPU usage and memory. 65 | Set sleep_time_s carefully to avoid perf degradations. 66 | """ 67 | 68 | def __init__( 69 | self, sleep_time_s: float = 0.05, gpu_device: int = 0, chunk_size: int = 25_000 70 | ): 71 | self.monitor_thread = None 72 | self._utilization: Dict[str, Distribution] = defaultdict( 73 | lambda: Distribution(chunk_size) 74 | ) 75 | self.stop_event = threading.Event() 76 | self.sleep_time_s = sleep_time_s 77 | self.gpu_device = gpu_device 78 | self.chunk_size = chunk_size 79 | 80 | def _monitor(self): 81 | while not self.stop_event.is_set(): 82 | self._utilization["cpu_util"].add(psutil.cpu_percent()) 83 | self._utilization["cpu_mem"].add(psutil.virtual_memory().percent) 84 | 85 | if monitor_gpu: 86 | gpu_info = nvmlDeviceGetUtilizationRates( 87 | nvmlDeviceGetHandleByIndex(self.gpu_device) 88 | ) 89 | gpu_mem_info = nvmlDeviceGetMemoryInfo( 90 | nvmlDeviceGetHandleByIndex(self.gpu_device) 91 | ) 92 | self._utilization["gpu_util"].add(gpu_info.gpu) 93 | self._utilization["gpu_mem"].add( 94 | gpu_mem_info.used / gpu_mem_info.total * 100 95 | ) 96 | time.sleep(self.sleep_time_s) 97 | 98 | @property 99 | def resource_data(self): 100 | return dict(self._utilization) 101 | 102 | def __enter__(self): 103 | self.start() 104 | return self 105 | 106 | def __exit__(self, exc_type, exc_val, exc_tb): 107 | self.stop() 108 | 109 | def start(self): 110 | self.monitor_thread = threading.Thread(target=self._monitor) 111 | self.monitor_thread.start() 112 | 113 | def stop(self): 114 | self.stop_event.set() 115 | self.monitor_thread.join() 116 | 117 | 118 | def build_random_suffix() -> str: 119 | """Generates a unique suffix combining timestamp with random characters for use in filepaths or S3 URIs.""" 120 | timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M") 121 | random_suffix = "".join(random.choices(string.ascii_letters, k=4)) 122 | return f"{timestamp}-{random_suffix}" 123 | 124 | 125 | def build_checkpoint_path(path: str, suffix: str) -> str: 126 | return str(Path(path) / suffix) 127 | 128 | 129 | def build_checkpoint_uri(uri: str, suffix: str) -> str: 130 | return uri.removesuffix("/") + "/" + suffix.removeprefix("/") 131 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/put_object_stream.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | 6 | use futures::executor::block_on; 7 | use mountpoint_s3_client::PutObjectRequest; 8 | use pyo3::{pyclass, pymethods, PyRefMut, PyResult, Python}; 9 | 10 | use crate::exception::{python_exception, S3Exception}; 11 | 12 | #[pyclass( 13 | name = "PutObjectStream", 14 | module = "s3torchconnectorclient._mountpoint_s3_client" 15 | )] 16 | pub struct PutObjectStream { 17 | request: Box, 18 | #[pyo3(get)] 19 | bucket: String, 20 | #[pyo3(get)] 21 | key: String, 22 | } 23 | 24 | impl PutObjectStream { 25 | pub(crate) fn new( 26 | request: T, 27 | bucket: String, 28 | key: String, 29 | ) -> Self { 30 | let request = Box::new(PutObjectRequestWrapperImpl::new(request)); 31 | Self { 32 | request, 33 | bucket, 34 | key, 35 | } 36 | } 37 | } 38 | 39 | #[pymethods] 40 | impl PutObjectStream { 41 | pub fn write(mut slf: PyRefMut<'_, Self>, data: &[u8]) -> PyResult<()> { 42 | let py = slf.py(); 43 | slf.request.write(py, data) 44 | } 45 | 46 | pub fn close(mut slf: PyRefMut<'_, Self>) -> PyResult<()> { 47 | let py = slf.py(); 48 | slf.request.complete(py) 49 | } 50 | } 51 | 52 | pub trait PutObjectRequestWrapper { 53 | fn write(&mut self, py: Python, data: &[u8]) -> PyResult<()>; 54 | fn complete(&mut self, py: Python) -> PyResult<()>; 55 | } 56 | 57 | pub struct PutObjectRequestWrapperImpl { 58 | request: Option, 59 | } 60 | 61 | impl PutObjectRequestWrapperImpl { 62 | pub fn new(request: T) -> PutObjectRequestWrapperImpl { 63 | PutObjectRequestWrapperImpl { 64 | request: Some(request), 65 | } 66 | } 67 | } 68 | 69 | impl PutObjectRequestWrapper for PutObjectRequestWrapperImpl { 70 | fn write(&mut self, py: Python, data: &[u8]) -> PyResult<()> { 71 | if let Some(request) = self.request.as_mut() { 72 | py.allow_threads(|| block_on(request.write(data)).map_err(python_exception)) 73 | } else { 74 | Err(S3Exception::new_err("Cannot write to closed object")) 75 | } 76 | } 77 | 78 | fn complete(&mut self, py: Python) -> PyResult<()> { 79 | if let Some(request) = self.request.take() { 80 | py.allow_threads(|| block_on(request.complete()).map_err(python_exception))?; 81 | Ok(()) 82 | } else { 83 | Err(S3Exception::new_err("Cannot close object more than once")) 84 | } 85 | } 86 | } 87 | 88 | #[cfg(test)] 89 | mod tests { 90 | use pyo3::types::IntoPyDict; 91 | use pyo3::{py_run, PyResult, Python}; 92 | use tracing_subscriber::layer::SubscriberExt; 93 | use tracing_subscriber::util::SubscriberInitExt; 94 | 95 | use crate::mock_client::PyMockClient; 96 | use crate::mountpoint_s3_client::MountpointS3Client; 97 | 98 | #[test] 99 | fn test_put_object() -> PyResult<()> { 100 | let layer = tracing_subscriber::fmt::layer().with_ansi(true); 101 | let registry = tracing_subscriber::registry().with(layer); 102 | let _ = registry.try_init(); 103 | 104 | pyo3::prepare_freethreaded_python(); 105 | 106 | Python::with_gil(|py| { 107 | let locals = [ 108 | ( 109 | "MountpointS3Client", 110 | py.get_type::(), 111 | ), 112 | ( 113 | "MockMountpointS3Client", 114 | py.get_type::(), 115 | ), 116 | ]; 117 | 118 | py_run!( 119 | py, 120 | *locals.into_py_dict(py).unwrap(), 121 | r#" 122 | data_to_write = b"Hello!" 123 | 124 | mock_client = MockMountpointS3Client("us-east-1", "mock-bucket") 125 | client = mock_client.create_mocked_client() 126 | 127 | put_stream = client.put_object("mock-bucket", "key") 128 | put_stream.write(data_to_write) 129 | put_stream.close() 130 | 131 | get_stream = client.get_object("mock-bucket", "key") 132 | assert b''.join(get_stream) == data_to_write 133 | "# 134 | ); 135 | }); 136 | 137 | Ok(()) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_common.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | import os 6 | from multiprocessing.queues import Queue 7 | from pathlib import Path 8 | from time import perf_counter 9 | from typing import List, Tuple 10 | 11 | import hydra 12 | import pandas as pd 13 | import torch.distributed as dist 14 | import torch.distributed.checkpoint as dcp 15 | from omegaconf import DictConfig 16 | from torch import multiprocessing as mp 17 | from torch.distributed.checkpoint import FileSystemWriter, FileSystemReader 18 | 19 | 20 | from s3torchbenchmarking.benchmark_utils import ( 21 | build_random_suffix, 22 | build_checkpoint_uri, 23 | ) 24 | from s3torchconnector.dcp import S3StorageWriter, S3StorageReader 25 | 26 | Timestamps = Tuple[float, float] 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def setup(backend: str, world_size: int, rank: int) -> None: 31 | os.environ["MASTER_ADDR"] = "localhost" 32 | os.environ["MASTER_PORT"] = "12355" 33 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 34 | dist.init_process_group(backend, world_size=world_size, rank=rank) 35 | 36 | 37 | def get_writer(cfg: DictConfig, suffix: str) -> FileSystemWriter: 38 | """Instantiate a checkpoint writer based on the input config.""" 39 | if cfg.checkpoint.storage == "disk": 40 | local_path = Path(cfg.path) / suffix 41 | logger.info("Saving checkpoint to %s (disk)...", local_path) 42 | sync_files = getattr(cfg.checkpoint, "sync_files", False) 43 | return dcp.FileSystemWriter( 44 | local_path, sync_files=sync_files, thread_count=cfg.thread_count 45 | ) 46 | elif cfg.checkpoint.storage == "s3": 47 | uri = build_checkpoint_uri(cfg.s3.uri, suffix) 48 | logger.info("Saving checkpoint to %s (S3)...", uri) 49 | return S3StorageWriter(cfg.s3.region, uri, thread_count=cfg.thread_count) 50 | raise ValueError(f"Storage writer {cfg.checkpoint.storage} not supported") 51 | 52 | 53 | def get_reader(cfg: DictConfig) -> FileSystemReader: 54 | """Instantiate a checkpoint reader based on the input config.""" 55 | suffix = cfg.checkpoint.suffix 56 | if cfg.checkpoint.storage == "disk": 57 | local_path = Path(cfg.path) / suffix 58 | logger.info("Loading checkpoint from %s (disk)...", local_path) 59 | return dcp.FileSystemReader(local_path) 60 | elif cfg.checkpoint.storage == "s3": 61 | uri = build_checkpoint_uri(cfg.s3.uri, suffix) 62 | logger.info("Loading checkpoint from %s (S3)...", uri) 63 | return S3StorageReader(cfg.s3.region, uri) 64 | raise ValueError(f"Storage reader {cfg.checkpoint.storage} not supported") 65 | 66 | 67 | def benchmark_common_runner( 68 | cfg: DictConfig, 69 | run_fn, 70 | run_args: tuple, 71 | ) -> dict: 72 | manager = mp.Manager() 73 | corrected_save_timestamps: Queue[Timestamps] = manager.Queue() 74 | processing_timestamps: List[Timestamps] = [] 75 | 76 | for epoch in range(cfg.epochs): 77 | suffix = build_random_suffix() 78 | logger.info("Executing epoch #%i / %i...", epoch + 1, cfg.epochs) 79 | begin_mp = perf_counter() 80 | mp.spawn( 81 | run_fn, 82 | run_args 83 | + ( 84 | suffix, 85 | corrected_save_timestamps, 86 | ), 87 | nprocs=cfg.world_size, 88 | join=True, 89 | ) 90 | end_mp = perf_counter() 91 | processing_timestamps.append((begin_mp, end_mp)) 92 | 93 | return process_timestamps(corrected_save_timestamps, processing_timestamps) 94 | 95 | 96 | def process_timestamps( 97 | corrected_save_timestamps: Queue, 98 | processing_timestamps: List[Timestamps], 99 | ) -> dict: 100 | """Process and return metrics from timestamps.""" 101 | collector: List[Timestamps] = [] 102 | while not corrected_save_timestamps.empty(): 103 | collector.append(corrected_save_timestamps.get()) 104 | 105 | cst = pd.DataFrame(collector, columns=["begin", "end", "size"]) 106 | pt = pd.DataFrame(processing_timestamps, columns=["begin", "end"]) 107 | 108 | corrected_save_durations_s = cst["end"] - cst["begin"] 109 | processing_durations_s = pt["end"] - pt["begin"] 110 | throughput_mibs = cst["size"] / corrected_save_durations_s 111 | 112 | return { 113 | "metrics": { 114 | "throughput_mibs": throughput_mibs.dropna().to_list(), 115 | "corrected_save_durations_s": corrected_save_durations_s.dropna().to_list(), 116 | "processing_durations_s": processing_durations_s.dropna().to_list(), 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /s3torchconnectorclient/rust/src/logger_setup.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 3 | * // SPDX-License-Identifier: BSD 4 | */ 5 | use crate::exception::python_exception; 6 | use mountpoint_s3_client::config::RustLogAdapter; 7 | use pyo3::PyResult; 8 | use std::env; 9 | use tracing_subscriber::filter::EnvFilter; 10 | use tracing_subscriber::util::SubscriberInitExt; 11 | 12 | pub const S3_TORCH_CONNECTOR_DEBUG_LOGS_ENV_VAR: &str = "S3_TORCH_CONNECTOR_DEBUG_LOGS"; 13 | pub const S3_TORCH_CONNECTOR_LOGS_DIR_PATH_ENV_VAR: &str = "S3_TORCH_CONNECTOR_LOGS_DIR_PATH"; 14 | pub const LOG_FILE_PREFIX: &str = "s3torchconnectorclient.log"; 15 | 16 | pub fn setup_logging() -> PyResult<()> { 17 | let enable_logs = env::var(S3_TORCH_CONNECTOR_DEBUG_LOGS_ENV_VAR); 18 | 19 | if enable_logs.is_ok() { 20 | let filter = EnvFilter::try_from_env(S3_TORCH_CONNECTOR_DEBUG_LOGS_ENV_VAR) 21 | .map_err(python_exception)?; 22 | let debug_logs_path = env::var(S3_TORCH_CONNECTOR_LOGS_DIR_PATH_ENV_VAR).ok(); 23 | 24 | RustLogAdapter::try_init().map_err(python_exception)?; 25 | 26 | match debug_logs_path { 27 | Some(logs_path) => { 28 | enable_file_logging(filter, logs_path)?; 29 | } 30 | None => { 31 | enable_default_logging(filter)?; 32 | } 33 | } 34 | } 35 | 36 | Ok(()) 37 | } 38 | 39 | fn enable_file_logging(filter: EnvFilter, logs_path: String) -> PyResult<()> { 40 | let logfile = tracing_appender::rolling::hourly(logs_path, LOG_FILE_PREFIX); 41 | let subscriber_builder = tracing_subscriber::fmt() 42 | .with_writer(logfile) 43 | .with_env_filter(filter) 44 | .with_ansi(false); 45 | subscriber_builder 46 | .finish() 47 | .try_init() 48 | .map_err(python_exception)?; 49 | 50 | Ok(()) 51 | } 52 | 53 | fn enable_default_logging(filter: EnvFilter) -> PyResult<()> { 54 | let subscriber_builder = tracing_subscriber::fmt() 55 | .with_env_filter(filter) 56 | .with_ansi(false); 57 | subscriber_builder 58 | .finish() 59 | .try_init() 60 | .map_err(python_exception)?; 61 | 62 | Ok(()) 63 | } 64 | 65 | #[cfg(test)] 66 | mod tests { 67 | use crate::logger_setup::{ 68 | setup_logging, S3_TORCH_CONNECTOR_DEBUG_LOGS_ENV_VAR, 69 | S3_TORCH_CONNECTOR_LOGS_DIR_PATH_ENV_VAR, 70 | }; 71 | use pyo3::PyResult; 72 | use rusty_fork::rusty_fork_test; 73 | use std::env; 74 | 75 | fn check_valid_log_level(log_level: &str) { 76 | pyo3::prepare_freethreaded_python(); 77 | env::set_var(S3_TORCH_CONNECTOR_DEBUG_LOGS_ENV_VAR, log_level); 78 | let result: PyResult<()> = setup_logging(); 79 | assert!(result.is_ok()); 80 | } 81 | 82 | rusty_fork_test! { 83 | #[test] 84 | fn test_debug_log_environment_variable_unset() { 85 | pyo3::prepare_freethreaded_python(); 86 | env::remove_var(S3_TORCH_CONNECTOR_DEBUG_LOGS_ENV_VAR); 87 | let result: PyResult<()> = setup_logging(); 88 | assert!(result.is_ok()); 89 | } 90 | 91 | #[test] 92 | fn test_logs_dir_environment_variable_unset() { 93 | pyo3::prepare_freethreaded_python(); 94 | env::remove_var(S3_TORCH_CONNECTOR_LOGS_DIR_PATH_ENV_VAR); 95 | let result: PyResult<()> = setup_logging(); 96 | assert!(result.is_ok()); 97 | } 98 | 99 | #[test] 100 | fn test_debug_logging_off() { 101 | check_valid_log_level("OFF"); 102 | } 103 | 104 | #[test] 105 | fn test_debug_logging_level_error() { 106 | check_valid_log_level("ERROR"); 107 | } 108 | 109 | #[test] 110 | fn test_debug_logging_level_warn() { 111 | check_valid_log_level("WARN"); 112 | } 113 | 114 | #[test] 115 | fn test_debug_logging_level_info() { 116 | check_valid_log_level("INFO"); 117 | } 118 | 119 | #[test] 120 | fn test_debug_logging_level_debug() { 121 | check_valid_log_level("debug"); 122 | } 123 | 124 | #[test] 125 | fn test_debug_logging_level_trace() { 126 | check_valid_log_level("trace"); 127 | } 128 | 129 | #[test] 130 | #[ignore = "tracing-subscriber 0.3.20 EnvFilter parsing regression - see tokio-rs/tracing#3371"] 131 | fn test_invalid_logging_level() { 132 | pyo3::prepare_freethreaded_python(); 133 | env::set_var(S3_TORCH_CONNECTOR_DEBUG_LOGS_ENV_VAR, "invalid123.&/?"); 134 | let result: PyResult<()> = setup_logging(); 135 | assert!(result.is_err()); 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /s3torchconnectorclient/python/src/s3torchconnectorclient/_mountpoint_s3_client.pyi: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | from typing import List, Optional 5 | 6 | # This interface is unstable! 7 | 8 | class MountpointS3Client: 9 | throughput_target_gbps: float 10 | region: str 11 | part_size: int 12 | profile: Optional[str] 13 | unsigned: Optional[bool] 14 | force_path_style: Optional[bool] 15 | max_attempts: int 16 | user_agent_prefix: str 17 | endpoint: str 18 | 19 | def __init__( 20 | self, 21 | region: str, 22 | user_agent_prefix: str = "", 23 | throughput_target_gbps: float = 10.0, 24 | part_size: int = 8 * 1024 * 1024, 25 | profile: Optional[str] = None, 26 | unsigned: Optional[bool] = False, 27 | endpoint: Optional[str] = None, 28 | force_path_style: Optional[bool] = False, 29 | max_attempts: int = 10, 30 | ): ... 31 | def get_object( 32 | self, 33 | bucket: str, 34 | key: str, 35 | start: Optional[int] = None, 36 | end: Optional[int] = None, 37 | ) -> GetObjectStream: ... 38 | def put_object( 39 | self, bucket: str, key: str, storage_class: Optional[str] = None 40 | ) -> PutObjectStream: ... 41 | def list_objects( 42 | self, bucket: str, prefix: str = "", delimiter: str = "", max_keys: int = 1000 43 | ) -> ListObjectStream: ... 44 | def head_object(self, bucket: str, key: str) -> HeadObjectResult: ... 45 | def delete_object(self, bucket: str, key: str) -> None: ... 46 | def copy_object( 47 | self, src_bucket: str, src_key: str, dst_bucket: str, dst_key: str 48 | ) -> None: ... 49 | 50 | class MockMountpointS3Client: 51 | throughput_target_gbps: float 52 | region: str 53 | part_size: int 54 | user_agent_prefix: str 55 | unsigned: bool 56 | force_path_style: bool 57 | max_attempts: int 58 | 59 | def __init__( 60 | self, 61 | region: str, 62 | bucket: str, 63 | endpoint: str = "", 64 | throughput_target_gbps: float = 10.0, 65 | part_size: int = 8 * 1024 * 1024, 66 | user_agent_prefix: str = "mock_client", 67 | unsigned: bool = False, 68 | force_path_style: bool = False, 69 | max_attempts: int = 10, 70 | ): ... 71 | def create_mocked_client(self) -> MountpointS3Client: ... 72 | def add_object(self, key: str, data: bytes) -> None: ... 73 | def remove_object(self, key: str) -> None: ... 74 | 75 | class GetObjectStream: 76 | bucket: str 77 | key: str 78 | 79 | def __iter__(self) -> GetObjectStream: ... 80 | def __next__(self) -> bytes: ... 81 | def tell(self) -> int: ... 82 | 83 | class PutObjectStream: 84 | bucket: str 85 | key: str 86 | def write(self, data: bytes) -> None: ... 87 | def close(self) -> None: ... 88 | 89 | class RestoreStatus: 90 | in_progress: bool 91 | expiry: Optional[int] 92 | 93 | def __init__(self, in_progress: bool, expiry: Optional[int]): ... 94 | 95 | class ObjectInfo: 96 | key: str 97 | etag: str 98 | size: int 99 | last_modified: int 100 | storage_class: Optional[str] 101 | restore_status: Optional[RestoreStatus] 102 | 103 | def __init__( 104 | self, 105 | key: str, 106 | etag: str, 107 | size: int, 108 | last_modified: int, 109 | storage_class: Optional[str], 110 | restore_status: Optional[RestoreStatus], 111 | ): ... 112 | 113 | class HeadObjectResult: 114 | etag: str 115 | size: int 116 | last_modified: int 117 | storage_class: Optional[str] 118 | restore_status: Optional[RestoreStatus] 119 | 120 | def __init__( 121 | self, 122 | etag: str, 123 | size: int, 124 | last_modified: int, 125 | storage_class: Optional[str], 126 | restore_status: Optional[RestoreStatus], 127 | ): ... 128 | 129 | class ListObjectResult: 130 | object_info: List[ObjectInfo] 131 | common_prefixes: List[str] 132 | 133 | class ListObjectStream: 134 | bucket: str 135 | continuation_token: Optional[str] 136 | complete: bool 137 | prefix: str 138 | delimiter: str 139 | max_keys: int 140 | 141 | def __iter__(self) -> ListObjectStream: ... 142 | def __next__(self) -> ListObjectResult: ... 143 | @staticmethod 144 | def _from_state( 145 | client: MountpointS3Client, 146 | bucket: str, 147 | prefix: str, 148 | delimiter: str, 149 | max_keys: int, 150 | continuation_token: Optional[str], 151 | complete: bool, 152 | ) -> ListObjectStream: ... 153 | 154 | class S3Exception(Exception): 155 | pass 156 | 157 | __version__: str 158 | 159 | def join_all_managed_threads(timeout_secs: float) -> None: ... 160 | -------------------------------------------------------------------------------- /s3torchbenchmarking/src/s3torchbenchmarking/dcp_fsdp/load_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | import functools 6 | from multiprocessing.queues import Queue 7 | from time import perf_counter 8 | from typing import Tuple 9 | 10 | import hydra 11 | import torch.distributed.checkpoint as dcp 12 | from omegaconf import DictConfig 13 | import torch 14 | import torch.distributed as dist 15 | import os 16 | 17 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 18 | from torch.distributed.fsdp import ShardingStrategy 19 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 20 | from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType 21 | from transformers.models.llama.modeling_llama import LlamaDecoderLayer 22 | 23 | from s3torchbenchmarking.dcp_common import setup, benchmark_common_runner, get_reader 24 | from s3torchbenchmarking.models import get_benchmark_model 25 | 26 | Timestamps = Tuple[float, float] 27 | logger = logging.getLogger(__name__) 28 | import sys 29 | 30 | 31 | @hydra.main(version_base=None) 32 | def run_benchmark(cfg: DictConfig) -> dict: 33 | """DCP load benchmarks entry point.""" 34 | return benchmark_common_runner(cfg, run_fsdp_load, (cfg,)) 35 | 36 | 37 | def run_fsdp_load( 38 | rank: int, 39 | cfg: DictConfig, 40 | suffix: str, 41 | load_timestamps: Queue, 42 | ): 43 | """Execute the actual code for checkpoint loading. 44 | 45 | This function is meant to be executed in subprocesses.""" 46 | setup(cfg.backend, world_size=cfg.world_size, rank=rank) 47 | 48 | if rank == 0: 49 | logger.info("Creating Model") 50 | 51 | # Instantiate model on CPU on rank=0 only to prevent CPU OOM 52 | # (e.g. 70B * 4 bytes * 8 processes > 2T RAM available on P5) 53 | if rank == 0: 54 | model_proxy = get_benchmark_model(cfg.model) 55 | model = model_proxy.model 56 | else: 57 | with torch.device("meta"): 58 | # Instantiating model on `meta` device doesn't consume CPU memory, 59 | # but requires specifing `param_init_fn=...` 60 | # and `sync_module_states=True` in FSDP c-tor. 61 | model_proxy = get_benchmark_model(cfg.model) 62 | model = model_proxy.model 63 | 64 | model_size = model_proxy.size 65 | 66 | transformer_layer = LlamaDecoderLayer 67 | gpt_auto_wrap_policy = functools.partial( 68 | transformer_auto_wrap_policy, 69 | transformer_layer_cls={ 70 | transformer_layer, 71 | }, 72 | ) 73 | 74 | if cfg.backend == "nccl": 75 | device_id = rank % torch.cuda.device_count() 76 | torch.cuda.set_device(device_id) 77 | param_init_fn = lambda module: module.to_empty( 78 | device=torch.device("cuda"), recurse=False 79 | ) 80 | else: 81 | device_id = rank % torch.cpu.device_count() 82 | torch.cpu.set_device(device_id) 83 | param_init_fn = lambda module: module.to_empty( 84 | device=torch.device("cpu"), recurse=False 85 | ) 86 | 87 | if cfg.checkpoint.sharding_strategy == "full": 88 | sharding_strategy = ShardingStrategy.FULL_SHARD 89 | elif cfg.checkpoint.sharding_strategy == "hybrid": 90 | sharding_strategy = ShardingStrategy.HYBRID_SHARD 91 | else: 92 | raise NotImplementedError("Available sharding strategies are full and hybrid") 93 | 94 | model = FSDP( 95 | model, 96 | auto_wrap_policy=gpt_auto_wrap_policy, 97 | device_id=( 98 | torch.cuda.current_device() 99 | if cfg.backend == "nccl" 100 | else torch.cpu.current_device() 101 | ), 102 | use_orig_params=False, 103 | sharding_strategy=sharding_strategy, 104 | sync_module_states=True if cfg.backend == "nccl" else False, 105 | param_init_fn=param_init_fn if rank != 0 else None, 106 | ) 107 | 108 | if rank == 0: 109 | logger.info("Wrapped model with FSDP") 110 | 111 | # Prepare state dict for loading 112 | with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): 113 | state_dict = { 114 | "model": model.state_dict(), 115 | } 116 | 117 | storage_reader = get_reader(cfg) 118 | 119 | # Align all workers to start loading at the same time 120 | dist.barrier() 121 | begin_load = perf_counter() 122 | dcp.load(state_dict, storage_reader=storage_reader) 123 | end_load = perf_counter() 124 | 125 | if rank == 0: 126 | logger.info(f"The total size of model is {model_size}") 127 | # Record the save times excluding the influence of the process setup and model loading to device. 128 | load_timestamps.put((begin_load, end_load, model_size)) 129 | dist.destroy_process_group() 130 | 131 | 132 | if __name__ == "__main__": 133 | run_benchmark() 134 | -------------------------------------------------------------------------------- /s3torchconnector/tst/unit/test_s3dataset_common.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import logging 5 | from typing import Iterable, Union, Sequence 6 | 7 | import pytest 8 | 9 | from s3torchconnector import S3Exception 10 | from s3torchconnector.s3reader import SequentialS3Reader, RangedS3Reader 11 | from s3torchconnector._s3client import MockS3Client 12 | 13 | from s3torchconnector._s3dataset_common import ( 14 | parse_s3_uri, 15 | get_objects_from_prefix, 16 | get_objects_from_uris, 17 | ) 18 | 19 | logging.basicConfig( 20 | format="%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s" 21 | ) 22 | logging.getLogger().setLevel(1) 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | TEST_BUCKET = "test-bucket" 27 | TEST_KEY = "test-key" 28 | TEST_REGION = "us-east-1" 29 | S3_PREFIX = f"s3://{TEST_BUCKET}" 30 | TEST_ENDPOINT = "https://s3.us-east-1.amazonaws.com" 31 | READER_TYPE_STRING_TO_CLASS = { 32 | "sequential": SequentialS3Reader, 33 | "range_based": RangedS3Reader, 34 | } 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "uri, expected_bucket, expected_key", 39 | [ 40 | (f"s3://bucket/key", "bucket", "key"), 41 | (f"s3://bucket", "bucket", ""), 42 | (f"s3://bucket/key/inner-key", "bucket", "key/inner-key"), 43 | ], 44 | ) 45 | def test_s3dataset_base_parse_s3_uri_success(uri, expected_bucket, expected_key): 46 | bucket, key = parse_s3_uri(uri) 47 | assert bucket == expected_bucket 48 | assert key == expected_key 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "uri, error_msg", 53 | [ 54 | (None, "Only s3:// URIs are supported"), 55 | ("", "Only s3:// URIs are supported"), 56 | ("s3a://bucket/key", "Only s3:// URIs are supported"), 57 | ("s3://", "Bucket name must be non-empty"), 58 | ("s3:///key", "Bucket name must be non-empty"), 59 | ], 60 | ) 61 | def test_s3dataset_base_parse_s3_uri_fail(uri, error_msg): 62 | with pytest.raises(ValueError, match=f"^{error_msg}$"): 63 | parse_s3_uri(uri) 64 | 65 | 66 | @pytest.mark.parametrize( 67 | "prefix, keys, expected_count", 68 | [ 69 | ("", ["obj1", "obj2", "obj3", "test", "test2"], 5), 70 | ("obj", ["obj1", "obj2", "obj3", "test", "test2"], 3), 71 | ], 72 | ) 73 | def test_get_objects_from_prefix(prefix: str, keys: Sequence[str], expected_count: int): 74 | mock_client = _create_mock_client_with_dummy_objects(TEST_BUCKET, keys) 75 | bucket_key_pairs = get_objects_from_prefix(f"{S3_PREFIX}/{prefix}", mock_client) 76 | count = 0 77 | for index, bucket_key_pair in enumerate(bucket_key_pairs): 78 | count += 1 79 | assert bucket_key_pair is not None 80 | assert bucket_key_pair.bucket == TEST_BUCKET 81 | assert bucket_key_pair.key == keys[index] 82 | assert count == expected_count 83 | 84 | 85 | def test_list_objects_for_bucket_invalid(): 86 | mock_client = _create_mock_client_with_dummy_objects(TEST_BUCKET, []) 87 | with pytest.raises(S3Exception, match="Service error: The bucket does not exist"): 88 | objects = get_objects_from_prefix( 89 | "s3://DIFFERENT_BUCKET", 90 | mock_client, 91 | ) 92 | next(iter(objects)) 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "object_uris, expected_keys", 97 | [([], []), ([f"{S3_PREFIX}/obj1", f"{S3_PREFIX}/obj2"], ["obj1", "obj2"])], 98 | ) 99 | def test_get_objects_from_uris_success( 100 | object_uris: Sequence[str], expected_keys: Sequence[str] 101 | ): 102 | mock_client = _create_mock_client_with_dummy_objects(TEST_BUCKET, expected_keys) 103 | bucket_key_pairs = get_objects_from_uris(object_uris, mock_client) 104 | count = 0 105 | for index, bucket_key_pair in enumerate(bucket_key_pairs): 106 | count += 1 107 | assert bucket_key_pair is not None 108 | assert bucket_key_pair.bucket == TEST_BUCKET 109 | assert bucket_key_pair.key == expected_keys[index] 110 | assert count == len(expected_keys) 111 | 112 | 113 | @pytest.mark.parametrize( 114 | "uri, error_msg", 115 | [ 116 | ("", "Only s3:// URIs are supported"), 117 | ("s3a://bucket/key", "Only s3:// URIs are supported"), 118 | ("s3://", "Bucket name must be non-empty"), 119 | ("s3:///key", "Bucket name must be non-empty"), 120 | ], 121 | ) 122 | def test_get_objects_from_uris_fail(uri, error_msg): 123 | mock_client = _create_mock_client_with_dummy_objects(TEST_BUCKET, []) 124 | with pytest.raises(ValueError, match=f"^{error_msg}$"): 125 | get_objects_from_uris(uri, mock_client) 126 | 127 | 128 | def _create_mock_client_with_dummy_objects( 129 | bucket: str, keys: Union[str, Iterable[str]] 130 | ): 131 | mock_client = MockS3Client(TEST_REGION, bucket) 132 | for key in keys: 133 | content = f"{bucket}-{key}-dummyData".encode() 134 | mock_client.add_object(key, content) 135 | return mock_client 136 | -------------------------------------------------------------------------------- /s3torchconnectorclient/python/tst/integration/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # // SPDX-License-Identifier: BSD 3 | 4 | import io 5 | import os 6 | import random 7 | from datetime import datetime 8 | from dataclasses import dataclass, field 9 | from typing import Optional 10 | 11 | import boto3 12 | import numpy as np 13 | from PIL import Image 14 | import pytest 15 | 16 | SESSION_DATETIME = datetime.now().strftime("%Y%m%dT%H%M%S") 17 | 18 | 19 | def getenv(var: str, optional: bool = False) -> str: 20 | v = os.getenv(var) 21 | if v is None and not optional: 22 | raise Exception(f"Required environment variable {var} is not set") 23 | return v 24 | 25 | 26 | @dataclass 27 | class BucketPrefixFixture: 28 | """An S3 bucket/prefix and its contents for use in a single unit test. The prefix will be unique 29 | to this instance, so other concurrent tests won't affect its state.""" 30 | 31 | name: str 32 | 33 | region: str = getenv("CI_REGION") 34 | bucket: str = getenv("CI_BUCKET") 35 | prefix: str = getenv("CI_PREFIX") 36 | storage_class: Optional[str] = getenv("CI_STORAGE_CLASS", optional=True) 37 | endpoint_url: Optional[str] = getenv("CI_CUSTOM_ENDPOINT_URL", optional=True) 38 | contents: dict = field(default_factory=dict) 39 | profile_arn: Optional[str] = getenv("CI_PROFILE_ROLE", optional=True) 40 | profile_bucket: Optional[str] = getenv("CI_PROFILE_BUCKET", optional=True) 41 | 42 | def __post_init__(self): 43 | assert self.prefix == "" or self.prefix.endswith("/") 44 | session = boto3.Session(region_name=self.region) 45 | self.s3 = session.client("s3") 46 | 47 | nonce = random.randrange(2**64) 48 | self.prefix = f"{self.prefix}{self.name}/{SESSION_DATETIME}-{nonce}/" 49 | 50 | @property 51 | def s3_uri(self): 52 | return f"s3://{self.bucket}/{self.prefix}" 53 | 54 | def add(self, key: str, contents: bytes, **kwargs): 55 | """Upload an S3 object to this prefix of the bucket.""" 56 | full_key = f"{self.prefix}{key}" 57 | self.s3.put_object(Bucket=self.bucket, Key=full_key, Body=contents, **kwargs) 58 | self.contents[full_key] = contents 59 | 60 | def remove(self, key: str): 61 | full_key = f"{self.prefix}{key}" 62 | self.s3.delete_object(Bucket=self.bucket, Key=full_key) 63 | 64 | def __getitem__(self, index): 65 | return self.contents[index] 66 | 67 | def __iter__(self): 68 | return iter(self.contents) 69 | 70 | 71 | @dataclass 72 | class CopyBucketFixture(BucketPrefixFixture): 73 | src_key: str = "src.txt" 74 | dst_key: str = "dst.txt" 75 | 76 | @property 77 | def full_src_key(self): 78 | return self.prefix + self.src_key 79 | 80 | @property 81 | def full_dst_key(self): 82 | return self.prefix + self.dst_key 83 | 84 | 85 | def get_test_copy_bucket_fixture(name: str) -> CopyBucketFixture: 86 | copy_bucket_fixture = CopyBucketFixture(name=name) 87 | 88 | # set up / teardown 89 | copy_bucket_fixture.add(copy_bucket_fixture.src_key, b"Hello, World!\n") 90 | copy_bucket_fixture.remove(copy_bucket_fixture.dst_key) 91 | 92 | return copy_bucket_fixture 93 | 94 | 95 | @pytest.fixture(scope="session") 96 | def image_directory() -> BucketPrefixFixture: 97 | """Create a bucket/prefix fixture that contains a directory of random JPG image files.""" 98 | NUM_IMAGES = 10 99 | IMAGE_SIZE = 100 100 | fixture = BucketPrefixFixture(f"image_directory_client") 101 | for i in range(NUM_IMAGES): 102 | data = np.random.randint(0, 256, IMAGE_SIZE * IMAGE_SIZE * 3, np.uint8) 103 | data = data.reshape(IMAGE_SIZE, IMAGE_SIZE, 3) 104 | image = Image.fromarray(data, "RGB") 105 | image_bytes = io.BytesIO() 106 | image.save(image_bytes, "jpeg") 107 | image_bytes.seek(0) 108 | image_bytes = image_bytes.read() 109 | 110 | key = f"img{i:03d}.jpg" 111 | fixture.add(key, image_bytes) 112 | 113 | return fixture 114 | 115 | 116 | @pytest.fixture 117 | def sample_directory(request) -> BucketPrefixFixture: 118 | fixture = BucketPrefixFixture(f"{request.node.name}-sample_files") 119 | fixture.add("hello_world.txt", b"Hello, World!\n") 120 | return fixture 121 | 122 | 123 | @pytest.fixture 124 | def put_object_tests_directory(request) -> BucketPrefixFixture: 125 | fixture = BucketPrefixFixture(f"{request.node.name}-put_integration_tests") 126 | fixture.add("to_overwrite.txt", b"before") 127 | return fixture 128 | 129 | 130 | @pytest.fixture 131 | def checkpoint_directory(request) -> BucketPrefixFixture: 132 | return BucketPrefixFixture(f"{request.node.name}-checkpoint_directory") 133 | 134 | 135 | @pytest.fixture 136 | def empty_directory(request) -> BucketPrefixFixture: 137 | return BucketPrefixFixture(f"{request.node.name}-empty_directory") 138 | 139 | 140 | @pytest.fixture 141 | def copy_directory(request) -> CopyBucketFixture: 142 | return get_test_copy_bucket_fixture(f"{request.node.name}-copy_directory") 143 | --------------------------------------------------------------------------------