├── torchsnapshot ├── py.typed ├── tricks │ ├── __init__.py │ ├── ddp.py │ ├── fsdp.py │ └── deepspeed.py ├── io_preparers │ ├── __init__.py │ ├── object.py │ └── chunked_tensor.py ├── storage_plugins │ ├── __init__.py │ ├── fs.py │ └── s3.py ├── version.py ├── stateful.py ├── __init__.py ├── event.py ├── state_dict.py ├── rng_state.py ├── uvm_tensor.py ├── event_handlers.py ├── rss_profiler.py ├── dtensor_utils.py ├── storage_plugin.py ├── memoryview_stream.py ├── pg_wrapper.py ├── io_types.py ├── manifest_utils.py ├── knobs.py ├── asyncio_utils.py └── dist_store.py ├── docs ├── .gitignore ├── requirements.txt ├── source │ ├── api_reference.rst │ ├── index.rst │ └── conf.py ├── license_header.txt └── Makefile ├── pyproject.toml ├── requirements.txt ├── pytest.ini ├── dev-requirements.txt ├── .coveragerc ├── .github ├── workflows │ ├── pre_commit.yaml │ ├── release_build.yaml │ ├── build_docs.yaml │ └── nightly_build_cpu.yaml ├── PULL_REQUEST_TEMPLATE.md └── ISSUE_TEMPLATE │ ├── help-support.yml │ ├── documentation.yml │ ├── feature-request.yml │ └── bug-report.yml ├── benchmarks ├── ddp │ ├── run.slurm │ ├── README.md │ └── main.py ├── fsdp │ ├── run.slurm │ └── main.py ├── torchrec │ └── run.slurm ├── deepspeed_opt │ ├── run.slurm │ └── main.py └── load_tensor │ └── main.py ├── examples ├── torchrec │ └── run.slurm ├── simple_example.py └── ddp_example.py ├── .flake8 ├── tests ├── test_rss_profiler.py ├── conftest.py ├── test_uvm_tensor.py ├── test_rng_state.py ├── test_ddp_replication_glob.py ├── test_pg_wrapper.py ├── gpu_tests │ ├── test_state_dict_fsdp.py │ ├── test_dtensor_utils.py │ ├── test_manifest_utils.py │ ├── test_snapshot_fsdp.py │ ├── test_snapshot_dtensor.py │ ├── test_dtensor_io_preparer.py │ └── test_partitioner_dtensor.py ├── test_memoryview_stream.py ├── test_fs_storage_plugin.py ├── test_state_dict.py ├── test_s3_storage_plugin.py ├── test_replication_glob.py ├── test_sharded_tensor_resharding.py ├── test_serialization.py ├── test_ddp_infer_replication.py ├── test_gcs_storage_plugin.py ├── test_async_take.py ├── test_test_utils.py └── test_read_object.py ├── .pre-commit-config.yaml ├── LICENSE ├── CONTRIBUTING.md ├── setup.py ├── README.md ├── CODE_OF_CONDUCT.md └── .gitignore /torchsnapshot/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | src/pytorch-sphinx-theme/ 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | 3 | first_party_detection = false 4 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==5.0.1 2 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | PyYAML 2 | aiofiles 3 | aiohttp 4 | importlib-metadata 5 | psutil 6 | pyre_extensions 7 | torch 8 | typing-extensions 9 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --strict-markers 3 | timeout = 300 4 | markers = 5 | cpu_and_gpu 6 | gcs_integration_test 7 | gpu_only 8 | s3_integration_test 9 | -------------------------------------------------------------------------------- /dev-requirements.txt: -------------------------------------------------------------------------------- 1 | aiobotocore 2 | boto3 3 | expecttest 4 | google-cloud-storage 5 | google-resumable-media 6 | numpy 7 | pre-commit 8 | pytest==8.1.1 9 | pytest-asyncio 10 | pytest-cov 11 | pytest-timeout 12 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | setup.py 4 | *tests/* 5 | *examples/* 6 | *benchmarks/* 7 | 8 | [report] 9 | omit = 10 | setup.py 11 | *tests/* 12 | *examples/* 13 | *benchmarks/* 14 | -------------------------------------------------------------------------------- /docs/source/api_reference.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ************* 3 | 4 | .. autoclass:: torchsnapshot.Snapshot 5 | :members: 6 | 7 | .. autoclass:: torchsnapshot.StateDict 8 | 9 | .. autoclass:: torchsnapshot.RNGState 10 | -------------------------------------------------------------------------------- /docs/license_header.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchsnapshot/tricks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchsnapshot/io_preparers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /torchsnapshot/storage_plugins/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | -------------------------------------------------------------------------------- /.github/workflows/pre_commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: actions/setup-python@v3 14 | - uses: pre-commit/action@v3.0.0 15 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Please read through our [contribution guide](https://github.com/pytorch/torchsnapshot/blob/main/CONTRIBUTING.md) prior to creating your pull request. 2 | 3 | Summary: 4 | 5 | 6 | Test plan: 7 | 8 | 9 | Fixes #{issue number} 10 | 11 | -------------------------------------------------------------------------------- /benchmarks/ddp/run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --exclusive 4 | #SBATCH --nodes 2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gpus-per-task=8 7 | 8 | RDZV_ENDPOINT=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 9 | 10 | srun python3 -m torch.distributed.run --nnodes=$SLURM_NNODES --nproc_per_node=$SLURM_GPUS_PER_TASK --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT --max_restarts 0 main.py $@ 11 | -------------------------------------------------------------------------------- /benchmarks/fsdp/run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --exclusive 4 | #SBATCH --nodes 2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gpus-per-task=8 7 | 8 | RDZV_ENDPOINT=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 9 | 10 | srun python3 -m torch.distributed.run --nnodes=$SLURM_NNODES --nproc_per_node=$SLURM_GPUS_PER_TASK --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT --max_restarts 0 main.py $@ 11 | -------------------------------------------------------------------------------- /examples/torchrec/run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --exclusive 4 | #SBATCH --nodes 2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gpus-per-task=8 7 | 8 | RDZV_ENDPOINT=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 9 | 10 | srun python3 -m torch.distributed.run --nnodes=$SLURM_NNODES --nproc_per_node=$SLURM_GPUS_PER_TASK --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT --max_restarts 0 main.py $@ 11 | -------------------------------------------------------------------------------- /benchmarks/torchrec/run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --exclusive 4 | #SBATCH --nodes 2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gpus-per-task=8 7 | 8 | RDZV_ENDPOINT=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 9 | 10 | srun python3 -m torch.distributed.run --nnodes=$SLURM_NNODES --nproc_per_node=$SLURM_GPUS_PER_TASK --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT --max_restarts 0 main.py $@ 11 | -------------------------------------------------------------------------------- /benchmarks/deepspeed_opt/run.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | #SBATCH --exclusive 4 | #SBATCH --nodes 2 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --gpus-per-task=8 7 | 8 | RDZV_ENDPOINT=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 9 | 10 | srun python3 -m torch.distributed.run --nnodes=$SLURM_NNODES --nproc_per_node=$SLURM_GPUS_PER_TASK --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$RDZV_ENDPOINT --max_restarts 0 main.py $@ 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/help-support.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Help Support 2 | description: Do you need help/support? Send us your questions. 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 Question 8 | description: > 9 | Description of your question or what you need support with. 10 | validations: 11 | required: true 12 | - type: markdown 13 | attributes: 14 | value: > 15 | Thanks for contributing 🎉! 16 | -------------------------------------------------------------------------------- /torchsnapshot/version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # Follows PEP-0440 version scheme guidelines 11 | # https://www.python.org/dev/peps/pep-0440/#version-scheme 12 | # 13 | # Examples: 14 | # 0.1.0.devN # Developmental release 15 | # 0.1.0aN # Alpha release 16 | # 0.1.0bN # Beta release 17 | # 0.1.0rcN # Release Candidate 18 | # 0.1.0 # Final release 19 | __version__: str = "0.1.0" 20 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Suggested config from pytorch that we can adopt 3 | select = B,C,E,F,P,T4,W,B9 4 | max-line-length = 120 5 | # C408 ignored because we like the dict keyword argument syntax 6 | # E501 is not flexible enough, we're using B950 instead 7 | ignore = 8 | E203,E305,E402,E501,E704,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303, 9 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 10 | # to line this up with executable bit 11 | EXE001, 12 | optional-ascii-coding = True 13 | exclude = 14 | ./.git, 15 | ./docs 16 | ./build 17 | ./scripts, 18 | ./venv, 19 | *.pyi 20 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: 📚 Documentation 2 | description: Report an issue related to inline documnetation 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 📚 The doc issue 8 | description: > 9 | A clear and concise description of what content in torchsnapshot is an issue. 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Suggest a potential alternative/fix 15 | description: > 16 | Tell us how we could improve the documentation in this regard. 17 | - type: markdown 18 | attributes: 19 | value: > 20 | Thanks for contributing 🎉! 21 | -------------------------------------------------------------------------------- /tests/test_rss_profiler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import time 11 | import unittest 12 | 13 | import torch 14 | from torchsnapshot.rss_profiler import measure_rss_deltas 15 | 16 | 17 | class RSSProfilerTest(unittest.TestCase): 18 | def test_rss_profiler(self) -> None: 19 | rss_deltas = [] 20 | with measure_rss_deltas(rss_deltas=rss_deltas): 21 | torch.randn(5000, 5000) 22 | time.sleep(2) 23 | -------------------------------------------------------------------------------- /torchsnapshot/stateful.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Any, Dict, TypeVar 11 | 12 | from typing_extensions import Protocol, runtime_checkable 13 | 14 | 15 | @runtime_checkable 16 | class Stateful(Protocol): 17 | def state_dict(self) -> Dict[str, Any]: ... 18 | 19 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... 20 | 21 | 22 | T = TypeVar("T", bound=Stateful) 23 | AppState = Dict[str, T] 24 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Generator 11 | 12 | import pytest 13 | from _pytest.fixtures import SubRequest # @manual 14 | from torchsnapshot.knobs import override_is_batching_disabled 15 | 16 | 17 | @pytest.fixture(params=["batching_on", "batching_off"]) 18 | def toggle_batching(request: SubRequest) -> Generator[None, None, None]: 19 | with override_is_batching_disabled(request.param == "batching_off"): 20 | yield 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /torchsnapshot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | "A lightweight library for adding fault tolerance to large-scale PyTorch distributed training workloads" 11 | 12 | from .rng_state import RNGState 13 | from .snapshot import Snapshot 14 | from .state_dict import StateDict 15 | from .stateful import Stateful 16 | from .version import __version__ 17 | 18 | __all__ = [ 19 | "__version__", 20 | "Snapshot", 21 | "Stateful", 22 | "StateDict", 23 | "RNGState", 24 | ] 25 | -------------------------------------------------------------------------------- /benchmarks/ddp/README.md: -------------------------------------------------------------------------------- 1 | ## Running with SLURM 2 | 3 | ``` 4 | sbatch --partition=[PARTITION] --nodes=[NUM_NODES] --gpus-per-task=[NUM_GPUS_PER_NODE] run.slurm 5 | ``` 6 | 7 | ## Benchmark 8 | 9 | PyTorch version: 1.13.0.dev20220915+cu113 10 | 11 | Benchmark environment: p4d.24xlarge 12 | 13 | Model size: 20GB 14 | 15 | | Storage Type | Nodes x GPUs | torch.save | torchsnapshot | 16 | | ------------ | ------------ | ---------- | ------------- | 17 | | Local FS | 1 x 1 | ~32s | ~13.91s | 18 | | Local FS | 1 x 8 | ~32s | ~3.38s | 19 | | Local FS | 2 x 8 | ~32s | ~2.02s | 20 | | Local FS | 4 x 8 | ~32s | ~1.29s | 21 | | FSx for Lustre | 1 x 1 | ~38s | ~14.52s | 22 | | FSx for Lustre | 1 x 8 | ~38s | ~7.61s | 23 | | FSx for Lustre | 2 x 8 | ~38s | ~4.61s | 24 | | FSx for Lustre | 4 x 8 | ~38s | ~2.68s | 25 | -------------------------------------------------------------------------------- /torchsnapshot/event.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-strict 8 | 9 | from dataclasses import dataclass, field 10 | from typing import Dict, Union 11 | 12 | EventMetadataValue = Union[str, int, float, bool, None] 13 | 14 | 15 | @dataclass 16 | class Event: 17 | """ 18 | The class represents the generic event that occurs during TorchSnapshot 19 | execution. The event can be any kind of meaningful action. 20 | 21 | Args: 22 | name: event name. 23 | metadata: additional data that is associated with the event. 24 | """ 25 | 26 | name: str 27 | metadata: Dict[str, EventMetadataValue] = field(default_factory=dict) 28 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to the TorchSnapshot documentation! 2 | =========================================== 3 | 4 | TorchSnapshot is a PyTorch library for adding fault tolerance to large-scale PyTorch distributed training workloads. 5 | 6 | `Installation instructions `_ 7 | 8 | 9 | TorchSnapshot API 10 | ----------------- 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: Contents: 15 | 16 | getting_started.rst 17 | api_reference.rst 18 | 19 | Examples 20 | -------- 21 | 22 | * `Simple example `_ 23 | * `Using TorchSnapshot with DistributedDataParallel (DDP) `_ 24 | * `Using TorchSnapshot with TorchRec `_ 25 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.1.0 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: check-ast 10 | - id: check-merge-conflict 11 | - id: check-added-large-files 12 | args: ['--maxkb=500'] 13 | - id: end-of-file-fixer 14 | 15 | - repo: https://github.com/Lucas-C/pre-commit-hooks 16 | rev: v1.1.7 17 | hooks: 18 | - id: insert-license 19 | files: \.py$ 20 | args: 21 | - --license-filepath 22 | - docs/license_header.txt 23 | 24 | - repo: https://github.com/pycqa/flake8 25 | rev: 6.1.0 26 | hooks: 27 | - id: flake8 28 | args: 29 | - --config=.flake8 30 | 31 | - repo: https://github.com/omnilib/ufmt 32 | rev: v2.5.1 33 | hooks: 34 | - id: ufmt 35 | additional_dependencies: 36 | - black == 24.2.0 37 | - usort == 1.0.2 38 | -------------------------------------------------------------------------------- /torchsnapshot/state_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from collections import UserDict 11 | from typing import Any, Dict 12 | 13 | 14 | # pyre-fixme[24]: Python <3.9 doesn't support typing on UserDict 15 | class StateDict(UserDict): 16 | """ 17 | A dictionary that exposes ``.state_dict()`` and ``.load_state_dict()`` 18 | methods. 19 | 20 | It can be used to capture objects that do not expose ``.state_dict()`` and 21 | ``.load_state_dict()`` methods (e.g. Tensors, Python primitive types) as 22 | part of the application state. 23 | """ 24 | 25 | def state_dict(self) -> Dict[str, Any]: 26 | return self.data 27 | 28 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 29 | self.data.update(state_dict) 30 | -------------------------------------------------------------------------------- /tests/test_uvm_tensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[56] 11 | 12 | import pytest 13 | import torch 14 | from torchsnapshot.uvm_tensor import ( 15 | _UVM_TENSOR_AVAILABLE, 16 | is_uvm_tensor, 17 | new_managed_tensor, 18 | uvm_to_cpu, 19 | ) 20 | 21 | 22 | @pytest.mark.cpu_and_gpu 23 | def test_uvm_tensor() -> None: 24 | if torch.cuda.is_available() and _UVM_TENSOR_AVAILABLE: 25 | uvm_tensor = torch.rand( 26 | (64, 64), 27 | out=new_managed_tensor( 28 | torch.empty(0, dtype=torch.float32, device="cuda:0"), 29 | [64, 64], 30 | ), 31 | ) 32 | assert is_uvm_tensor(uvm_tensor) 33 | cpu_tensor = uvm_to_cpu(uvm_tensor) 34 | assert not is_uvm_tensor(cpu_tensor) 35 | else: 36 | tensor = torch.rand(64, 64) 37 | assert not is_uvm_tensor(tensor) 38 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: 🚀 Feature request 2 | description: Submit a proposal/request for a new TorchSnapshot feature 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: 🚀 The feature 8 | description: > 9 | A clear and concise description of the feature proposal 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Motivation, pitch 15 | description: > 16 | Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., 17 | *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link 18 | here too. 19 | validations: 20 | required: true 21 | - type: textarea 22 | attributes: 23 | label: Alternatives 24 | description: > 25 | A description of any alternative solutions or features you've considered, if any. 26 | - type: textarea 27 | attributes: 28 | label: Additional context 29 | description: > 30 | Add any other context or screenshots about the feature request. 31 | - type: markdown 32 | attributes: 33 | value: > 34 | Thanks for contributing 🎉! 35 | -------------------------------------------------------------------------------- /tests/test_rng_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import tempfile 11 | import unittest 12 | from typing import Any, Dict 13 | 14 | import torch 15 | import torchsnapshot 16 | 17 | 18 | class StatefulWithRNGSideEffect: 19 | def state_dict(self) -> Dict[str, Any]: 20 | torch.rand([2]) 21 | return {} 22 | 23 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 24 | torch.rand([3]) 25 | 26 | 27 | class RNGStateTest(unittest.TestCase): 28 | def test_rng_state(self) -> None: 29 | app_state = { 30 | "rng_state": torchsnapshot.RNGState(), 31 | "effectful": StatefulWithRNGSideEffect(), 32 | } 33 | 34 | with tempfile.TemporaryDirectory() as tmp_dir: 35 | snapshot = torchsnapshot.Snapshot.take(path=tmp_dir, app_state=app_state) 36 | after_take = torch.rand(1) 37 | snapshot.restore(app_state) 38 | after_restore = torch.rand(1) 39 | torch.testing.assert_close(after_take, after_restore) 40 | -------------------------------------------------------------------------------- /.github/workflows/release_build.yaml: -------------------------------------------------------------------------------- 1 | name: Push Release to PyPi 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | run_tests: 8 | uses: ./.github/workflows/run_tests.yaml 9 | secrets: inherit 10 | 11 | upload_to_pypi: 12 | needs: run_tests 13 | runs-on: ubuntu-latest 14 | steps: 15 | - name: Check out repo 16 | uses: actions/checkout@v2 17 | - name: Setup conda env 18 | uses: conda-incubator/setup-miniconda@v2 19 | with: 20 | miniconda-version: "latest" 21 | activate-environment: test 22 | python-version: 3.8 23 | - name: Install dependencies 24 | shell: bash -l {0} 25 | run: | 26 | set -eux 27 | conda activate test 28 | conda install pytorch cpuonly -c pytorch-nightly 29 | pip install -r requirements.txt 30 | pip install -r dev-requirements.txt 31 | pip install --no-build-isolation -e ".[dev]" 32 | - name: Upload to PyPI 33 | shell: bash -l {0} 34 | env: 35 | PYPI_USER: ${{ secrets.PYPI_USER_RELEASE }} 36 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN_RELEASE }} 37 | run: | 38 | set -eux 39 | conda activate test 40 | pip install twine 41 | python setup.py sdist bdist_wheel 42 | twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose 43 | -------------------------------------------------------------------------------- /torchsnapshot/rng_state.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Dict 11 | 12 | import torch 13 | 14 | 15 | class RNGState: 16 | """ 17 | A special stateful object for saving and restoring global RNG state. 18 | 19 | When captured in the application state, it is guaranteed that the global 20 | RNG state is set to the same values after restoring from the snapshot as it 21 | was after taking the snapshot. 22 | 23 | Example: 24 | 25 | :: 26 | 27 | >>> Snapshot.take( 28 | >>> path="foo/bar", 29 | >>> app_state={"rng_state": RNGState()}, 30 | >>> ) 31 | >>> after_take = torch.rand(1) 32 | 33 | >>> # In the same process or in another process 34 | >>> snapshot = Snapshot(path="foo/bar") 35 | >>> snapshot.restore(app_state) 36 | >>> after_restore = torch.rand(1) 37 | 38 | >>> torch.testing.assert_close(after_take, after_restore) 39 | """ 40 | 41 | # TODO: augment this to capture rng states other than torch.get_rng_state() 42 | 43 | def state_dict(self) -> Dict[str, torch.Tensor]: 44 | return {"rng_state": torch.get_rng_state()} 45 | 46 | def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: 47 | torch.set_rng_state(state_dict["rng_state"]) 48 | -------------------------------------------------------------------------------- /.github/workflows/build_docs.yaml: -------------------------------------------------------------------------------- 1 | name: Build and Update Docs 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | 7 | # Allow one concurrent deployment 8 | concurrency: 9 | group: "pages" 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | build_docs: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Check out repo 17 | uses: actions/checkout@v2 18 | - name: Setup conda env 19 | uses: conda-incubator/setup-miniconda@v2 20 | with: 21 | miniconda-version: "latest" 22 | activate-environment: test 23 | - name: Install dependencies 24 | shell: bash -l {0} 25 | run: | 26 | set -eux 27 | conda activate test 28 | conda install pytorch cpuonly -c pytorch-nightly 29 | pip install -r requirements.txt 30 | pip install -r dev-requirements.txt 31 | python setup.py sdist bdist_wheel 32 | pip install dist/*.whl 33 | - name: Build docs 34 | shell: bash -l {0} 35 | run: | 36 | set -eux 37 | conda activate test 38 | cd docs 39 | pip install -r requirements.txt 40 | sphinx-build -b html source build/html/main 41 | touch build/html/main/.nojekyll 42 | cd .. 43 | - name: Deploy docs to Github pages 44 | uses: JamesIves/github-pages-deploy-action@v4.4.1 45 | with: 46 | branch: gh-pages # The branch the action should deploy to. 47 | folder: docs/build/html/main # The folder the action should deploy. 48 | target-folder: main 49 | -------------------------------------------------------------------------------- /torchsnapshot/tricks/ddp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Any, Dict 11 | 12 | import torch 13 | 14 | DDP_STATE_DICT_PREFIX: str = "module." 15 | 16 | 17 | class DistributedDataParallelAdapter: 18 | """ 19 | A convenience class to load a module's state dict saved from a DistributedDataParallel-wrapped module into a module that is not wrapped with DDP. 20 | 21 | Example:: 22 | 23 | >>> module = torch.nn.Linear(2, 2) 24 | >>> ddp_module = DistributedDataParallel(module) 25 | >>> Snapshot.take( 26 | >>> path="foo/bar", 27 | >>> app_state={"module": ddp_module}, 28 | >>> ) 29 | 30 | >>> # Restore the state 31 | >>> snapshot = Snapshot(path="foo/bar") 32 | >>> adapter = DistributedDataParallelAdapter(module) 33 | >>> snapshot.restore({"module": adapter}) 34 | >>> module = adapter.module 35 | """ 36 | 37 | def __init__(self, module: torch.nn.Module) -> None: 38 | self.module = module 39 | 40 | def state_dict(self) -> Dict[str, Any]: 41 | return self.module.state_dict() 42 | 43 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 44 | torch.nn.modules.utils.consume_prefix_in_state_dict_if_present( 45 | state_dict, DDP_STATE_DICT_PREFIX 46 | ) 47 | self.module.load_state_dict(state_dict) 48 | -------------------------------------------------------------------------------- /torchsnapshot/uvm_tensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | from typing import List 10 | 11 | import torch 12 | 13 | _UVM_TENSOR_AVAILABLE = False 14 | 15 | try: 16 | # pyre-fixme[21]: Could not find module `fbgemm_gpu`. 17 | import fbgemm_gpu # @manual # noqa 18 | except Exception: 19 | pass 20 | 21 | try: 22 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:cumem_utils") 23 | except Exception: 24 | pass 25 | 26 | 27 | try: 28 | # pyre-fixme[9]: new_managed_tensor has type `(t: Tensor, sizes: List[int]) -> 29 | # Tensor`; used as `OpOverloadPacket`. 30 | new_managed_tensor = torch.ops.fbgemm.new_managed_tensor 31 | # pyre-fixme[9]: is_uvm_tensor has type `(t: Tensor) -> bool`; used as 32 | # `OpOverloadPacket`. 33 | is_uvm_tensor = torch.ops.fbgemm.is_uvm_tensor 34 | # pyre-fixme[9]: uvm_to_cpu has type `(t: Tensor) -> Tensor`; used as 35 | # `OpOverloadPacket`. 36 | uvm_to_cpu = torch.ops.fbgemm.uvm_to_cpu 37 | 38 | _UVM_TENSOR_AVAILABLE = True 39 | except AttributeError: 40 | 41 | def new_managed_tensor(t: torch.Tensor, sizes: List[int]) -> torch.Tensor: 42 | raise NotImplementedError() 43 | 44 | def is_uvm_tensor(t: torch.Tensor) -> bool: 45 | return False 46 | 47 | def uvm_to_cpu(t: torch.Tensor) -> torch.Tensor: 48 | return t 49 | 50 | 51 | __all__ = ["is_uvm_tensor", "uvm_to_cpu"] 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For torchsnapshot software 4 | 5 | Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Meta nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torchsnapshot 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints: 16 | - using *pre-commit* (only need to do this once) 17 | - install pre-commit: `pip install pre-commit` 18 | - add it as a git hook: `pre-commit install` 19 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 20 | 21 | ## Contributor License Agreement ("CLA") 22 | In order to accept your pull request, we need you to submit a CLA. You only need 23 | to do this once to work on any of Meta's open source projects. 24 | 25 | Complete your CLA here: 26 | 27 | ## Issues 28 | We use GitHub issues to track public bugs. Please ensure your description is 29 | clear and has sufficient instructions to be able to reproduce the issue. 30 | 31 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 32 | disclosure of security bugs. In those cases, please go through the process 33 | outlined on that page and do not file a public issue. 34 | 35 | 36 | ## License 37 | By contributing to torchsnapshot, you agree that your contributions will be licensed 38 | under the LICENSE file in the root directory of this source tree. 39 | -------------------------------------------------------------------------------- /.github/workflows/nightly_build_cpu.yaml: -------------------------------------------------------------------------------- 1 | name: Push CPU Binary Nightly 2 | 3 | on: 4 | # run every day at 11:15am 5 | schedule: 6 | - cron: '15 11 * * *' 7 | # or manually trigger it 8 | workflow_dispatch: 9 | inputs: 10 | append_to_version: 11 | description: "Optional value to append to version string" 12 | 13 | jobs: 14 | run_tests: 15 | permissions: 16 | id-token: write 17 | contents: read 18 | uses: ./.github/workflows/run_tests.yaml 19 | secrets: inherit 20 | 21 | upload_to_pypi: 22 | needs: run_tests 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Check out repo 26 | uses: actions/checkout@v2 27 | - name: Setup conda env 28 | uses: conda-incubator/setup-miniconda@v2 29 | with: 30 | miniconda-version: "latest" 31 | activate-environment: test 32 | python-version: 3.8 33 | - name: Install dependencies 34 | shell: bash -l {0} 35 | run: | 36 | set -eux 37 | conda activate test 38 | conda install pytorch cpuonly -c pytorch-nightly 39 | pip install -r requirements.txt 40 | pip install -r dev-requirements.txt 41 | pip install --no-build-isolation -e ".[dev]" 42 | - name: Upload to PyPI 43 | shell: bash -l {0} 44 | env: 45 | PYPI_USER: ${{ secrets.PYPI_USER }} 46 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 47 | run: | 48 | set -eux 49 | conda activate test 50 | pip install twine 51 | python setup.py --nightly --append-to-version=${{ github.event.inputs.append_to_version }} sdist bdist_wheel 52 | twine upload --username "$PYPI_USER" --password "$PYPI_TOKEN" dist/* --verbose 53 | -------------------------------------------------------------------------------- /torchsnapshot/event_handlers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pyre-strict 4 | 5 | # Copyright (c) Meta Platforms, Inc. and affiliates. 6 | # All rights reserved. 7 | # 8 | # This source code is licensed under the BSD-style license found in the 9 | # LICENSE file in the root directory of this source tree. 10 | 11 | import logging 12 | from functools import lru_cache 13 | from typing import List 14 | 15 | from importlib_metadata import entry_points 16 | from typing_extensions import Protocol, runtime_checkable 17 | 18 | from .event import Event 19 | 20 | logger: logging.Logger = logging.getLogger(__name__) 21 | 22 | 23 | @runtime_checkable 24 | class EventHandler(Protocol): 25 | def handle_event(self, event: Event) -> None: ... 26 | 27 | 28 | _log_handlers: List[EventHandler] = [] 29 | 30 | 31 | @lru_cache(maxsize=None) 32 | def get_event_handlers() -> List[EventHandler]: 33 | global _log_handlers 34 | 35 | # Registered event handlers through entry points 36 | eps = entry_points(group="event_handlers") 37 | for entry in eps: 38 | logger.debug( 39 | f"Attempting to register event handler {entry.name}: {entry.value}" 40 | ) 41 | factory = entry.load() 42 | handler = factory() 43 | 44 | if not isinstance(handler, EventHandler): 45 | raise RuntimeError( 46 | f"The factory function for {({entry.value})} " 47 | "did not return a EventHandler object." 48 | ) 49 | _log_handlers.append(handler) 50 | return _log_handlers 51 | 52 | 53 | def log_event(event: Event) -> None: 54 | """ 55 | Handle an event. 56 | Args: 57 | event: The event to handle. 58 | """ 59 | for handler in get_event_handlers(): 60 | handler.handle_event(event) 61 | -------------------------------------------------------------------------------- /torchsnapshot/rss_profiler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | 11 | import time 12 | from contextlib import contextmanager 13 | from datetime import timedelta 14 | from threading import Event, Thread 15 | from typing import Generator, List 16 | 17 | import psutil 18 | 19 | _DEFAULT_MEASURE_INTERVAL = timedelta(milliseconds=100) 20 | 21 | 22 | def _measure( 23 | rss_deltas: List[int], 24 | interval: timedelta, 25 | baseline_rss_bytes: int, 26 | stop_event: Event, 27 | ) -> None: 28 | p = psutil.Process() 29 | while not stop_event.is_set(): 30 | rss_deltas.append(p.memory_info().rss - baseline_rss_bytes) 31 | time.sleep(interval.total_seconds()) 32 | 33 | 34 | @contextmanager 35 | def measure_rss_deltas( 36 | rss_deltas: List[int], interval: timedelta = _DEFAULT_MEASURE_INTERVAL 37 | ) -> Generator[None, None, None]: 38 | """ 39 | A context manager that periodically measures RSS (resident set size) delta. 40 | 41 | The baseline RSS is measured when the context manager is initialized. 42 | 43 | Args: 44 | rss_deltas: The list to which the measured RSS deltas (measured in 45 | bytes) are appended. 46 | interval: The interval at which RSS deltas are measured. 47 | """ 48 | baseline_rss_bytes = psutil.Process().memory_info().rss 49 | stop_event = Event() 50 | thread = Thread( 51 | target=_measure, args=(rss_deltas, interval, baseline_rss_bytes, stop_event) 52 | ) 53 | thread.start() 54 | try: 55 | yield 56 | finally: 57 | stop_event.set() 58 | thread.join() 59 | -------------------------------------------------------------------------------- /tests/test_ddp_replication_glob.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from pathlib import Path 11 | from typing import List, Optional 12 | 13 | import pytest 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torch.nn.parallel import DistributedDataParallel as DDP 18 | from torchsnapshot import Snapshot 19 | from torchsnapshot.manifest_utils import is_fully_replicated_entry 20 | from torchsnapshot.stateful import AppState 21 | from torchsnapshot.test_utils import run_with_pet 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "replication_globs, expected_replicated_paths", 26 | [ 27 | ([], ["0/ddp/module.weight", "0/ddp/module.bias"]), 28 | (None, ["0/ddp/module.weight", "0/ddp/module.bias"]), 29 | ( 30 | ["**"], 31 | [ 32 | "0/ddp/module.weight", 33 | "0/ddp/module.bias", 34 | "0/nonddp/weight", 35 | "0/nonddp/bias", 36 | ], 37 | ), 38 | ], 39 | ) 40 | @run_with_pet(nproc=2) 41 | def test_ddp_replication_glob( 42 | replication_globs: Optional[List[str]], 43 | expected_replicated_paths: List[str], 44 | tmp_path: Path, 45 | ) -> None: 46 | dist.init_process_group(backend="gloo") 47 | app_state: AppState = { 48 | "ddp": DDP(torch.nn.Linear(4, 3)), 49 | "nonddp": torch.nn.Linear(3, 2), 50 | } 51 | snapshot = Snapshot.take( 52 | path=str(tmp_path), 53 | app_state=app_state, 54 | replicated=replication_globs, 55 | ) 56 | replicated_paths = [ 57 | path 58 | for path, entry in snapshot.get_manifest().items() 59 | if is_fully_replicated_entry(entry) 60 | ] 61 | assert set(replicated_paths) == set(expected_replicated_paths) 62 | -------------------------------------------------------------------------------- /torchsnapshot/tricks/fsdp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Any, Dict 11 | 12 | import torch 13 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 14 | 15 | 16 | class FSDPOptimizerAdapter: 17 | """ 18 | Wrapper for FSDP optimizer to call specific FSDP optimizer state checkpointing APIs. 19 | 20 | Example:: 21 | 22 | >>> module = torch.nn.Linear(2, 2) 23 | >>> fsdp_module = FullyShardedDataParallel(module) 24 | >>> optimizer = torch.optim.SGD(fsdp_module.parameters(), lr=0.1) 25 | >>> Snapshot.take( 26 | >>> path="foo/bar", 27 | >>> app_state={"module": fsdp_module, "optim": FSDPOptimizerAdapter(fsdp_module, optimizer)}, 28 | >>> ) 29 | 30 | >>> # Restore the state 31 | >>> snapshot = Snapshot(path="foo/bar") 32 | >>> module = torch.nn.Linear(2, 2) 33 | >>> fsdp_module = FullyShardedDataParallel(module) 34 | >>> optimizer = torch.optim.SGD(fsdp_module.parameters(), lr=0.1) 35 | >>> adapter = FSDPOptimizerAdapter(module) 36 | >>> snapshot.restore({"module": module, "optim": FSDPOptimizerAdapter(fsdp_module, optimizer)}) 37 | """ 38 | 39 | def __init__(self, module: FSDP, optimizer: torch.optim.Optimizer) -> None: 40 | self.module = module 41 | self.optimizer = optimizer 42 | 43 | def state_dict(self) -> Dict[str, Any]: 44 | optim_state_dict = FSDP.optim_state_dict(self.module, self.optimizer) 45 | return optim_state_dict 46 | 47 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 48 | optim_state_dict = FSDP.optim_state_dict_to_load( 49 | self.module, self.optimizer, state_dict 50 | ) 51 | self.optimizer.load_state_dict(optim_state_dict) 52 | -------------------------------------------------------------------------------- /tests/test_pg_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[56] 11 | 12 | import os 13 | import unittest 14 | 15 | import torch 16 | 17 | import torch.distributed as dist 18 | import torch.distributed.launcher as pet 19 | from torchsnapshot.pg_wrapper import PGWrapper 20 | from torchsnapshot.test_utils import get_pet_launch_config 21 | 22 | 23 | class TestPGWrapper(unittest.TestCase): 24 | @staticmethod 25 | def _worker(backend: str) -> None: 26 | tc = unittest.TestCase() 27 | dist.init_process_group(backend=backend) 28 | if backend == "nccl": 29 | local_rank = int(os.environ["LOCAL_RANK"]) 30 | torch.cuda.set_device(torch.device(f"cuda:{local_rank}")) 31 | pg_wrapper = PGWrapper(pg=None) 32 | output_list = [None] 33 | input_list = [["foo"], ["bar"], ["quaz"]] 34 | pg_wrapper.scatter_object_list(output_list=output_list, input_list=input_list) 35 | rank = dist.get_rank() 36 | tc.assertEqual(output_list, [input_list[rank]]) 37 | 38 | def test_scatter_obj_list_gloo(self) -> None: 39 | lc = get_pet_launch_config(nproc=3) 40 | pet.elastic_launch(lc, entrypoint=self._worker)("gloo") 41 | 42 | @unittest.skipUnless(torch.cuda.is_available(), "This test requires GPU to run.") 43 | def test_scatter_obj_list_nccl(self) -> None: 44 | lc = get_pet_launch_config(nproc=3) 45 | pet.elastic_launch(lc, entrypoint=self._worker)("nccl") 46 | 47 | def test_scatter_obj_list_dist_uninitialized(self) -> None: 48 | pg_wrapper = PGWrapper(pg=None) 49 | output_list = [None] 50 | input_list = [["foo"]] 51 | pg_wrapper.scatter_object_list(output_list=output_list, input_list=input_list) 52 | self.assertEqual(output_list, [input_list[0]]) 53 | -------------------------------------------------------------------------------- /tests/gpu_tests/test_state_dict_fsdp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | from pathlib import Path 12 | 13 | import pytest 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType 18 | from torchsnapshot import Snapshot 19 | from torchsnapshot.test_utils import check_state_dict_eq, run_with_pet 20 | 21 | 22 | def _create_fsdp_model( 23 | seed: int, 24 | device: torch.device, 25 | ) -> torch.nn.Module: 26 | torch.manual_seed(seed) 27 | model = torch.nn.Linear(32, 32) 28 | 29 | fsdp_model = FSDP( 30 | module=model, 31 | device_id=device, 32 | ) 33 | FSDP.set_state_dict_type(fsdp_model, StateDictType.SHARDED_STATE_DICT) 34 | return fsdp_model 35 | 36 | 37 | @pytest.mark.skipif( 38 | bool(not torch.cuda.is_available()), reason="The test requires GPUs to run." 39 | ) 40 | @pytest.mark.skipif( 41 | bool(torch.cuda.device_count() < 2), reason="At least two GPUs are required." 42 | ) 43 | @run_with_pet(nproc=2) 44 | def test_model_and_optim_fsdp(tmp_path: Path) -> None: 45 | dist.init_process_group(backend="nccl") 46 | local_rank = int(os.environ["LOCAL_RANK"]) 47 | device = torch.device(f"cuda:{local_rank}") 48 | torch.cuda.set_device(device) 49 | 50 | fsdp_model = _create_fsdp_model(17, device) 51 | 52 | snapshot = Snapshot.take( 53 | path=str(tmp_path), 54 | app_state={"fsdp_model": fsdp_model}, 55 | ) 56 | state_dict_from_method = snapshot.get_state_dict_for_key("fsdp_model") 57 | FSDP.set_state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT) 58 | 59 | full_state_dict = fsdp_model.state_dict() 60 | for k, v in full_state_dict.items(): 61 | full_state_dict[k] = v.cpu() 62 | 63 | assert check_state_dict_eq(full_state_dict, state_dict_from_method) 64 | -------------------------------------------------------------------------------- /torchsnapshot/storage_plugins/fs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import io 11 | import os 12 | import pathlib 13 | from typing import Any, Dict, Optional, Set 14 | 15 | import aiofiles 16 | import aiofiles.os 17 | 18 | from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO 19 | 20 | 21 | class FSStoragePlugin(StoragePlugin): 22 | def __init__( 23 | self, root: str, storage_options: Optional[Dict[str, Any]] = None 24 | ) -> None: 25 | self.root = root 26 | self._dir_cache: Set[pathlib.Path] = set() 27 | 28 | async def write(self, write_io: WriteIO) -> None: 29 | path = os.path.join(self.root, write_io.path) 30 | 31 | dir_path = pathlib.Path(path).parent 32 | if dir_path not in self._dir_cache: 33 | dir_path.mkdir(parents=True, exist_ok=True) 34 | self._dir_cache.add(dir_path) 35 | 36 | async with aiofiles.open(path, "wb+") as f: 37 | # pyre-ignore: memoryview is actually supported 38 | await f.write(write_io.buf) 39 | 40 | async def read(self, read_io: ReadIO) -> None: 41 | path = os.path.join(self.root, read_io.path) 42 | byte_range = read_io.byte_range 43 | 44 | async with aiofiles.open(path, "rb") as f: 45 | if byte_range is None: 46 | read_io.buf = io.BytesIO(await f.read()) 47 | else: 48 | offset = byte_range[0] 49 | size = byte_range[1] - byte_range[0] 50 | await f.seek(offset) 51 | read_io.buf = io.BytesIO(await f.read(size)) 52 | 53 | async def delete(self, path: str) -> None: 54 | path = os.path.join(self.root, path) 55 | await aiofiles.os.remove(path) 56 | 57 | async def delete_dir(self, path: str) -> None: 58 | path = os.path.join(self.root, path) 59 | await aiofiles.os.rmdir(path) 60 | 61 | async def close(self) -> None: 62 | pass 63 | -------------------------------------------------------------------------------- /tests/test_memoryview_stream.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import io 11 | import unittest 12 | 13 | import torch 14 | from torchsnapshot.memoryview_stream import MemoryviewStream 15 | 16 | 17 | class MemoryviewStreamTest(unittest.TestCase): 18 | def test_memoryview_stream(self) -> None: 19 | tensor = torch.rand(1000) 20 | # pyre-fixme[6]: For 1st argument expected `Buffer` but got `ndarray[Any, Any]`. 21 | mv = memoryview(tensor.numpy()).cast("b") 22 | self.assertEqual(len(mv), 4000) 23 | 24 | mvs = MemoryviewStream(mv=mv) 25 | bio = io.BytesIO(mv.tobytes()) 26 | 27 | self.assertTrue(mvs.readable()) 28 | self.assertTrue(bio.readable()) 29 | 30 | self.assertTrue(mvs.seekable()) 31 | self.assertTrue(bio.seekable()) 32 | 33 | buf = bytes(mvs.read(20)) 34 | self.assertEqual(len(buf), 20) 35 | self.assertEqual(buf, bio.read(20)) 36 | 37 | pos = mvs.tell() 38 | self.assertEqual(pos, 20) 39 | self.assertEqual(pos, bio.tell()) 40 | 41 | pos = mvs.seek(500) 42 | self.assertEqual(pos, 500) 43 | self.assertEqual(pos, bio.seek(500)) 44 | 45 | buf = bytes(mvs.read(20)) 46 | self.assertEqual(len(buf), 20) 47 | self.assertEqual(buf, bio.read(20)) 48 | 49 | pos = mvs.tell() 50 | self.assertEqual(pos, 520) 51 | self.assertEqual(pos, bio.tell()) 52 | 53 | buf = bytes(mvs.read(4000)) 54 | self.assertEqual(len(buf), 3480) 55 | self.assertEqual(buf, bio.read(4000)) 56 | 57 | pos = mvs.tell() 58 | self.assertEqual(pos, 4000) 59 | self.assertEqual(pos, bio.tell()) 60 | 61 | pos = mvs.seek(0) 62 | self.assertEqual(pos, 0) 63 | self.assertEqual(pos, bio.seek(0)) 64 | 65 | buf = bytes(mvs.read(4500)) 66 | self.assertEqual(len(buf), 4000) 67 | self.assertEqual(buf, bio.read(4500)) 68 | -------------------------------------------------------------------------------- /tests/gpu_tests/test_dtensor_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import torch 11 | 12 | import torch.distributed as dist 13 | from torch.distributed._shard import sharded_tensor 14 | from torch.distributed._shard.sharding_spec import ChunkShardingSpec 15 | 16 | from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate, Shard 17 | from torch.testing._internal.common_utils import instantiate_parametrized_tests 18 | from torch.testing._internal.distributed._tensor.common_dtensor import ( 19 | DTensorTestBase, 20 | skip_if_lt_x_gpu, 21 | with_comms, 22 | ) 23 | from torchsnapshot.dtensor_utils import is_replicated_dtensor, is_sharded 24 | 25 | WORLD_SIZE = 4 26 | 27 | 28 | @instantiate_parametrized_tests 29 | class TestDTensorUtils(DTensorTestBase): 30 | @with_comms 31 | @skip_if_lt_x_gpu(WORLD_SIZE) 32 | # pyre-fixme[3]: Return type must be annotated. 33 | def test_is_sharded_is_replicated(self): 34 | mesh = DeviceMesh("cuda", mesh=[[0, 1], [2, 3]]) 35 | placements = [Replicate(), Shard(0)] 36 | local_tensor = torch.rand((16, 16)) 37 | dtensor = distribute_tensor( 38 | tensor=local_tensor, device_mesh=mesh, placements=placements 39 | ) 40 | assert is_sharded(dtensor) 41 | assert is_replicated_dtensor(dtensor) 42 | 43 | placements = [Replicate(), Replicate()] 44 | dtensor = distribute_tensor( 45 | tensor=local_tensor, device_mesh=mesh, placements=placements 46 | ) 47 | assert not is_sharded(dtensor) 48 | assert is_replicated_dtensor(dtensor) 49 | 50 | # pyre-ignore 51 | spec = ChunkShardingSpec( 52 | dim=0, 53 | placements=[ 54 | f"rank:{rank}/cuda:{rank}" for rank in range(dist.get_world_size()) 55 | ], 56 | ) 57 | stensor = sharded_tensor.empty(spec, (16, 16)) 58 | assert is_sharded(stensor) 59 | 60 | placements = [Shard(0), Shard(1)] 61 | dtensor = distribute_tensor( 62 | tensor=local_tensor, device_mesh=mesh, placements=placements 63 | ) 64 | assert is_sharded(dtensor) 65 | assert not is_replicated_dtensor(dtensor) 66 | -------------------------------------------------------------------------------- /tests/test_fs_storage_plugin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[56] 11 | 12 | import io 13 | 14 | import logging 15 | import random 16 | from pathlib import Path 17 | 18 | import pytest 19 | 20 | import torch 21 | from torchsnapshot import Snapshot, StateDict 22 | from torchsnapshot.io_types import ReadIO, WriteIO 23 | from torchsnapshot.storage_plugins.fs import FSStoragePlugin 24 | 25 | logger: logging.Logger = logging.getLogger(__name__) 26 | 27 | _TENSOR_SZ = int(1_000_000 / 4) 28 | 29 | 30 | def test_fs_read_write_via_snapshot(tmp_path: Path) -> None: 31 | tensor = torch.rand((_TENSOR_SZ,)) 32 | app_state = {"state": StateDict(tensor=tensor)} 33 | snapshot = Snapshot.take(path=str(tmp_path), app_state=app_state) 34 | 35 | app_state["state"]["tensor"] = torch.rand((_TENSOR_SZ,)) 36 | assert not torch.allclose(tensor, app_state["state"]["tensor"]) 37 | 38 | snapshot.restore(app_state) 39 | assert torch.allclose(tensor, app_state["state"]["tensor"]) 40 | 41 | 42 | @pytest.mark.asyncio 43 | async def test_fs_write_read_delete(tmp_path: Path) -> None: 44 | plugin = FSStoragePlugin(root=str(tmp_path)) 45 | 46 | tensor = torch.rand((_TENSOR_SZ,)) 47 | buf = io.BytesIO() 48 | torch.save(tensor, buf) 49 | write_io = WriteIO(path="tensor", buf=buf.getbuffer()) 50 | 51 | await plugin.write(write_io=write_io) 52 | 53 | read_io = ReadIO(path="tensor") 54 | await plugin.read(read_io=read_io) 55 | loaded = torch.load(read_io.buf) 56 | assert torch.allclose(tensor, loaded) 57 | 58 | await plugin.delete(path="tensor") 59 | await plugin.close() 60 | 61 | 62 | @pytest.mark.asyncio 63 | async def test_fs_ranged_read(tmp_path: Path) -> None: 64 | plugin = FSStoragePlugin(root=str(tmp_path)) 65 | 66 | buf = bytes(random.getrandbits(8) for _ in range(2000)) 67 | write_io = WriteIO(path="rand_bytes", buf=memoryview(buf)) 68 | 69 | await plugin.write(write_io=write_io) 70 | 71 | read_io = ReadIO(path="rand_bytes", byte_range=(100, 200)) 72 | await plugin.read(read_io=read_io) 73 | assert len(read_io.buf.getvalue()) == 100 74 | assert read_io.buf.getvalue(), buf[100:200] 75 | 76 | await plugin.close() 77 | -------------------------------------------------------------------------------- /benchmarks/ddp/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import time 10 | import uuid 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchsnapshot 15 | from torch.nn.parallel import DistributedDataParallel 16 | 17 | 18 | class Model(torch.nn.Module): 19 | def __init__(self, param_size: int, num_params: int) -> None: 20 | super().__init__() 21 | for i in range(num_params): 22 | self.register_parameter( 23 | f"param_{i}", 24 | torch.nn.Parameter( 25 | torch.rand(int(param_size / 4), device=torch.cuda.current_device()) 26 | ), 27 | ) 28 | 29 | 30 | def rank_0_print(msg: str) -> None: 31 | if dist.get_rank() == 0: 32 | print(msg) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("--work-dir", default="/tmp") 38 | parser.add_argument("--param-size", type=int, default=int(100_000_000)) 39 | parser.add_argument("--num-params", type=int, default=200) 40 | args = parser.parse_args() 41 | 42 | local_rank = int(os.environ["LOCAL_RANK"]) 43 | device = torch.device(f"cuda:{local_rank}") 44 | torch.cuda.set_device(device) 45 | dist.init_process_group(backend="nccl") 46 | 47 | model = Model(param_size=args.param_size, num_params=args.num_params) 48 | model = DistributedDataParallel(model, gradient_as_bucket_view=True) 49 | 50 | sz = sum(t.nelement() * t.element_size() for t in model.parameters()) 51 | rank_0_print(f"Model size: {sz / 1_000_000_000.0} GB") 52 | 53 | if dist.get_rank() == 0: 54 | print("Saving the model with torch.save...") 55 | t_begin = time.time() 56 | with open(f"{args.work_dir}/{uuid.uuid4()}.pt", "wb+") as f: 57 | torch.save(model.state_dict(), f) 58 | print(f"Took {time.time() - t_begin} seconds with torch.save") 59 | dist.barrier() 60 | 61 | rank_0_print("Saving the model with torchsnapshot...") 62 | t_begin = time.time() 63 | app_state = {"model": model} 64 | snapshot = torchsnapshot.Snapshot.take( 65 | path=f"{args.work_dir}/{uuid.uuid4()}", 66 | app_state=app_state, 67 | replicated=["**"], 68 | ) 69 | rank_0_print(f"Snapshot path: {snapshot.path}") 70 | rank_0_print(f"Took {time.time() - t_begin} seconds with torchsnapshot") 71 | -------------------------------------------------------------------------------- /torchsnapshot/dtensor_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import Dict, Iterator, List, Set 11 | 12 | import torch 13 | from torch.distributed._shard.sharded_tensor import ShardedTensor 14 | from torch.distributed._tensor import DTensor, Replicate, Shard 15 | 16 | 17 | def is_sharded(tensor: torch.Tensor) -> bool: 18 | """ 19 | Returns true if tensor is a ShardedTensor or a DTensor that is partially 20 | or fully sharded 21 | """ 22 | if isinstance(tensor, ShardedTensor): 23 | return True 24 | elif isinstance(tensor, DTensor): 25 | for placement in tensor.placements: 26 | if isinstance(placement, Shard): 27 | return True 28 | return False 29 | 30 | 31 | def is_replicated_dtensor(dtensor: DTensor) -> bool: 32 | """ 33 | Returns true if DTensor is fully or partially replicated, false if fully sharded. 34 | """ 35 | for placement in dtensor.placements: 36 | if isinstance(placement, Replicate): 37 | return True 38 | return False 39 | 40 | 41 | class _ReplicatedShards: 42 | """ 43 | Utility class to collect ranks that each DTensor shard is replicated on 44 | in a convenient wrapper that allows efficient querying for all ranks 45 | that contain the same shard as a given rank. 46 | 47 | For example, a DTensor has a device mesh [[0,1,2],[3,4,5]]. It is replicated 48 | across mesh dim 1 and sharded across mesh dim 0. Thus, rank sets {0,1,2} and {3,4,5} 49 | would denote the two replicated shards. Then, the following queries would return: 50 | - 1 -> {0, 1, 2} 51 | - 5 -> {3, 4, 5} 52 | 53 | Attributes: 54 | replicated_ranks_for_shards (List[Set]): List of sets of ranks that each shard is 55 | replicated on. Length of list should be number of shards. 56 | """ 57 | 58 | def __init__(self, replicated_ranks_for_shards: List[Set[int]]) -> None: 59 | self.repranks = replicated_ranks_for_shards 60 | self.lookup: Dict[int, Set[int]] = {} 61 | for rankset in self.repranks: 62 | for rank in rankset: 63 | self.lookup[rank] = rankset 64 | 65 | def get_all_replicated_ranks(self, rank: int) -> Set[int]: 66 | return self.lookup.get(rank, set()) 67 | 68 | def __iter__(self) -> Iterator[Set[int]]: 69 | return iter(self.repranks) 70 | -------------------------------------------------------------------------------- /tests/test_state_dict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import tempfile 11 | import unittest 12 | from typing import cast, Dict 13 | 14 | import torch 15 | import torchsnapshot 16 | from torchsnapshot import Stateful 17 | 18 | 19 | class MyModule(torch.nn.Module): 20 | def __init__(self) -> None: 21 | super().__init__() 22 | self.foo = torch.nn.Parameter(torch.randn(20, 20)) 23 | 24 | 25 | class MyStateful(Stateful): 26 | def __init__(self) -> None: 27 | self.foo = 1 28 | self.bar = "bar" 29 | 30 | def state_dict(self) -> Dict[str, object]: 31 | return {"foo": self.foo, "bar": self.bar} 32 | 33 | def load_state_dict(self, state_dict: Dict[str, object]) -> None: 34 | self.foo = cast(int, state_dict["foo"]) 35 | self.bar = cast(str, state_dict["bar"]) 36 | 37 | 38 | class StateDictTest(unittest.TestCase): 39 | def test_get_state_dict(self) -> None: 40 | my_module = MyModule() 41 | with tempfile.TemporaryDirectory() as path: 42 | torchsnapshot.Snapshot.take( 43 | path=path, 44 | app_state={"my_module": my_module}, 45 | ) 46 | snapshot = torchsnapshot.Snapshot(path) 47 | state_dict = snapshot.get_state_dict_for_key("my_module") 48 | self.assertTrue(torch.allclose(state_dict["foo"], my_module.foo)) 49 | 50 | def test_get_state_dict_with_invalid_key(self) -> None: 51 | my_module = MyModule() 52 | with tempfile.TemporaryDirectory() as path: 53 | torchsnapshot.Snapshot.take( 54 | path=path, 55 | app_state={"my_module": my_module}, 56 | ) 57 | snapshot = torchsnapshot.Snapshot(path) 58 | with self.assertRaisesRegex( 59 | AssertionError, "is absent in both manifest and flattened" 60 | ): 61 | snapshot.get_state_dict_for_key("invalid_key") 62 | 63 | def test_generic_stateful(self) -> None: 64 | my_stateful = MyStateful() 65 | my_stateful.foo = 2 66 | my_stateful.bar = "baz" 67 | with tempfile.TemporaryDirectory() as path: 68 | snapshot = torchsnapshot.Snapshot(path) 69 | snapshot.take(path, app_state={"my_stateful": my_stateful}) 70 | state_dict = snapshot.get_state_dict_for_key("my_stateful") 71 | self.assertDictEqual(state_dict, my_stateful.state_dict()) 72 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: 🐛 Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: > 8 | #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the 9 | existing and past issues](https://github.com/pytorch/torchsnapshot/issues?q=is%3Aissue+sort%3Acreated-desc+). 10 | - type: textarea 11 | attributes: 12 | label: 🐛 Describe the bug 13 | description: | 14 | Please provide a clear and concise description of what the bug is. 15 | 16 | If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: 17 | 18 | ```python 19 | # All necessary imports at the beginning 20 | import torch 21 | import torchsnapshot 22 | 23 | # A succinct reproducing example trimmed down to the essential parts 24 | 25 | ``` 26 | 27 | If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. 28 | 29 | Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. 30 | placeholder: | 31 | A clear and concise description of what the bug is. 32 | 33 | ```python 34 | Sample code to reproduce the problem 35 | ``` 36 | 37 | ``` 38 | The error message you got, with the full traceback. 39 | ``` 40 | validations: 41 | required: true 42 | - type: textarea 43 | attributes: 44 | label: Versions 45 | description: | 46 | Please run the following and paste the output below. Make sure the version numbers of all relevant packages (e.g. torch, torchsnapshot, other domain packages) are included. 47 | ```sh 48 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 49 | # For security purposes, please check the contents of collect_env.py before running it. 50 | python collect_env.py 51 | ``` 52 | validations: 53 | required: true 54 | 55 | - type: markdown 56 | attributes: 57 | value: > 58 | Thanks for contributing 🎉! 59 | -------------------------------------------------------------------------------- /examples/simple_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import argparse 11 | import uuid 12 | from typing import Optional 13 | 14 | import torch 15 | import torchsnapshot 16 | from torchsnapshot.snapshot import Snapshot 17 | from torchsnapshot.stateful import AppState 18 | 19 | NUM_EPOCHS = 4 20 | EPOCH_SIZE = 16 21 | BATCH_SIZE = 8 22 | 23 | 24 | class Model(torch.nn.Module): 25 | def __init__(self) -> None: 26 | super().__init__() 27 | self.layers = torch.nn.Sequential( 28 | torch.nn.Linear(128, 64), 29 | torch.nn.ReLU(), 30 | torch.nn.Linear(64, 32), 31 | torch.nn.ReLU(), 32 | torch.nn.Linear(32, 1), 33 | ) 34 | 35 | def forward(self, X: torch.Tensor) -> torch.Tensor: 36 | return self.layers(X) 37 | 38 | 39 | def main() -> None: 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--work-dir", default="/tmp") 42 | parser.add_argument("--restore-path", default=None) 43 | args: argparse.Namespace = parser.parse_args() 44 | 45 | torch.random.manual_seed(42) 46 | 47 | model = Model() 48 | optim = torch.optim.Adagrad(model.parameters(), lr=0.01) 49 | loss_fn = torch.nn.BCEWithLogitsLoss() 50 | progress = torchsnapshot.StateDict(current_epoch=0) 51 | 52 | # torchsnapshot: define app state 53 | app_state: AppState = { 54 | "rng_state": torchsnapshot.RNGState(), 55 | "model": model, 56 | "optim": optim, 57 | "progress": progress, 58 | } 59 | snapshot: Optional[Snapshot] = None 60 | 61 | # torchsnapshot: restore app state 62 | if args.restore_path: 63 | snapshot = Snapshot(args.restore_path) 64 | print(f"Restoring snapshot from path: {snapshot.path}") 65 | snapshot.restore(app_state) 66 | 67 | while progress["current_epoch"] < NUM_EPOCHS: 68 | for _ in range(EPOCH_SIZE): 69 | X = torch.rand((BATCH_SIZE, 128)) 70 | label = torch.rand((BATCH_SIZE, 1)) 71 | pred = model(X) 72 | loss = loss_fn(pred, label) 73 | 74 | optim.zero_grad() 75 | loss.backward() 76 | optim.step() 77 | 78 | progress["current_epoch"] += 1 79 | 80 | # torchsnapshot: take snapshot 81 | snapshot = torchsnapshot.Snapshot.take( 82 | f"{args.work_dir}/{uuid.uuid4()}", app_state 83 | ) 84 | print(f"Snapshot path: {snapshot.path}") 85 | 86 | 87 | if __name__ == "__main__": 88 | main() # pragma: no cover 89 | -------------------------------------------------------------------------------- /torchsnapshot/storage_plugin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import asyncio 11 | from typing import Any, Dict, Optional 12 | 13 | from importlib_metadata import entry_points 14 | 15 | from .io_types import StoragePlugin 16 | from .storage_plugins.fs import FSStoragePlugin 17 | from .storage_plugins.s3 import S3StoragePlugin 18 | 19 | 20 | def url_to_storage_plugin( 21 | url_path: str, storage_options: Optional[Dict[str, Any]] = None 22 | ) -> StoragePlugin: 23 | """ 24 | Initialize storage plugin from url path. 25 | 26 | Args: 27 | url_path: The url path following the pattern [protocol]://[path]. 28 | The protocol defaults to `fs` if unspecified. 29 | storage_options: Additional keyword options for the StoragePlugin to use. 30 | See each StoragePlugin's documentation for customizations. 31 | 32 | Returns: 33 | The initialized storage plugin. 34 | """ 35 | if "://" in url_path: 36 | protocol, path = url_path.split("://", 1) 37 | if len(protocol) == 0: 38 | protocol = "fs" 39 | else: 40 | protocol, path = "fs", url_path 41 | 42 | if storage_options is None: 43 | storage_options = {} 44 | 45 | # Built-in storage plugins 46 | if protocol == "fs": 47 | return FSStoragePlugin(root=path, storage_options=storage_options) 48 | elif protocol == "s3": 49 | return S3StoragePlugin(root=path, storage_options=storage_options) 50 | elif protocol == "gs": 51 | from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin 52 | 53 | return GCSStoragePlugin(root=path, storage_options=storage_options) 54 | 55 | # Registered storage plugins 56 | eps = entry_points(group="storage_plugins") 57 | registered_plugins = {ep.name: ep for ep in eps} 58 | if protocol in registered_plugins: 59 | entry = registered_plugins[protocol] 60 | factory = entry.load() 61 | plugin = factory(path, storage_options) 62 | if not isinstance(plugin, StoragePlugin): 63 | raise RuntimeError( 64 | f"The factory function for {protocol} ({entry.value}) " 65 | "did not return a StorgePlugin object." 66 | ) 67 | return plugin 68 | else: 69 | raise RuntimeError(f"Unsupported protocol: {protocol}.") 70 | 71 | 72 | def url_to_storage_plugin_in_event_loop( 73 | url_path: str, 74 | event_loop: asyncio.AbstractEventLoop, 75 | storage_options: Optional[Dict[str, Any]] = None, 76 | ) -> StoragePlugin: 77 | async def _url_to_storage_plugin() -> StoragePlugin: 78 | return url_to_storage_plugin(url_path=url_path, storage_options=storage_options) 79 | 80 | return event_loop.run_until_complete(_url_to_storage_plugin()) 81 | -------------------------------------------------------------------------------- /torchsnapshot/io_preparers/object.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[2]: Allow `Any` in type annotations 11 | 12 | import io 13 | import logging 14 | import sys 15 | from concurrent.futures import Executor 16 | from typing import Any, Generic, List, Optional, Tuple, TypeVar 17 | 18 | import torch 19 | 20 | from torchsnapshot.io_types import ( 21 | BufferConsumer, 22 | BufferStager, 23 | BufferType, 24 | Future, 25 | ReadReq, 26 | WriteReq, 27 | ) 28 | from torchsnapshot.manifest import ObjectEntry 29 | 30 | from torchsnapshot.serialization import Serializer 31 | 32 | logger: logging.Logger = logging.getLogger(__name__) 33 | 34 | T = TypeVar("T") 35 | 36 | 37 | class ObjectIOPreparer(Generic[T]): 38 | @staticmethod 39 | def prepare_write( 40 | storage_path: str, 41 | obj: T, 42 | ) -> Tuple[ObjectEntry, List[WriteReq]]: 43 | buffer_stager = ObjectBufferStager(obj=obj) 44 | return ( 45 | ObjectEntry( 46 | location=storage_path, 47 | serializer=Serializer.TORCH_SAVE.value, 48 | obj_type=type(obj).__module__ + "." + type(obj).__name__, 49 | replicated=False, 50 | ), 51 | [WriteReq(path=storage_path, buffer_stager=buffer_stager)], 52 | ) 53 | 54 | @classmethod 55 | def prepare_read( 56 | cls, entry: ObjectEntry, obj_out: Optional[Any] 57 | ) -> Tuple[List[ReadReq], Future[T]]: 58 | # obj_out is only used for memory estimation 59 | fut = Future(obj=obj_out) 60 | buffer_consumer = ObjectBufferConsumer(fut=fut) 61 | return [ 62 | ReadReq( 63 | path=entry.location, 64 | buffer_consumer=buffer_consumer, 65 | ) 66 | ], fut 67 | 68 | 69 | class ObjectBufferStager(BufferStager): 70 | def __init__(self, obj: Any) -> None: 71 | self.obj = obj 72 | 73 | async def stage_buffer(self, executor: Optional[Executor] = None) -> BufferType: 74 | buf = io.BytesIO() 75 | torch.save(self.obj, buf) 76 | return buf.getvalue() 77 | 78 | def get_staging_cost_bytes(self) -> int: 79 | # TODO: this is not accurate 80 | return sys.getsizeof(self.obj) 81 | 82 | 83 | class ObjectBufferConsumer(BufferConsumer, Generic[T]): 84 | def __init__(self, fut: Future[T]) -> None: 85 | self.fut = fut 86 | 87 | async def consume_buffer( 88 | self, buf: bytes, executor: Optional[Executor] = None 89 | ) -> None: 90 | map_location = None if torch.cuda.is_available() else torch.device("cpu") 91 | obj: T = torch.load(io.BytesIO(buf), map_location=map_location) 92 | self.fut.obj = obj 93 | 94 | def get_consuming_cost_bytes(self) -> int: 95 | return sys.getsizeof(self.fut.obj) 96 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Configuration file for the Sphinx documentation builder. 9 | # 10 | # This file only contains a selection of the most common options. For a full 11 | # list see the documentation: 12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | 23 | import pytorch_sphinx_theme 24 | from torchsnapshot import __version__ 25 | 26 | current_dir = os.path.dirname(__file__) 27 | target_dir = os.path.abspath(os.path.join(current_dir, "../..")) 28 | sys.path.insert(0, target_dir) 29 | print(target_dir) 30 | 31 | # -- Project information ----------------------------------------------------- 32 | 33 | project = "TorchSnapshot" 34 | copyright = "2022, Meta" 35 | author = "Meta" 36 | 37 | # The full version, including alpha/beta/rc tags 38 | if os.environ.get("RELEASE_BUILD", None): 39 | version = __version__ 40 | release = __version__ 41 | else: 42 | version = "main (unstable)" 43 | release = "main" 44 | 45 | 46 | # -- General configuration --------------------------------------------------- 47 | 48 | # Add any Sphinx extension module names here, as strings. They can be 49 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 50 | # ones. 51 | extensions = ["sphinx.ext.napoleon", "sphinx.ext.autodoc", "sphinx.ext.intersphinx"] 52 | 53 | # Add any paths that contain templates here, relative to this directory. 54 | templates_path = ["_templates"] 55 | 56 | # List of patterns, relative to source directory, that match files and 57 | # directories to ignore when looking for source files. 58 | # This pattern also affects html_static_path and html_extra_path. 59 | exclude_patterns = [] 60 | 61 | 62 | # -- Options for HTML output ------------------------------------------------- 63 | 64 | # The theme to use for HTML and HTML Help pages. See the documentation for 65 | # a list of builtin themes. 66 | # 67 | html_theme = "pytorch_sphinx_theme" 68 | html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] 69 | 70 | html_theme_options = { 71 | "display_version": True, 72 | } 73 | 74 | # Add any paths that contain custom static files (such as style sheets) here, 75 | # relative to this directory. They are copied after the builtin static files, 76 | # so a file named "default.css" will overwrite the builtin "default.css". 77 | html_static_path = ["_static"] 78 | 79 | # where to find external docs 80 | intersphinx_mapping = { 81 | "torch": ("https://pytorch.org/docs/stable/", None), 82 | } 83 | 84 | add_module_names = False 85 | autodoc_member_order = "bysource" 86 | -------------------------------------------------------------------------------- /torchsnapshot/memoryview_stream.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import io 11 | from typing import Optional 12 | 13 | 14 | class MemoryviewStream(io.IOBase): 15 | # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. 16 | def __init__(self, mv: memoryview) -> None: 17 | # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. 18 | self._mv: memoryview = mv.cast("b") 19 | self._pos = 0 20 | 21 | # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. 22 | def read(self, size: Optional[int] = -1) -> memoryview: 23 | if self.closed: 24 | raise ValueError("read from closed file") 25 | if size is None: 26 | size = -1 27 | else: 28 | try: 29 | size_index = size.__index__ 30 | except AttributeError: 31 | raise TypeError(f"{size!r} is not an integer") 32 | else: 33 | size = size_index() 34 | if size < 0: 35 | size = len(self._mv) 36 | if len(self._mv) <= self._pos: 37 | return memoryview(b"") 38 | newpos = min(len(self._mv), self._pos + size) 39 | b = self._mv[self._pos : newpos] 40 | self._pos = newpos 41 | return b 42 | 43 | # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. 44 | def read1(self, size: int = -1) -> memoryview: 45 | """This is the same as read.""" 46 | return self.read(size) 47 | 48 | def seek(self, pos: int, whence: int = 0) -> int: 49 | if self.closed: 50 | raise ValueError("seek on closed file") 51 | try: 52 | pos_index = pos.__index__ 53 | except AttributeError: 54 | raise TypeError(f"{pos!r} is not an integer") 55 | else: 56 | pos = pos_index() 57 | if whence == 0: 58 | if pos < 0: 59 | raise ValueError("negative seek position %r" % (pos,)) 60 | self._pos = pos 61 | elif whence == 1: 62 | self._pos = max(0, self._pos + pos) 63 | elif whence == 2: 64 | self._pos = max(0, len(self._mv) + pos) 65 | else: 66 | raise ValueError("unsupported whence value") 67 | return self._pos 68 | 69 | def tell(self) -> int: 70 | if self.closed: 71 | raise ValueError("tell on closed file") 72 | return self._pos 73 | 74 | def readable(self) -> bool: 75 | if self.closed: 76 | raise ValueError("I/O operation on closed file.") 77 | return True 78 | 79 | def writable(self) -> bool: 80 | if self.closed: 81 | raise ValueError("I/O operation on closed file.") 82 | return False 83 | 84 | def seekable(self) -> bool: 85 | if self.closed: 86 | raise ValueError("I/O operation on closed file.") 87 | return True 88 | -------------------------------------------------------------------------------- /torchsnapshot/pg_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[2] 11 | 12 | from typing import Any, List, Optional 13 | 14 | import torch.distributed as dist 15 | 16 | 17 | class PGWrapper: 18 | """ 19 | A wrapper around ProcessGroup that allows collectives to be issued in a 20 | consistent fashion regardless of the following scenarios: 21 | 22 | pg is None, distributed is initialized: use WORLD as pg 23 | pg is None, distributed is not initialized: single process app 24 | pg is not None: use pg 25 | """ 26 | 27 | def __init__(self, pg: Optional[dist.ProcessGroup]) -> None: 28 | if pg is None and dist.is_initialized(): 29 | # pyre-ignore 30 | self.pg = dist.group.WORLD 31 | else: 32 | self.pg = pg 33 | 34 | def get_rank(self) -> int: 35 | if self.pg is None: 36 | return 0 37 | return dist.get_rank(group=self.pg) 38 | 39 | def get_world_size(self) -> int: 40 | if self.pg is None: 41 | return 1 42 | return dist.get_world_size(group=self.pg) 43 | 44 | def barrier(self) -> None: 45 | if self.pg is None: 46 | return 47 | dist.barrier(group=self.pg) 48 | 49 | def broadcast_object_list(self, obj_list: List[Any], src: int = 0) -> None: 50 | if self.pg is None: 51 | return 52 | dist.broadcast_object_list(obj_list, src=src, group=self.pg) 53 | 54 | def all_gather_object(self, obj_list: List[Any], obj: Any) -> None: 55 | if self.pg is None: 56 | obj_list[0] = obj 57 | return 58 | dist.all_gather_object(obj_list, obj, group=self.pg) 59 | 60 | def scatter_object_list( 61 | self, 62 | output_list: List[Any], 63 | input_list: Optional[List[Any]], 64 | src: int = 0, 65 | ) -> None: 66 | rank = self.get_rank() 67 | world_size = self.get_world_size() 68 | if rank == src: 69 | if input_list is None: 70 | raise RuntimeError( 71 | "The src rank's input_list for scatter_object_list must not be None." 72 | ) 73 | if len(input_list) != world_size: 74 | raise RuntimeError( 75 | f"The length of input_list {len(input_list)} for scatter_object_list " 76 | f"must be the same as the process group's world size ({world_size})." 77 | ) 78 | else: 79 | input_list = [None] * world_size 80 | 81 | if self.pg is None: 82 | output_list[0] = input_list[0] 83 | return 84 | 85 | # scatter_object_list does not yet support NCCL backend 86 | if dist.get_backend(self.pg) == "nccl": 87 | self.broadcast_object_list(obj_list=input_list, src=src) 88 | output_list[0] = input_list[rank] 89 | return 90 | 91 | dist.scatter_object_list(output_list, input_list, src=src, group=self.pg) 92 | -------------------------------------------------------------------------------- /torchsnapshot/storage_plugins/s3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import io 11 | import os 12 | from typing import Any, Dict, Optional 13 | 14 | from torchsnapshot.io_types import ReadIO, StoragePlugin, WriteIO 15 | from torchsnapshot.memoryview_stream import MemoryviewStream 16 | 17 | 18 | class S3StoragePlugin(StoragePlugin): 19 | def __init__( 20 | self, root: str, storage_options: Optional[Dict[str, Any]] = None 21 | ) -> None: 22 | try: 23 | from aiobotocore.session import get_session # @manual 24 | except ImportError: 25 | raise RuntimeError( 26 | "S3 support requires aiobotocore. " 27 | "Please make sure aiobotocore is installed." 28 | ) 29 | components = root.split("/") 30 | if len(components) < 2: 31 | raise RuntimeError( 32 | "The S3 root path must follow the following pattern: " 33 | f"[BUCKET]/[PATH] (got {root})" 34 | ) 35 | self.bucket: str = components[0] 36 | self.root: str = "/".join(components[1:]) 37 | # pyre-ignore 38 | # TODO: read AWS tokens from storage_options? 39 | self.session = get_session() 40 | 41 | async def write(self, write_io: WriteIO) -> None: 42 | if isinstance(write_io.buf, bytes): 43 | stream = io.BytesIO(write_io.buf) 44 | elif isinstance(write_io.buf, memoryview): 45 | stream = MemoryviewStream(write_io.buf) 46 | else: 47 | raise TypeError(f"Unrecognized buffer type: {type(write_io.buf)}") 48 | 49 | async with self.session.create_client("s3") as client: 50 | key = os.path.join(self.root, write_io.path) 51 | await client.put_object(Bucket=self.bucket, Key=key, Body=stream) 52 | 53 | async def read(self, read_io: ReadIO) -> None: 54 | async with self.session.create_client("s3") as client: 55 | key = os.path.join(self.root, read_io.path) 56 | byte_range = read_io.byte_range 57 | if byte_range is None: 58 | response = await client.get_object(Bucket=self.bucket, Key=key) 59 | else: 60 | response = await client.get_object( 61 | Bucket=self.bucket, 62 | Key=key, 63 | # HTTP Byte Range is inclusive: 64 | # https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 65 | Range=f"bytes={byte_range[0]}-{byte_range[1] - 1}", 66 | ) 67 | async with response["Body"] as stream: 68 | read_io.buf = io.BytesIO(await stream.read()) 69 | 70 | async def delete(self, path: str) -> None: 71 | async with self.session.create_client("s3") as client: 72 | key = os.path.join(self.root, path) 73 | await client.delete_object(Bucket=self.bucket, Key=key) 74 | 75 | async def delete_dir(self, path: str) -> None: 76 | raise NotImplementedError() 77 | 78 | async def close(self) -> None: 79 | pass 80 | -------------------------------------------------------------------------------- /torchsnapshot/io_types.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import abc 11 | import asyncio 12 | import io 13 | from concurrent.futures import Executor 14 | from dataclasses import dataclass, field 15 | from typing import Generic, Optional, Tuple, TypeVar, Union 16 | 17 | from .asyncio_utils import maybe_nested_loop 18 | 19 | 20 | # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. 21 | BufferType = Union[bytes, memoryview] 22 | 23 | 24 | class BufferStager: 25 | @abc.abstractmethod 26 | async def stage_buffer(self, executor: Optional[Executor] = None) -> BufferType: 27 | pass 28 | 29 | @abc.abstractmethod 30 | def get_staging_cost_bytes(self) -> int: 31 | pass 32 | 33 | 34 | @dataclass 35 | class WriteReq: 36 | path: str 37 | buffer_stager: BufferStager 38 | 39 | 40 | class BufferConsumer: 41 | @abc.abstractmethod 42 | async def consume_buffer( 43 | self, buf: bytes, executor: Optional[Executor] = None 44 | ) -> None: 45 | pass 46 | 47 | @abc.abstractmethod 48 | def get_consuming_cost_bytes(self) -> int: 49 | pass 50 | 51 | 52 | @dataclass 53 | class ReadReq: 54 | path: str 55 | buffer_consumer: BufferConsumer 56 | byte_range: Optional[Tuple[int, int]] = None 57 | 58 | 59 | T = TypeVar("T") 60 | 61 | 62 | @dataclass 63 | class Future(Generic[T]): 64 | obj: Optional[T] = None 65 | 66 | 67 | @dataclass 68 | class WriteIO: 69 | path: str 70 | buf: BufferType 71 | 72 | 73 | @dataclass 74 | class ReadIO: 75 | path: str 76 | buf: io.BytesIO = field(default_factory=io.BytesIO) 77 | byte_range: Optional[Tuple[int, int]] = None 78 | 79 | 80 | class StoragePlugin(abc.ABC): 81 | @abc.abstractmethod 82 | async def write(self, write_io: WriteIO) -> None: 83 | pass 84 | 85 | @abc.abstractmethod 86 | async def read(self, read_io: ReadIO) -> None: 87 | pass 88 | 89 | @abc.abstractmethod 90 | async def delete(self, path: str) -> None: 91 | pass 92 | 93 | @abc.abstractmethod 94 | async def delete_dir(self, path: str) -> None: 95 | pass 96 | 97 | @abc.abstractmethod 98 | async def close(self) -> None: 99 | pass 100 | 101 | def sync_write( 102 | self, write_io: WriteIO, event_loop: Optional[asyncio.AbstractEventLoop] = None 103 | ) -> None: 104 | if event_loop is None: 105 | event_loop = maybe_nested_loop() 106 | event_loop.run_until_complete(self.write(write_io=write_io)) 107 | 108 | def sync_read( 109 | self, read_io: ReadIO, event_loop: Optional[asyncio.AbstractEventLoop] = None 110 | ) -> None: 111 | if event_loop is None: 112 | event_loop = maybe_nested_loop() 113 | event_loop.run_until_complete(self.read(read_io=read_io)) 114 | 115 | def sync_close( 116 | self, event_loop: Optional[asyncio.AbstractEventLoop] = None 117 | ) -> None: 118 | if event_loop is None: 119 | event_loop = maybe_nested_loop() 120 | event_loop.run_until_complete(self.close()) 121 | -------------------------------------------------------------------------------- /benchmarks/load_tensor/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-unsafe 9 | 10 | import argparse 11 | import logging 12 | import os 13 | import tempfile 14 | import time 15 | import uuid 16 | 17 | import fsspec 18 | import torch 19 | from torchsnapshot import Snapshot, StateDict 20 | from torchsnapshot.rss_profiler import measure_rss_deltas 21 | 22 | logging.basicConfig(level=logging.INFO) 23 | logger: logging.Logger = logging.getLogger(__name__) 24 | 25 | 26 | TENSOR_DIMS = (50000, 50000) 27 | MEMORY_BUDGET_BYTES = 100 * 1024**2 28 | 29 | 30 | def benchmark_torchsnapshot(work_dir: str, gpu_tensor: torch.Tensor) -> None: 31 | app_state = { 32 | "state": StateDict( 33 | tensor=gpu_tensor, 34 | ) 35 | } 36 | snapshot = Snapshot.take(path=f"{work_dir}/{uuid.uuid4()}", app_state=app_state) 37 | 38 | ts_begin = time.monotonic() 39 | rss_deltas = [] 40 | logger.info("Loading the tensor with torchsnapshot (without memory budget)...") 41 | with measure_rss_deltas(rss_deltas=rss_deltas): 42 | snapshot.read_object(path="0/state/tensor", obj_out=gpu_tensor) 43 | logger.info( 44 | f"Took {time.monotonic() - ts_begin:.2f} seconds. " 45 | f"Peak RSS delta: {max(rss_deltas) // 1024**2}MB" 46 | ) 47 | 48 | ts_begin = time.monotonic() 49 | rss_deltas = [] 50 | logger.info( 51 | f"Loading the tensor with torchsnapshot " 52 | f"(with a {MEMORY_BUDGET_BYTES // 1024**2:.2f}MB memory budget)..." 53 | ) 54 | with measure_rss_deltas(rss_deltas=rss_deltas): 55 | snapshot.read_object( 56 | path="0/state/tensor", 57 | obj_out=gpu_tensor, 58 | memory_budget_bytes=MEMORY_BUDGET_BYTES, 59 | ) 60 | logger.info( 61 | f"Took {time.monotonic() - ts_begin:.2f}. " 62 | f"Peak RSS delta: {max(rss_deltas) // 1024**2}MB" 63 | ) 64 | 65 | 66 | def benchmark_torch_save_fsspec(work_dir: str, gpu_tensor: torch.Tensor) -> None: 67 | path = os.path.join(work_dir, str(uuid.uuid4())) 68 | with fsspec.open(path, "wb") as f: 69 | torch.save(gpu_tensor, f) 70 | 71 | ts_begin = time.monotonic() 72 | rss_deltas = [] 73 | logger.info("Loading the tensor with torch.load()...") 74 | with measure_rss_deltas(rss_deltas=rss_deltas): 75 | with fsspec.open(path, "rb") as f: 76 | loaded = torch.load(f, map_location="cpu") 77 | gpu_tensor.copy_(loaded) 78 | 79 | logger.info( 80 | f"Took {time.monotonic() - ts_begin:.2f}. " 81 | f"Peak RSS delta: {max(rss_deltas) // 1024**2}MB" 82 | ) 83 | 84 | 85 | def main() -> None: 86 | with tempfile.TemporaryDirectory() as path: 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--work-dir", default=str(path)) 89 | args: argparse.Namespace = parser.parse_args() 90 | 91 | device = torch.device("cuda:0") 92 | gpu_tensor = torch.rand(*TENSOR_DIMS, device=device) 93 | benchmark_torch_save_fsspec(work_dir=args.work_dir, gpu_tensor=gpu_tensor) 94 | benchmark_torchsnapshot(work_dir=args.work_dir, gpu_tensor=gpu_tensor) 95 | 96 | 97 | if __name__ == "__main__": 98 | main() # pragma: no cover 99 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import sys 10 | 11 | from datetime import date 12 | from typing import List 13 | 14 | from setuptools import find_packages, setup 15 | 16 | # using exec_file instead of the import to avoid having to install dependencies 17 | # when building the wheel 18 | exec(open("torchsnapshot/version.py").read()) 19 | 20 | 21 | def current_path(file_name: str) -> str: 22 | return os.path.abspath(os.path.join(__file__, os.path.pardir, file_name)) 23 | 24 | 25 | def read_requirements(file_name: str) -> List[str]: 26 | with open(current_path(file_name), encoding="utf8") as f: 27 | return [r for r in f.read().strip().split() if not r.startswith("-")] 28 | 29 | 30 | def get_nightly_version() -> str: 31 | return date.today().strftime("%Y.%m.%d") 32 | 33 | 34 | def parse_args() -> argparse.Namespace: 35 | parser = argparse.ArgumentParser(description="torchsnapshot setup") 36 | parser.add_argument( 37 | "--nightly", 38 | dest="nightly", 39 | action="store_true", 40 | help="enable settings for nightly package build", 41 | default=False, 42 | ) 43 | parser.add_argument( 44 | "--append-to-version", 45 | dest="append_version", 46 | help="append string to end of version number (e.g. a1)", 47 | ) 48 | return parser.parse_known_args() 49 | 50 | 51 | if __name__ == "__main__": 52 | with open(current_path("README.md"), encoding="utf8") as f: 53 | readme: str = f.read() 54 | 55 | custom_args, setup_args = parse_args() 56 | package_name = ( 57 | "torchsnapshot" if not custom_args.nightly else "torchsnapshot-nightly" 58 | ) 59 | version = __version__ if not custom_args.nightly else get_nightly_version() 60 | if custom_args.append_version: 61 | version = f"{version}{custom_args.append_version}" 62 | 63 | print(f"using package_name={package_name}, version={version}") 64 | 65 | sys.argv = [sys.argv[0]] + setup_args 66 | 67 | setup( 68 | name=package_name, 69 | version=version, 70 | author="torchsnapshot team", 71 | author_email="yifu@fb.com", 72 | description="A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind.", 73 | long_description=readme, 74 | long_description_content_type="text/markdown", 75 | url="https://github.com/pytorch/torchsnapshot", 76 | license="BSD-3", 77 | keywords=["pytorch", "snapshot", "checkpoint"], 78 | python_requires=">=3.7", 79 | install_requires=read_requirements("requirements.txt"), 80 | packages=find_packages(), 81 | package_data={"torchsnapshot": ["py.typed"]}, 82 | zip_safe=True, 83 | classifiers=[ 84 | "Development Status :: 2 - Pre-Alpha", 85 | "Intended Audience :: Developers", 86 | "Intended Audience :: Science/Research", 87 | "License :: OSI Approved :: BSD License", 88 | "Programming Language :: Python :: 3", 89 | "Programming Language :: Python :: 3.7", 90 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 91 | ], 92 | extras_require={"dev": read_requirements("dev-requirements.txt")}, 93 | ) 94 | -------------------------------------------------------------------------------- /tests/gpu_tests/test_manifest_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import torch 11 | 12 | import torch.distributed as dist 13 | from torch.testing._internal.common_utils import ( 14 | instantiate_parametrized_tests, 15 | parametrize, 16 | ) 17 | from torch.testing._internal.distributed._tensor.common_dtensor import ( 18 | DTensorTestBase, 19 | skip_if_lt_x_gpu, 20 | with_comms, 21 | ) 22 | from torchsnapshot.manifest_utils import ( 23 | _get_replicated_ranks, 24 | is_partially_replicated_entry, 25 | ) 26 | from torchsnapshot.serialization import NCCL_SUPPORTED_DTYPES 27 | from torchsnapshot.test_utils import _dtensor_test_case, _tensor_test_case 28 | 29 | WORLD_SIZE = 4 30 | 31 | 32 | @instantiate_parametrized_tests 33 | class TestManifestUtils(DTensorTestBase): 34 | @parametrize("dtype", NCCL_SUPPORTED_DTYPES) 35 | @skip_if_lt_x_gpu(WORLD_SIZE) 36 | # pyre-fixme[56]: While applying decorator 37 | # `torch.testing._internal.distributed._tensor.common_dtensor.with_comms`: For 1st 38 | # argument expected `(object) -> object` but got `(self: TestManifestUtils, dtype: 39 | # dtype) -> Any`. 40 | @with_comms 41 | # pyre-fixme[3]: Return type must be annotated. 42 | def test_get_replicated_ranks(self, dtype: torch.dtype): 43 | logical_path = "foo" 44 | tensor, entry, wrs = _dtensor_test_case( 45 | dtype=dtype, 46 | shape=[16, 16], 47 | logical_path=logical_path, 48 | rank=dist.get_rank(), 49 | replicated=True, 50 | ) 51 | # pyre-fixme[6]: For 1st argument expected `DTensorEntry` but got `Entry`. 52 | actual_repranks = _get_replicated_ranks(entry=entry) 53 | expected_repranks = [[0, 2], [1, 3]] 54 | assert actual_repranks == expected_repranks 55 | 56 | @parametrize("dtype", NCCL_SUPPORTED_DTYPES) 57 | @skip_if_lt_x_gpu(WORLD_SIZE) 58 | # pyre-fixme[56]: While applying decorator 59 | # `torch.testing._internal.distributed._tensor.common_dtensor.with_comms`: For 1st 60 | # argument expected `(object) -> object` but got `(self: TestManifestUtils, dtype: 61 | # dtype) -> Any`. 62 | @with_comms 63 | # pyre-fixme[3]: Return type must be annotated. 64 | def test_is_partially_replicated(self, dtype: torch.dtype): 65 | logical_path = "foo" 66 | tensor, entry, wrs = _dtensor_test_case( 67 | dtype=dtype, 68 | shape=[16, 16], 69 | logical_path=logical_path, 70 | rank=dist.get_rank(), 71 | replicated=True, 72 | ) 73 | assert is_partially_replicated_entry(entry=entry) 74 | 75 | # Only replicated 76 | # pyre-fixme[16]: `Entry` has no attribute `dim_map`. 77 | entry.dim_map = [-1, -1] 78 | assert not is_partially_replicated_entry(entry=entry) 79 | 80 | # Only sharded 81 | entry.dim_map = [0, 1] 82 | assert not is_partially_replicated_entry(entry=entry) 83 | 84 | tensor, entry, wrs = _tensor_test_case( 85 | dtype=dtype, 86 | shape=[16, 16], 87 | logical_path=logical_path, 88 | rank=dist.get_rank(), 89 | replicated=False, 90 | ) 91 | 92 | assert not is_partially_replicated_entry(entry=entry) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TorchSnapshot (Beta Release) 2 | 3 |

4 | build status 5 | pypi version 6 | conda version 7 | pypi nightly version 8 | codecov 9 | bsd license 10 | 11 | 12 | A performant, memory-efficient checkpointing library for PyTorch applications, designed with large, complex distributed workloads in mind. 13 | 14 | 15 | ## Install 16 | 17 | Requires Python >= 3.8 and PyTorch >= 2.0.0 18 | 19 | From pip: 20 | 21 | ```bash 22 | # Stable 23 | pip install torchsnapshot 24 | # Or, using conda 25 | conda install -c conda-forge torchsnapshot 26 | 27 | # Nightly 28 | pip install --pre torchsnapshot-nightly 29 | ``` 30 | 31 | 32 | From source: 33 | 34 | ```bash 35 | git clone https://github.com/pytorch/torchsnapshot 36 | cd torchsnapshot 37 | pip install -r requirements.txt 38 | python setup.py install 39 | ``` 40 | 41 | ## Why TorchSnapshot 42 | 43 | **Performance** 44 | - TorchSnapshot provides a fast checkpointing implementation employing various optimizations, including zero-copy serialization for most tensor types, overlapped device-to-host copy and storage I/O, parallelized storage I/O. 45 | - TorchSnapshot greatly speeds up checkpointing for DistributedDataParallel workloads by distributing the write load across all ranks ([benchmark](https://github.com/pytorch/torchsnapshot/tree/main/benchmarks/ddp)). 46 | - When host memory is abundant, TorchSnapshot allows training to resume before all storage I/O completes, reducing the time blocked by checkpoint saving. 47 | 48 | **Memory Usage** 49 | - TorchSnapshot's memory usage adapts to the host's available resources, greatly reducing the chance of out-of-memory issues when saving and loading checkpoints. 50 | - TorchSnapshot supports efficient random access to individual objects within a snapshot, even when the snapshot is stored in a cloud object storage. 51 | 52 | **Usability** 53 | - Simple APIs that are consistent between distributed and non-distributed workloads. 54 | - Out of the box integration with commonly used cloud object storage systems. 55 | - Automatic resharding (elasticity) on world size change for supported workloads ([more details](https://pytorch.org/torchsnapshot/getting_started.html#elasticity-experimental)). 56 | 57 | **Security** 58 | - Secure tensor serialization without pickle dependency [WIP]. 59 | 60 | 61 | ## Getting Started 62 | 63 | ```python 64 | from torchsnapshot import Snapshot 65 | 66 | # Taking a snapshot 67 | app_state = {"model": model, "optimizer": optimizer} 68 | snapshot = Snapshot.take(path="/path/to/snapshot", app_state=app_state) 69 | 70 | # Restoring from a snapshot 71 | snapshot.restore(app_state=app_state) 72 | ``` 73 | 74 | See the [documentation](https://pytorch.org/torchsnapshot/main/getting_started.html) for more details. 75 | 76 | 77 | ## License 78 | 79 | torchsnapshot is BSD licensed, as found in the [LICENSE](LICENSE) file. 80 | -------------------------------------------------------------------------------- /torchsnapshot/manifest_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import itertools 11 | from typing import List, Set 12 | 13 | import numpy as np 14 | 15 | from torchsnapshot.manifest import ( 16 | DictEntry, 17 | DTensorEntry, 18 | Entry, 19 | ListEntry, 20 | OrderedDictEntry, 21 | ShardedTensorEntry, 22 | ) 23 | 24 | 25 | def is_dict_entry(entry: Entry) -> bool: 26 | return isinstance(entry, (DictEntry, OrderedDictEntry)) 27 | 28 | 29 | def is_replicated_entry(entry: Entry) -> bool: 30 | """ 31 | Returns true if entry is partially or fully replicated. 32 | """ 33 | return is_fully_replicated_entry(entry) or is_partially_replicated_entry(entry) 34 | 35 | 36 | def is_container_entry(entry: Entry) -> bool: 37 | return isinstance(entry, (ListEntry, DictEntry, OrderedDictEntry)) 38 | 39 | 40 | def is_sharded_entry(entry: Entry) -> bool: 41 | if isinstance(entry, DTensorEntry): 42 | return any(dims[0] != -1 for dims in entry.dim_map) 43 | return isinstance(entry, ShardedTensorEntry) 44 | 45 | 46 | def is_fully_replicated_entry(entry: Entry) -> bool: 47 | """ 48 | Return True for an entry that is fully replicated on all ranks 49 | """ 50 | if isinstance(entry, DTensorEntry): 51 | return all(dims[0] == -1 for dims in entry.dim_map) 52 | if not hasattr(entry, "replicated"): 53 | return False 54 | # pyre-ignore 55 | return entry.replicated 56 | 57 | 58 | def is_partially_replicated_entry(entry: Entry) -> bool: 59 | """ 60 | Return True for an entry that is both sharded and replicated, which only applies 61 | to DTensorEntries 62 | """ 63 | if isinstance(entry, DTensorEntry): 64 | return ( 65 | 0 < sum(1 for dims in entry.dim_map if dims[0] == -1) < len(entry.dim_map) 66 | ) 67 | return False 68 | 69 | 70 | def _get_replicated_ranks( 71 | entry: DTensorEntry, 72 | ) -> List[Set[int]]: 73 | """ 74 | Given a DTensorEntry across ranks, return a list of rank sets 75 | where each set denotes a replicated shard. 76 | """ 77 | 78 | mesh = entry.mesh 79 | mesh_shape = np.array(entry.mesh).shape 80 | dim_map = entry.dim_map 81 | shard_dims = [] 82 | for dims in dim_map: 83 | if dims[0] != -1: 84 | shard_dims.extend(dims) 85 | replicate_dims = set(range(len(mesh_shape))) - set(shard_dims) 86 | 87 | # Programmatically generate slices of the device mesh that represent 88 | # sets of replicated ranks. Iterate across sharded dims, taking the 89 | # whole slice of the replicated dim each time. 90 | # 91 | # Example: 92 | # 3D mesh = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], replicate on dim 0, shard on dims 1, 2 93 | # The sets of replicated ranks returned is [[0, 4], [1, 5], [2, 6], [3, 7]] 94 | slices_for_dims = [] 95 | mesh_shape = np.array(mesh).shape 96 | for dim, size in enumerate(mesh_shape): 97 | if dim in replicate_dims: 98 | # Take entire dimension 99 | slices_for_dims.append([slice(None)]) 100 | elif dim in shard_dims: 101 | # Take one element at a time 102 | slices_for_dims.append([slice(i, i + 1) for i in range(size)]) 103 | 104 | slice_combinations = list(itertools.product(*slices_for_dims)) 105 | # Gymnastics to take advantage of numpy's multidimensional slicing and squeeze 106 | return [set(np.array(mesh)[s].flatten()) for s in slice_combinations] 107 | -------------------------------------------------------------------------------- /tests/gpu_tests/test_snapshot_fsdp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import os 11 | from pathlib import Path 12 | 13 | import pytest 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType 18 | from torchsnapshot import Snapshot 19 | from torchsnapshot.test_utils import check_state_dict_eq, run_with_pet 20 | from torchsnapshot.tricks.fsdp import FSDPOptimizerAdapter 21 | 22 | 23 | def _create_fsdp_model( 24 | seed: int, 25 | device: torch.device, 26 | state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT, 27 | ) -> torch.nn.Module: 28 | torch.manual_seed(seed) 29 | model = torch.nn.Sequential( 30 | torch.nn.Linear(128, 64), 31 | torch.nn.Linear(64, 32), 32 | torch.nn.Linear(32, 16), 33 | ) 34 | 35 | fsdp_model = FSDP( 36 | module=model, 37 | device_id=device, 38 | ) 39 | FSDP.set_state_dict_type(fsdp_model, state_dict_type) 40 | return fsdp_model 41 | 42 | 43 | @pytest.mark.skipif( 44 | not torch.cuda.is_available(), reason="The test requires GPUs to run." 45 | ) 46 | # pyre-fixme[56]: Pyre was not able to infer the type of the decorator 47 | # `pytest.mark.gpu_only`. 48 | @pytest.mark.gpu_only 49 | @pytest.mark.usefixtures("toggle_batching") 50 | # Sharded state dict will test ShardedTensors, full tests Tensors 51 | @pytest.mark.parametrize( 52 | "state_dict_type", [StateDictType.FULL_STATE_DICT, StateDictType.SHARDED_STATE_DICT] 53 | ) 54 | @run_with_pet(nproc=2) 55 | def test_model_and_optim_fsdp(tmp_path: Path, state_dict_type: StateDictType) -> None: 56 | dist.init_process_group(backend="nccl") 57 | local_rank = int(os.environ["LOCAL_RANK"]) 58 | device = torch.device(f"cuda:{local_rank}") 59 | torch.cuda.set_device(device) 60 | 61 | foo_fsdp = _create_fsdp_model( 62 | seed=42, 63 | device=device, 64 | state_dict_type=state_dict_type, 65 | ) 66 | bar_fsdp = _create_fsdp_model( 67 | seed=777 + dist.get_rank(), 68 | device=device, 69 | state_dict_type=state_dict_type, 70 | ) 71 | 72 | assert not check_state_dict_eq(foo_fsdp.state_dict(), bar_fsdp.state_dict()) 73 | 74 | # Need to step and zero_grad in order to initialize all the optimizer parameters 75 | foo_optim = torch.optim.AdamW(foo_fsdp.parameters(), lr=0.01) 76 | foo_optim.step(closure=None) 77 | foo_optim.zero_grad(set_to_none=True) 78 | 79 | bar_optim = torch.optim.AdamW(bar_fsdp.parameters(), lr=0.02) 80 | bar_optim.step(closure=None) 81 | bar_optim.zero_grad(set_to_none=True) 82 | 83 | # pyre-fixme[6]: For 1st argument expected `FullyShardedDataParallel` but got 84 | # `Module`. 85 | foo_fsdp_optim = FSDPOptimizerAdapter(foo_fsdp, foo_optim) 86 | # pyre-fixme[6]: For 1st argument expected `FullyShardedDataParallel` but got 87 | # `Module`. 88 | bar_fsdp_optim = FSDPOptimizerAdapter(bar_fsdp, bar_optim) 89 | 90 | assert not check_state_dict_eq( 91 | foo_fsdp_optim.state_dict(), bar_fsdp_optim.state_dict() 92 | ) 93 | 94 | foo_app_state = {"foo": foo_fsdp, "optim": foo_fsdp_optim} 95 | 96 | snapshot = Snapshot.take(str(tmp_path), foo_app_state) 97 | snapshot.restore({"foo": bar_fsdp, "optim": bar_fsdp_optim}) 98 | 99 | assert check_state_dict_eq(foo_fsdp_optim.state_dict(), bar_fsdp_optim.state_dict()) 100 | assert check_state_dict_eq(foo_fsdp.state_dict(), bar_fsdp.state_dict()) 101 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /torchsnapshot/tricks/deepspeed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import logging 9 | from types import MethodType 10 | from typing import Any, Dict 11 | 12 | from deepspeed import DeepSpeedEngine, version 13 | from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 14 | from torchsnapshot import Snapshot, StateDict 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def _save_zero_checkpoint(self, save_path: str, tag: str) -> None: 20 | app_state = { 21 | "optimizer": self.optimizer, 22 | "objects": StateDict(ds_config=self.config, ds_version=version), 23 | } 24 | Snapshot.async_take(path=save_path, app_state=app_state) 25 | # TODO: demonstrate how torchsnapshot can help with zero_to_fp32.py 26 | if self.global_rank == 0: 27 | self._copy_recovery_script(save_path) 28 | 29 | 30 | class Zero3StateAdapter: 31 | """ 32 | Adapts DeepSpeedZeroOptimizer_Stage3 to expose conventional .state_dict() 33 | and .load_state_dict(). 34 | 35 | Usage: 36 | 37 | >>> app_state = { 38 | >>> "optimizer": Zero3StateAdapter(zero3_optimizer), 39 | >>> } 40 | >>> Snapshot.take(path=path, app_state=app_state) 41 | """ 42 | 43 | def __init__( 44 | self, 45 | optimizer: DeepSpeedZeroOptimizer_Stage3, 46 | load_optimizer_states: bool = True, 47 | load_from_fp32_weights: bool = False, 48 | ) -> None: 49 | self.optimizer = optimizer 50 | self.load_optimizer_state = load_optimizer_states 51 | self.load_from_fp32_weights = load_from_fp32_weights 52 | 53 | def state_dict(self) -> Dict[str, Any]: 54 | return self.optimizer.state_dict() 55 | 56 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 57 | self.optimizer._rigid_load_state_dict( 58 | state_dict=state_dict, load_optimizer_states=self.load_optimizer_state 59 | ) 60 | if len(self.optimizer.persistent_parameters) > 0: 61 | self.optimizer.persistent_parameters[0].partition( 62 | self.optimizer.persistent_parameters 63 | ) 64 | self.optimizer.persistent_parameters[0].all_gather( 65 | self.optimizer.persistent_parameters 66 | ) 67 | 68 | 69 | def _load_zero_checkpoint( 70 | self, 71 | load_dir: str, 72 | tag: str, 73 | load_optimizer_states: bool = True, 74 | ) -> bool: 75 | snapshot = Snapshot(path=load_dir) 76 | app_state = { 77 | "optimizer": Zero3StateAdapter( 78 | optimizer=self.optimizer, 79 | load_optimizer_states=load_optimizer_states, 80 | load_from_fp32_weights=self.zero_load_from_fp32_weights(), 81 | ) 82 | } 83 | snapshot.restore(app_state=app_state) 84 | return True 85 | 86 | 87 | def patch_engine_to_use_torchsnapshot(engine: DeepSpeedEngine) -> None: 88 | """ 89 | Patch a DeepSpeedEngine to use torchsnapshot to save its optimizer states. 90 | 91 | Args: 92 | engine: The DeepSpeedEngine to patch. 93 | 94 | WARNING: This function is not a proper integration with deepspeed. Its 95 | purpose is to demonstrate/benchmark a potential integration. Only use it at 96 | your own risk. 97 | """ 98 | if not isinstance(engine.optimizer, DeepSpeedZeroOptimizer_Stage3): 99 | raise RuntimeError( 100 | "patch_engine_to_use_torchsnapshot only supports DeepSpeedZeroOptimizer_Stage3." 101 | ) 102 | engine._save_zero_checkpoint = MethodType(_save_zero_checkpoint, engine) 103 | engine._load_zero_checkpoint = MethodType(_load_zero_checkpoint, engine) 104 | -------------------------------------------------------------------------------- /tests/test_s3_storage_plugin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[56] 11 | 12 | import io 13 | import logging 14 | import os 15 | import random 16 | import uuid 17 | 18 | import pytest 19 | 20 | import torch 21 | import torchsnapshot 22 | from torchsnapshot.io_types import ReadIO, WriteIO 23 | from torchsnapshot.storage_plugins.s3 import S3StoragePlugin 24 | 25 | logger: logging.Logger = logging.getLogger(__name__) 26 | 27 | _TEST_BUCKET = "torchsnapshot-test" 28 | _TENSOR_SZ = int(1_000_000 / 4) 29 | 30 | 31 | @pytest.fixture 32 | def s3_health_check() -> None: 33 | """ 34 | S3 access can be flaky on Github Action. Only run the tests if the health 35 | check passes. 36 | """ 37 | try: 38 | import boto3 # pyre-ignore # @manual 39 | 40 | s3 = boto3.client("s3") 41 | data = b"hello" 42 | key = str(uuid.uuid4()) 43 | s3.upload_fileobj(io.BytesIO(data), _TEST_BUCKET, key) 44 | s3.download_fileobj(_TEST_BUCKET, key, io.BytesIO()) 45 | except Exception as e: 46 | # pyre-ignore[29] 47 | pytest.skip(f"Skipping the test because s3 health check failed: {e}") 48 | 49 | 50 | @pytest.mark.s3_integration_test 51 | @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") 52 | @pytest.mark.usefixtures("s3_health_check") 53 | def test_s3_read_write_via_snapshot() -> None: 54 | path = f"s3://{_TEST_BUCKET}/{uuid.uuid4()}" 55 | logger.info(path) 56 | 57 | tensor = torch.rand((_TENSOR_SZ,)) 58 | app_state = {"state": torchsnapshot.StateDict(tensor=tensor)} 59 | snapshot = torchsnapshot.Snapshot.take(path=path, app_state=app_state) 60 | 61 | app_state["state"]["tensor"] = torch.rand((_TENSOR_SZ,)) 62 | assert not torch.allclose(tensor, app_state["state"]["tensor"]) 63 | 64 | snapshot.restore(app_state) 65 | assert torch.allclose(tensor, app_state["state"]["tensor"]) 66 | 67 | 68 | @pytest.mark.s3_integration_test 69 | @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") 70 | @pytest.mark.usefixtures("s3_health_check") 71 | @pytest.mark.asyncio 72 | async def test_s3_write_read_delete() -> None: 73 | path = f"{_TEST_BUCKET}/{uuid.uuid4()}" 74 | logger.info(path) 75 | plugin = S3StoragePlugin(root=path) 76 | 77 | tensor = torch.rand((_TENSOR_SZ,)) 78 | buf = io.BytesIO() 79 | torch.save(tensor, buf) 80 | write_io = WriteIO(path="tensor", buf=buf.getbuffer()) 81 | 82 | await plugin.write(write_io=write_io) 83 | 84 | read_io = ReadIO(path="tensor") 85 | await plugin.read(read_io=read_io) 86 | loaded = torch.load(read_io.buf) 87 | assert torch.allclose(tensor, loaded) 88 | 89 | await plugin.delete(path="tensor") 90 | await plugin.close() 91 | 92 | 93 | @pytest.mark.s3_integration_test 94 | @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_AWS_TEST") is None, reason="") 95 | @pytest.mark.usefixtures("s3_health_check") 96 | @pytest.mark.asyncio 97 | async def test_s3_ranged_read() -> None: 98 | path = f"{_TEST_BUCKET}/{uuid.uuid4()}" 99 | logger.info(path) 100 | plugin = S3StoragePlugin(root=path) 101 | 102 | buf = bytes(random.getrandbits(8) for _ in range(2000)) 103 | write_io = WriteIO(path="rand_bytes", buf=memoryview(buf)) 104 | 105 | await plugin.write(write_io=write_io) 106 | 107 | read_io = ReadIO(path="rand_bytes", byte_range=(100, 200)) 108 | await plugin.read(read_io=read_io) 109 | assert len(read_io.buf.getvalue()) == 100 110 | assert read_io.buf.getvalue(), buf[100:200] 111 | 112 | await plugin.close() 113 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # MacOS 86 | .DS_Store 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /tests/test_replication_glob.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from pathlib import Path 11 | from typing import Any, Dict, List 12 | 13 | import pytest 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torchsnapshot import Snapshot 18 | from torchsnapshot.manifest_utils import is_fully_replicated_entry 19 | from torchsnapshot.test_utils import run_with_pet 20 | 21 | _WORLD_SIZE: int = 2 22 | 23 | 24 | class _TestStateful: 25 | def state_dict(self) -> Dict[str, Any]: 26 | return { 27 | "foo": torch.Tensor(1), 28 | "bar": torch.Tensor(1), 29 | "baz": [torch.Tensor(1), torch.Tensor(1)], 30 | "qux": {"quux": torch.Tensor(1), "quuz": torch.Tensor(1)}, 31 | } 32 | 33 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 34 | raise NotImplementedError() 35 | 36 | 37 | @pytest.mark.parametrize( 38 | "replication_globs, expected_replicated_paths", 39 | [ 40 | ( 41 | [["**"]] * _WORLD_SIZE, 42 | [ 43 | "0/my_stateful/foo", 44 | "0/my_stateful/bar", 45 | "0/my_stateful/baz/0", 46 | "0/my_stateful/baz/1", 47 | "0/my_stateful/qux/quux", 48 | "0/my_stateful/qux/quuz", 49 | ], 50 | ), 51 | ( 52 | [["my_stateful/baz/*", "my_stateful/qux/*"]] * _WORLD_SIZE, 53 | [ 54 | "0/my_stateful/baz/0", 55 | "0/my_stateful/baz/1", 56 | "0/my_stateful/qux/quux", 57 | "0/my_stateful/qux/quuz", 58 | ], 59 | ), 60 | ( 61 | [ 62 | ["my_stateful/foo", "my_stateful/qux/*"], 63 | ["my_stateful/foo", "my_stateful/bax/*"], 64 | ], 65 | ["0/my_stateful/foo"], 66 | ), 67 | ], 68 | ) 69 | @run_with_pet(nproc=_WORLD_SIZE) 70 | def test_replication_glob( 71 | replication_globs: List[List[str]], 72 | expected_replicated_paths: List[str], 73 | tmp_path: Path, 74 | ) -> None: 75 | dist.init_process_group(backend="gloo") 76 | app_state = {"my_stateful": _TestStateful()} 77 | snapshot = Snapshot.take( 78 | path=str(tmp_path), 79 | app_state=app_state, 80 | replicated=replication_globs[dist.get_rank()], 81 | ) 82 | replicated_paths = [ 83 | path 84 | for path, entry in snapshot.get_manifest().items() 85 | if is_fully_replicated_entry(entry) 86 | ] 87 | assert set(replicated_paths) == set(expected_replicated_paths) 88 | 89 | 90 | @pytest.mark.parametrize( 91 | "global_replicated, expected_replicated", 92 | [ 93 | ( 94 | [ 95 | ["my_stateful/foo", "my_stateful/qux"], 96 | ["my_stateful/foo", "my_stateful/qux"], 97 | ], 98 | ["my_stateful/foo", "my_stateful/qux"], 99 | ), 100 | ( 101 | [ 102 | ["my_stateful/foo", "my_stateful/qux"], 103 | ["my_stateful/foo", "my_stateful/baz"], 104 | ], 105 | ["my_stateful/foo"], 106 | ), 107 | ( 108 | [ 109 | ["my_stateful/foo"], 110 | ["my_stateful/qux"], 111 | ], 112 | [], 113 | ), 114 | ], 115 | ) 116 | def test_coalesce_replicated( 117 | global_replicated: List[List[str]], 118 | expected_replicated: List[str], 119 | ) -> None: 120 | assert sorted( 121 | Snapshot._coalesce_replicated(global_replicated=global_replicated) 122 | ) == sorted(expected_replicated) 123 | -------------------------------------------------------------------------------- /tests/test_sharded_tensor_resharding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[21, 56]: ignore pytest undefine import and invalid decoration 11 | import itertools 12 | import uuid 13 | from typing import cast, Generator, List 14 | 15 | import pytest 16 | import torch 17 | import torch.distributed as dist 18 | from torch.distributed._shard import sharded_tensor 19 | from torch.distributed._shard.metadata import ShardMetadata 20 | from torch.distributed._shard.sharding_spec import ( 21 | ChunkShardingSpec, 22 | EnumerableShardingSpec, 23 | ShardingSpec, 24 | ) 25 | from torchsnapshot.io_preparer import ShardedTensorIOPreparer 26 | 27 | 28 | @pytest.fixture 29 | def dummy_pg() -> Generator[None, None, None]: 30 | dist.init_process_group( 31 | backend="gloo", init_method=f"file:///tmp/{uuid.uuid4()}", rank=0, world_size=1 32 | ) 33 | yield 34 | dist.destroy_process_group() 35 | 36 | 37 | def sharding_specs() -> List[ShardingSpec]: 38 | specs: List[ShardingSpec] = [ 39 | # pyre-ignore 40 | ChunkShardingSpec( 41 | dim=dim, 42 | placements=[ 43 | "rank:0/cpu", 44 | ] 45 | * num_shards, 46 | ) 47 | for dim, num_shards in itertools.product([0, 1], [3, 5]) 48 | ] 49 | specs.append( 50 | EnumerableShardingSpec( 51 | [ 52 | ShardMetadata( 53 | shard_offsets=[0, 0], 54 | shard_sizes=[64, 64], 55 | placement="rank:0/cpu", 56 | ), 57 | ShardMetadata( 58 | shard_offsets=[0, 64], 59 | shard_sizes=[64, 64], 60 | placement="rank:0/cpu", 61 | ), 62 | ShardMetadata( 63 | shard_offsets=[64, 0], 64 | shard_sizes=[64, 64], 65 | placement="rank:0/cpu", 66 | ), 67 | ShardMetadata( 68 | shard_offsets=[64, 64], 69 | shard_sizes=[64, 64], 70 | placement="rank:0/cpu", 71 | ), 72 | ] 73 | ) 74 | ) 75 | return specs 76 | 77 | 78 | @pytest.mark.asyncio 79 | @pytest.mark.parametrize("src_spec", sharding_specs()) 80 | @pytest.mark.parametrize("dst_spec", sharding_specs()) 81 | async def test_sharded_tensor_resharding( 82 | src_spec: ShardingSpec, dst_spec: ShardingSpec, dummy_pg: None 83 | ) -> None: 84 | # Randomly initialize two sharded tensors 85 | src = sharded_tensor.empty(src_spec, 128, 128) 86 | dst = sharded_tensor.empty(dst_spec, 128, 128) 87 | for st in [src, dst]: 88 | for shard in st.local_shards(): 89 | shard.tensor.random_() 90 | 91 | # Verify that they are not the same 92 | src_gathered = torch.empty(128, 128) 93 | dst_gathered = torch.empty(128, 128) 94 | src.gather(out=src_gathered) 95 | dst.gather(out=dst_gathered) 96 | assert not torch.allclose(src_gathered, dst_gathered) 97 | 98 | entry, write_reqs = ShardedTensorIOPreparer.prepare_write( 99 | storage_path="src", obj=src 100 | ) 101 | read_reqs, _ = ShardedTensorIOPreparer.prepare_read(entry=entry, obj_out=dst) 102 | 103 | # Fulfill the dst's read requests with src's write requests 104 | path_to_buf = {wr.path: await wr.buffer_stager.stage_buffer() for wr in write_reqs} 105 | for rr in read_reqs: 106 | await rr.buffer_consumer.consume_buffer(buf=cast(bytes, path_to_buf[rr.path])) 107 | 108 | src.gather(out=src_gathered) 109 | dst.gather(out=dst_gathered) 110 | assert torch.allclose(src_gathered, dst_gathered) 111 | -------------------------------------------------------------------------------- /examples/ddp_example.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | import argparse 10 | import os 11 | import uuid 12 | from typing import Dict, Optional 13 | 14 | import torch 15 | 16 | import torch.distributed as dist 17 | import torch.distributed.launcher as pet 18 | import torchsnapshot 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | 21 | from torchsnapshot import Snapshot, Stateful 22 | 23 | NUM_EPOCHS = 4 24 | EPOCH_SIZE = 16 25 | BATCH_SIZE = 8 26 | 27 | 28 | class Model(torch.nn.Module): 29 | def __init__(self) -> None: 30 | super().__init__() 31 | self.layers = torch.nn.Sequential( 32 | torch.nn.Linear(128, 64), 33 | torch.nn.ReLU(), 34 | torch.nn.Linear(64, 32), 35 | torch.nn.ReLU(), 36 | torch.nn.Linear(32, 1), 37 | ) 38 | 39 | def forward(self, X: torch.Tensor) -> torch.Tensor: 40 | return self.layers(X) 41 | 42 | 43 | def train( 44 | work_dir: str, 45 | snapshot_path: Optional[str] = None, 46 | ) -> None: 47 | # initialize the process group 48 | dist.init_process_group(backend="nccl") 49 | local_rank = int(os.environ["LOCAL_RANK"]) 50 | device = torch.device(f"cuda:{local_rank}") 51 | torch.cuda.set_device(device) 52 | 53 | torch.manual_seed(42) 54 | 55 | print(f"Running basic DDP example on device {device}.") 56 | model = Model().to(device) 57 | 58 | # DDP wrapper around model 59 | ddp_model = DDP(model) 60 | 61 | optim = torch.optim.Adagrad(ddp_model.parameters(), lr=0.01) 62 | loss_fn = torch.nn.BCEWithLogitsLoss() 63 | progress = torchsnapshot.StateDict(current_epoch=0) 64 | 65 | # torchsnapshot: define app state 66 | app_state: Dict[str, Stateful] = { 67 | "rng_state": torchsnapshot.RNGState(), 68 | "model": ddp_model, 69 | "optim": optim, 70 | "progress": progress, 71 | } 72 | snapshot: Optional[Snapshot] = None 73 | 74 | if snapshot_path is not None: 75 | # torchsnapshot: restore app state 76 | snapshot = torchsnapshot.Snapshot(path=snapshot_path) 77 | print(f"Restoring snapshot from path: {snapshot.path}") 78 | snapshot.restore(app_state=app_state) 79 | 80 | while progress["current_epoch"] < NUM_EPOCHS: 81 | for _ in range(EPOCH_SIZE): 82 | X = torch.rand((BATCH_SIZE, 128), device=device) 83 | pred = ddp_model(X) 84 | label = torch.rand((BATCH_SIZE, 1), device=device) 85 | loss = loss_fn(pred, label) 86 | 87 | optim.zero_grad() 88 | loss.backward() 89 | optim.step() 90 | 91 | progress["current_epoch"] += 1 92 | 93 | # torchsnapshot: take snapshot 94 | snapshot = torchsnapshot.Snapshot.take( 95 | f"{work_dir}/run-{uuid.uuid4()}-epoch-{progress['current_epoch']}", 96 | app_state, 97 | replicated=["**"], # this pattern treats all states as replicated 98 | ) 99 | 100 | print(f"Snapshot path: {snapshot.path}") 101 | dist.destroy_process_group() 102 | 103 | 104 | def main() -> None: 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument("--work-dir", default="/tmp") 107 | parser.add_argument("--num-processes", type=int, default=2) 108 | parser.add_argument("--snapshot-path") 109 | args: argparse.Namespace = parser.parse_args() 110 | 111 | lc = pet.LaunchConfig( 112 | min_nodes=1, 113 | max_nodes=1, 114 | nproc_per_node=args.num_processes, 115 | run_id=str(uuid.uuid4()), 116 | rdzv_backend="c10d", 117 | rdzv_endpoint="localhost:0", 118 | max_restarts=0, 119 | monitor_interval=1, 120 | ) 121 | 122 | pet.elastic_launch(lc, entrypoint=train)(args.work_dir, args.snapshot_path) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() # pragma: no cover 127 | -------------------------------------------------------------------------------- /tests/test_serialization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | 11 | from typing import Tuple 12 | 13 | import pytest 14 | 15 | import torch 16 | 17 | from torchsnapshot.serialization import ( 18 | ALL_SUPPORTED_DTYPES, 19 | BUFFER_PROTOCOL_SUPPORTED_DTYPES, 20 | dtype_to_string, 21 | per_channel_qtensor_as_bytes, 22 | per_channel_qtensor_from_bytes, 23 | per_tensor_qtensor_as_bytes, 24 | per_tensor_qtensor_from_bytes, 25 | string_to_dtype, 26 | SUPPORTED_QUANTIZED_DTYPES, 27 | tensor_as_memoryview, 28 | tensor_from_memoryview, 29 | ) 30 | from torchsnapshot.test_utils import rand_tensor, tensor_eq 31 | 32 | 33 | @pytest.mark.parametrize("dtype", BUFFER_PROTOCOL_SUPPORTED_DTYPES) 34 | def test_buffer_protocol(dtype: torch.dtype) -> None: 35 | foo = rand_tensor(shape=(1000, 1000), dtype=dtype) 36 | 37 | serialized = tensor_as_memoryview(foo).tobytes() 38 | dtype_str = dtype_to_string(foo.dtype) 39 | shape = list(foo.shape) 40 | 41 | bar = tensor_from_memoryview( 42 | memoryview(serialized), 43 | dtype=string_to_dtype(dtype_str), 44 | shape=shape, 45 | ) 46 | assert torch.allclose(foo, bar) 47 | 48 | 49 | @pytest.mark.parametrize("dtype", ALL_SUPPORTED_DTYPES) 50 | def test_string_dtype_conversion(dtype: torch.dtype) -> None: 51 | dtype_str = dtype_to_string(dtype) 52 | restored = string_to_dtype(dtype_str) 53 | assert restored == dtype 54 | 55 | 56 | @pytest.mark.parametrize("dtype", SUPPORTED_QUANTIZED_DTYPES) 57 | @pytest.mark.parametrize("shape", [(100, 100), (10, 11, 12)]) 58 | def test_per_tensor_qtensor(dtype: torch.dtype, shape: Tuple[int, ...]) -> None: 59 | qtensor = rand_tensor(shape=shape, dtype=dtype) 60 | buf = per_tensor_qtensor_as_bytes(qtensor) 61 | deserialized = per_tensor_qtensor_from_bytes(buf, dtype=dtype, shape=list(shape)) 62 | assert qtensor.dtype == deserialized.dtype 63 | assert qtensor.is_quantized 64 | assert deserialized.is_quantized 65 | assert qtensor.qscheme() == deserialized.qscheme() 66 | assert qtensor.q_scale() == deserialized.q_scale() 67 | assert qtensor.q_zero_point() == deserialized.q_zero_point() 68 | assert qtensor.stride() == deserialized.stride() 69 | assert torch.allclose(qtensor.dequantize(), deserialized.dequantize()) 70 | 71 | 72 | @pytest.mark.parametrize("dtype", SUPPORTED_QUANTIZED_DTYPES) 73 | @pytest.mark.parametrize("shape", [(100, 100), (10, 11, 12)]) 74 | def test_per_channel_qtensor(dtype: torch.dtype, shape: Tuple[int, ...]) -> None: 75 | for axis in range(len(shape)): 76 | qtensor = rand_tensor( 77 | shape=shape, 78 | dtype=dtype, 79 | qscheme=torch.per_channel_affine, 80 | channel_axis=axis, 81 | ) 82 | buf = per_channel_qtensor_as_bytes(qtensor) 83 | deserialized = per_channel_qtensor_from_bytes( 84 | buf, dtype=dtype, shape=list(shape) 85 | ) 86 | assert qtensor.dtype == deserialized.dtype 87 | assert qtensor.is_quantized 88 | assert deserialized.is_quantized 89 | assert qtensor.qscheme(), deserialized.qscheme() 90 | assert torch.allclose( 91 | qtensor.q_per_channel_scales(), 92 | deserialized.q_per_channel_scales(), 93 | ) 94 | assert torch.allclose( 95 | qtensor.q_per_channel_zero_points(), 96 | deserialized.q_per_channel_zero_points(), 97 | ) 98 | assert qtensor.stride() == deserialized.stride() 99 | assert torch.allclose(qtensor.dequantize(), deserialized.dequantize()) 100 | 101 | 102 | @pytest.mark.parametrize("dtype", BUFFER_PROTOCOL_SUPPORTED_DTYPES) 103 | def test_tensor_as_memoryview_for_continuous_view(dtype: torch.dtype) -> None: 104 | """ 105 | Verify that tensor_as_memoryview() behaves correctly for continuous views. 106 | """ 107 | tensor = rand_tensor((64, 64), dtype=dtype) 108 | cont_view = tensor[32:, :] 109 | assert cont_view.is_contiguous() 110 | 111 | mv = tensor_as_memoryview(cont_view) 112 | assert len(mv) == cont_view.nelement() * cont_view.element_size() 113 | 114 | deserialized_view = tensor_from_memoryview(mv=mv, dtype=dtype, shape=[32, 64]) 115 | assert tensor_eq(deserialized_view, cont_view) 116 | -------------------------------------------------------------------------------- /tests/test_ddp_infer_replication.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | from typing import Dict, List 12 | 13 | import torch 14 | import torch.distributed as dist 15 | import torch.distributed.launcher as pet 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | from torchsnapshot import Snapshot, Stateful 18 | from torchsnapshot.test_utils import get_pet_launch_config 19 | 20 | 21 | class DDPInferReplicatedTest(unittest.TestCase): 22 | @staticmethod 23 | def _worker_helper(replicated: List[str], expected_replicated: List[str]) -> None: 24 | dist.init_process_group(backend="gloo") 25 | model = torch.nn.Sequential(torch.nn.Linear(4, 2), torch.nn.Linear(2, 1)) 26 | inferred_replicated = Snapshot._infer_replicated( 27 | replicated=replicated, app_state={"ddp": DDP(model), "nonddp": model} 28 | ) 29 | 30 | unittest.TestCase().assertCountEqual(expected_replicated, inferred_replicated) 31 | 32 | def test_with_no_glob(self) -> None: 33 | lc = get_pet_launch_config(nproc=2) 34 | replicated = [] 35 | expected_replicated = ["ddp/**"] 36 | pet.elastic_launch( 37 | lc, 38 | entrypoint=DDPInferReplicatedTest._worker_helper, 39 | )(replicated, expected_replicated) 40 | 41 | def test_with_all_glob(self) -> None: 42 | lc = get_pet_launch_config(nproc=2) 43 | replicated = ["**"] 44 | expected_replicated = ["**"] 45 | pet.elastic_launch( 46 | lc, 47 | entrypoint=DDPInferReplicatedTest._worker_helper, 48 | )(replicated, expected_replicated) 49 | 50 | def test_with_nonddp_glob(self) -> None: 51 | lc = get_pet_launch_config(nproc=2) 52 | replicated = ["nonddp/**"] 53 | expected_replicated = ["ddp/**", "nonddp/**"] 54 | pet.elastic_launch( 55 | lc, 56 | entrypoint=DDPInferReplicatedTest._worker_helper, 57 | )(replicated, expected_replicated) 58 | 59 | @staticmethod 60 | def _worker_with_params_to_ignore( 61 | replicated: List[str], expected_replicated: List[str] 62 | ) -> None: 63 | dist.init_process_group(backend="gloo") 64 | model = torch.nn.Sequential(torch.nn.Linear(4, 2), torch.nn.Linear(2, 1)) 65 | DDP._set_params_and_buffers_to_ignore_for_model( 66 | model, ["module.0.bias", "module.0.weight"] 67 | ) 68 | ddp_model = DDP(model) 69 | app_state: Dict[str, Stateful] = {"ddp": ddp_model, "nonddp": model} 70 | 71 | inferred_replicated = Snapshot._infer_replicated( 72 | replicated=replicated, app_state=app_state 73 | ) 74 | unittest.TestCase().assertCountEqual(expected_replicated, inferred_replicated) 75 | 76 | def test_with_params_to_ignore(self) -> None: 77 | lc = get_pet_launch_config(nproc=2) 78 | replicated = [] 79 | expected_replicated = ["ddp/module.1.bias", "ddp/module.1.weight"] 80 | pet.elastic_launch( 81 | lc, 82 | entrypoint=DDPInferReplicatedTest._worker_with_params_to_ignore, 83 | )(replicated, expected_replicated) 84 | 85 | @staticmethod 86 | def _worker_with_params_to_ignore_and_all_glob( 87 | replicated: List[str], expected_replicated: List[str] 88 | ) -> None: 89 | dist.init_process_group(backend="gloo") 90 | model = torch.nn.Sequential(torch.nn.Linear(4, 2), torch.nn.Linear(2, 1)) 91 | DDP._set_params_and_buffers_to_ignore_for_model( 92 | model, ["module.0.bias", "module.0.weight"] 93 | ) 94 | ddp_model = DDP(model) 95 | app_state: Dict[str, Stateful] = {"ddp": ddp_model, "nonddp": model} 96 | inferred_replicated = Snapshot._infer_replicated( 97 | replicated=replicated, app_state=app_state 98 | ) 99 | unittest.TestCase().assertCountEqual(expected_replicated, inferred_replicated) 100 | 101 | def test_with_params_to_ignore_and_all_glob(self) -> None: 102 | lc = get_pet_launch_config(nproc=2) 103 | replicated = ["**"] 104 | expected_replicated = ["**"] 105 | pet.elastic_launch( 106 | lc, 107 | entrypoint=DDPInferReplicatedTest._worker_with_params_to_ignore_and_all_glob, 108 | )(replicated, expected_replicated) 109 | -------------------------------------------------------------------------------- /tests/test_gcs_storage_plugin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[56] 11 | 12 | import io 13 | import logging 14 | import os 15 | import random 16 | import tempfile 17 | import uuid 18 | from typing import Generator 19 | 20 | import pytest 21 | 22 | import torch 23 | import torchsnapshot 24 | from torchsnapshot.io_types import ReadIO, WriteIO 25 | 26 | logger: logging.Logger = logging.getLogger(__name__) 27 | 28 | _TEST_BUCKET = "torchsnapshot-benchmark" 29 | _TENSOR_SZ = int(1_000_000 / 4) 30 | 31 | 32 | @pytest.fixture 33 | def gcs_health_check() -> None: 34 | """ 35 | GCS access can be flaky on Github Action. Only run the tests if the health 36 | check passes. 37 | """ 38 | try: 39 | from google.cloud import storage # @manual # pyre-ignore 40 | 41 | bucket = storage.Client().bucket(_TEST_BUCKET) # pyre-ignore 42 | blob = bucket.blob(str(uuid.uuid4())) 43 | with blob.open("w") as f: 44 | f.write("hello") 45 | with blob.open("r") as f: 46 | f.read() 47 | 48 | except Exception as e: 49 | # pyre-ignore[29] 50 | pytest.skip(f"Skipping the test because gcs health check failed: {e}") 51 | 52 | 53 | @pytest.fixture 54 | def gcs_test_credential() -> Generator[None, None, None]: 55 | if "GOOGLE_APPLICATION_CREDENTIALS" in os.environ: 56 | yield 57 | return 58 | 59 | if "GOOGLE_APPLICATION_CREDENTIALS_JSON" in os.environ: 60 | with tempfile.NamedTemporaryFile("w") as f: 61 | f.write(os.environ["GOOGLE_APPLICATION_CREDENTIALS_JSON"]) 62 | f.flush() 63 | os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = f.name 64 | yield 65 | del os.environ["GOOGLE_APPLICATION_CREDENTIALS"] 66 | 67 | 68 | @pytest.mark.gcs_integration_test 69 | @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_GCP_TEST") is None, reason="") 70 | @pytest.mark.usefixtures("gcs_test_credential", "gcs_health_check") 71 | def test_gcs_read_write_via_snapshot() -> None: 72 | path = f"gs://{_TEST_BUCKET}/{uuid.uuid4()}" 73 | logger.info(path) 74 | 75 | tensor = torch.rand((_TENSOR_SZ,)) 76 | app_state = {"state": torchsnapshot.StateDict(tensor=tensor)} 77 | snapshot = torchsnapshot.Snapshot.take(path=path, app_state=app_state) 78 | 79 | app_state["state"]["tensor"] = torch.rand((_TENSOR_SZ,)) 80 | assert not torch.allclose(tensor, app_state["state"]["tensor"]) 81 | 82 | snapshot.restore(app_state) 83 | assert torch.allclose(tensor, app_state["state"]["tensor"]) 84 | 85 | 86 | @pytest.mark.gcs_integration_test 87 | @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_GCP_TEST") is None, reason="") 88 | @pytest.mark.usefixtures("gcs_test_credential", "gcs_health_check") 89 | @pytest.mark.asyncio 90 | async def test_gcs_write_read_delete() -> None: 91 | path = f"{_TEST_BUCKET}/{uuid.uuid4()}" 92 | logger.info(path) 93 | 94 | from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin 95 | 96 | plugin = GCSStoragePlugin(root=path) 97 | 98 | tensor = torch.rand((_TENSOR_SZ,)) 99 | buf = io.BytesIO() 100 | torch.save(tensor, buf) 101 | write_io = WriteIO(path="tensor", buf=memoryview(buf.getvalue())) 102 | await plugin.write(write_io=write_io) 103 | 104 | read_io = ReadIO(path="tensor") 105 | await plugin.read(read_io=read_io) 106 | loaded = torch.load(read_io.buf) 107 | assert torch.allclose(tensor, loaded) 108 | 109 | # TODO: bring this back 110 | # await plugin.delete(path="tensor") 111 | await plugin.close() 112 | 113 | 114 | @pytest.mark.gcs_integration_test 115 | @pytest.mark.skipif(os.environ.get("TORCHSNAPSHOT_ENABLE_GCP_TEST") is None, reason="") 116 | @pytest.mark.usefixtures("gcs_test_credential", "gcs_health_check") 117 | @pytest.mark.asyncio 118 | async def test_gcs_ranged_read() -> None: 119 | path = f"{_TEST_BUCKET}/{uuid.uuid4()}" 120 | logger.info(path) 121 | 122 | from torchsnapshot.storage_plugins.gcs import GCSStoragePlugin 123 | 124 | plugin = GCSStoragePlugin(root=path) 125 | 126 | buf = bytes(random.getrandbits(8) for _ in range(2000)) 127 | write_io = WriteIO(path="rand_bytes", buf=memoryview(buf)) 128 | 129 | await plugin.write(write_io=write_io) 130 | 131 | read_io = ReadIO(path="rand_bytes", byte_range=(100, 200)) 132 | await plugin.read(read_io=read_io) 133 | assert len(read_io.buf.getvalue()) == 100 134 | assert read_io.buf.getvalue() == buf[100:200] 135 | 136 | await plugin.close() 137 | -------------------------------------------------------------------------------- /torchsnapshot/knobs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | # pyre-ignore-all-errors[2]: Allow `Any` in type annotations 11 | import os 12 | from contextlib import contextmanager 13 | from typing import Any, Generator 14 | 15 | # This file contains various non-user facing constants used throughout the 16 | # project, and utilities for overriding the constants for testing and debugging 17 | # purposes. Environment variable is chosen as the overriding mechanism since it 18 | # works well for unit tests, e2e tests, and real world use cases. Sometimes it 19 | # makes sense for the the function consuming one of these constants to also 20 | # allow overriding it via function argument for the ease of unit testing. In 21 | # such cases, the convention is to let the function argument take precedence. 22 | 23 | _MAX_CHUNK_SIZE_ENV_VAR = "TORCHSNAPSHOT_MAX_CHUNK_SIZE_BYTES_OVERRIDE" 24 | _MAX_SHARD_SIZE_ENV_VAR = "TORCHSNAPSHOT_MAX_SHARD_SIZE_BYTES_OVERRIDE" 25 | _SLAB_SIZE_THRESHOLD_ENV_VAR = "TORCHSNAPSHOT_SLAB_SIZE_THRESHOLD_BYTES_OVERRIDE" 26 | _MAX_PER_RANK_IO_CONCURRENCY_ENV_VAR = ( 27 | "TORCHSNAPSHOT_MAX_PER_RANK_IO_CONCURRENCY_OVERRIDE" 28 | ) 29 | 30 | _DEFAULT_MAX_CHUNK_SIZE_BYTES: int = 512 * 1024 * 1024 31 | _DEFAULT_MAX_SHARD_SIZE_BYTES: int = 512 * 1024 * 1024 32 | _DEFAULT_SLAB_SIZE_THRESHOLD_BYTES: int = 128 * 1024 * 1024 33 | _DISABLE_BATCHING_ENV_VAR: str = "TORCHSNAPSHOT_DISABLE_BATCHING" 34 | _ENABLE_SHARDED_TENSOR_ELASTICITY_ROOT_ENV_VAR: str = ( 35 | "TORCHSNAPSHOT_ENABLE_SHARDED_TENSOR_ELASTICITY_ROOT_ONLY" 36 | ) 37 | 38 | _DEFAULT_MAX_PER_RANK_IO_CONCURRENCY: int = 16 39 | 40 | 41 | def get_max_chunk_size_bytes() -> int: 42 | override = os.environ.get(_MAX_CHUNK_SIZE_ENV_VAR) 43 | if override is not None: 44 | return int(override) 45 | return _DEFAULT_MAX_CHUNK_SIZE_BYTES 46 | 47 | 48 | def get_max_shard_size_bytes() -> int: 49 | override = os.environ.get(_MAX_SHARD_SIZE_ENV_VAR) 50 | if override is not None: 51 | return int(override) 52 | return _DEFAULT_MAX_SHARD_SIZE_BYTES 53 | 54 | 55 | def get_slab_size_threshold_bytes() -> int: 56 | override = os.environ.get(_SLAB_SIZE_THRESHOLD_ENV_VAR) 57 | if override is not None: 58 | return int(override) 59 | return _DEFAULT_SLAB_SIZE_THRESHOLD_BYTES 60 | 61 | 62 | def get_max_per_rank_io_concurrency() -> int: 63 | override = os.environ.get(_MAX_PER_RANK_IO_CONCURRENCY_ENV_VAR) 64 | if override is not None: 65 | return int(override) 66 | return _DEFAULT_MAX_PER_RANK_IO_CONCURRENCY 67 | 68 | 69 | def is_batching_disabled() -> bool: 70 | if os.getenv(_DISABLE_BATCHING_ENV_VAR, "False").lower() in ("true", "1"): 71 | return True 72 | return False 73 | 74 | 75 | def is_sharded_tensor_elasticity_enabled_at_root_only() -> bool: 76 | if os.getenv(_ENABLE_SHARDED_TENSOR_ELASTICITY_ROOT_ENV_VAR, "False").lower() in ( 77 | "true", 78 | "1", 79 | ): 80 | return True 81 | return False 82 | 83 | 84 | @contextmanager 85 | def _override_env_var(env_var: str, value: Any) -> Generator[None, None, None]: 86 | prev = os.environ.get(env_var) 87 | os.environ[env_var] = str(value) 88 | yield 89 | if prev is None: 90 | del os.environ[env_var] 91 | else: 92 | os.environ[env_var] = prev 93 | 94 | 95 | @contextmanager 96 | def override_max_chunk_size_bytes( 97 | max_chunk_size_bytes: int, 98 | ) -> Generator[None, None, None]: 99 | with _override_env_var(_MAX_CHUNK_SIZE_ENV_VAR, max_chunk_size_bytes): 100 | yield 101 | 102 | 103 | @contextmanager 104 | def override_max_shard_size_bytes( 105 | max_shard_size_bytes: int, 106 | ) -> Generator[None, None, None]: 107 | with _override_env_var(_MAX_SHARD_SIZE_ENV_VAR, max_shard_size_bytes): 108 | yield 109 | 110 | 111 | @contextmanager 112 | def override_is_batching_disabled(disabled: bool) -> Generator[None, None, None]: 113 | with _override_env_var(_DISABLE_BATCHING_ENV_VAR, disabled): 114 | yield 115 | 116 | 117 | @contextmanager 118 | def override_slab_size_threshold_bytes( 119 | max_shard_size_bytes: int, 120 | ) -> Generator[None, None, None]: 121 | with _override_env_var(_MAX_SHARD_SIZE_ENV_VAR, max_shard_size_bytes): 122 | yield 123 | 124 | 125 | @contextmanager 126 | def override_max_per_rank_io_concurrency( 127 | max_per_rank_io_concurrency: int, 128 | ) -> Generator[None, None, None]: 129 | with _override_env_var( 130 | _MAX_PER_RANK_IO_CONCURRENCY_ENV_VAR, max_per_rank_io_concurrency 131 | ): 132 | yield 133 | -------------------------------------------------------------------------------- /tests/test_async_take.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import asyncio 11 | import os 12 | import tempfile 13 | import unittest 14 | from unittest.mock import patch 15 | 16 | import torch 17 | import torch.distributed as dist 18 | import torch.distributed.launcher as pet 19 | import torchsnapshot 20 | 21 | from torchsnapshot.io_types import WriteIO 22 | from torchsnapshot.snapshot import SNAPSHOT_METADATA_FNAME 23 | from torchsnapshot.storage_plugins.fs import FSStoragePlugin 24 | from torchsnapshot.test_utils import get_pet_launch_config 25 | 26 | 27 | class SlowFSStoragePlugin(FSStoragePlugin): 28 | async def write(self, write_io: WriteIO) -> None: 29 | await asyncio.sleep(5) 30 | await super().write(write_io=write_io) 31 | 32 | 33 | class FaultyFSStoragePlugin(FSStoragePlugin): 34 | async def write(self, write_io: WriteIO) -> None: 35 | await asyncio.sleep(5) 36 | if dist.get_world_size() == 1 or dist.get_rank() == 1: 37 | raise Exception("sorry") 38 | else: 39 | await super().write(write_io=write_io) 40 | 41 | 42 | class AsyncTakeTest(unittest.TestCase): 43 | @staticmethod 44 | def _test_async_take_with_error(path: str) -> None: 45 | tc = unittest.TestCase() 46 | 47 | dist.init_process_group(backend="gloo") 48 | with patch( 49 | "torchsnapshot.storage_plugin.FSStoragePlugin", FaultyFSStoragePlugin 50 | ): 51 | future = torchsnapshot.Snapshot.async_take( 52 | path, {"foo": torch.nn.Linear(128, 64)} 53 | ) 54 | tc.assertFalse(future.done()) 55 | with tc.assertRaisesRegex(RuntimeError, "sorry"): 56 | future.wait() 57 | 58 | def test_async_take_with_error(self) -> None: 59 | for nproc in [2, 4]: 60 | with tempfile.TemporaryDirectory() as path: 61 | lc = get_pet_launch_config(nproc=nproc) 62 | pet.elastic_launch(lc, entrypoint=self._test_async_take_with_error)( 63 | path 64 | ) 65 | metadata_path = os.path.join(path, SNAPSHOT_METADATA_FNAME) 66 | self.assertFalse(os.path.isfile(metadata_path)) 67 | 68 | @staticmethod 69 | def _test_unwaited_async_take(path: str) -> None: 70 | tc = unittest.TestCase() 71 | 72 | dist.init_process_group(backend="gloo") 73 | with patch("torchsnapshot.storage_plugin.FSStoragePlugin", SlowFSStoragePlugin): 74 | future = torchsnapshot.Snapshot.async_take( 75 | path, {"foo": torch.nn.Linear(128, 64)} 76 | ) 77 | tc.assertFalse(future.done()) 78 | 79 | # In Python3.8, an unwaited async snapshot can complete during interpreter shutdown. 80 | # In Python3.9, the follow exception would occur: 81 | # RuntimeError: cannot schedule new futures after interpreter shutdown 82 | # 83 | # TODO: if it's not possible to allow unwaited async snapshot in Python3.9, 84 | # we may need to require users to always explicitly wait for async snapshots. 85 | @unittest.skip( 86 | "Skipping due to inconsistent behavior between Python3.8 and Python3.9" 87 | ) 88 | def test_unwaited_async_take(self) -> None: 89 | for nproc in [1, 2, 4]: 90 | with tempfile.TemporaryDirectory() as path: 91 | lc = get_pet_launch_config(nproc=nproc) 92 | pet.elastic_launch(lc, entrypoint=self._test_unwaited_async_take)(path) 93 | metadata_path = os.path.join(path, SNAPSHOT_METADATA_FNAME) 94 | self.assertTrue(os.path.isfile(metadata_path)) 95 | 96 | @staticmethod 97 | def _test_unwaited_async_take_with_error(path: str) -> None: 98 | tc = unittest.TestCase() 99 | 100 | dist.init_process_group(backend="gloo") 101 | with patch( 102 | "torchsnapshot.storage_plugin.FSStoragePlugin", FaultyFSStoragePlugin 103 | ): 104 | future = torchsnapshot.Snapshot.async_take( 105 | path, {"foo": torch.nn.Linear(128, 64)} 106 | ) 107 | tc.assertFalse(future.done()) 108 | 109 | def test_unwaited_async_take_with_error(self) -> None: 110 | for nproc in [1, 2, 4]: 111 | with tempfile.TemporaryDirectory() as path: 112 | lc = get_pet_launch_config(nproc=nproc) 113 | pet.elastic_launch( 114 | lc, entrypoint=self._test_unwaited_async_take_with_error 115 | )(path) 116 | metadata_path = os.path.join(path, SNAPSHOT_METADATA_FNAME) 117 | self.assertFalse(os.path.isfile(metadata_path)) 118 | -------------------------------------------------------------------------------- /torchsnapshot/io_preparers/chunked_tensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import logging 11 | import math 12 | from dataclasses import dataclass 13 | from typing import Callable, List, Optional, Tuple, Union 14 | 15 | import torch 16 | 17 | from torchsnapshot.io_preparers.tensor import TensorIOPreparer 18 | 19 | from torchsnapshot.io_types import Future, ReadReq, WriteReq 20 | from torchsnapshot.knobs import get_max_chunk_size_bytes 21 | from torchsnapshot.manifest import ChunkedTensorEntry, Shard 22 | 23 | from torchsnapshot.serialization import dtype_to_string 24 | 25 | logger: logging.Logger = logging.getLogger(__name__) 26 | 27 | 28 | @dataclass 29 | class Chunk: 30 | offsets: List[int] 31 | sizes: List[int] 32 | dtype: str 33 | 34 | 35 | class ChunkedTensorIOPreparer: 36 | @staticmethod 37 | def chunk_tensor( 38 | tensor: torch.Tensor, 39 | chunking_dim: int = 0, 40 | chunk_sz_bytes: Optional[int] = None, 41 | ) -> List[Chunk]: 42 | chunk_sz_bytes = chunk_sz_bytes or get_max_chunk_size_bytes() 43 | 44 | # for 0-d case, reshape to 1-d 45 | if tensor.ndim == 0: 46 | tensor = tensor.view(-1) 47 | 48 | tensor_sz_bytes = tensor.numel() * tensor.element_size() 49 | n_chunks = math.ceil(tensor_sz_bytes / chunk_sz_bytes) 50 | tensor_chunks = torch.chunk(tensor, chunks=n_chunks, dim=chunking_dim) 51 | 52 | curr_offsets = [0] * tensor.ndim 53 | chunking_instruction = [] 54 | for i in range(len(tensor_chunks)): 55 | tensor_chunk_sizes = list(tensor_chunks[i].shape) 56 | chunking_instruction.append( 57 | Chunk( 58 | offsets=curr_offsets[:], 59 | sizes=tensor_chunk_sizes, 60 | dtype=str(tensor.dtype), 61 | ) 62 | ) 63 | curr_offsets[chunking_dim] += tensor_chunk_sizes[chunking_dim] 64 | return chunking_instruction 65 | 66 | @staticmethod 67 | def _get_subtensor_view( 68 | tensor: torch.Tensor, chunk: Union[Shard, Chunk] 69 | ) -> torch.Tensor: 70 | # for 0-d case, reshape to 1-d 71 | result = tensor.view(-1) if tensor.ndim == 0 else tensor 72 | 73 | for d in range(len(chunk.sizes)): 74 | result = result.narrow(d, chunk.offsets[d], chunk.sizes[d]) 75 | return result 76 | 77 | @classmethod 78 | def prepare_write( 79 | cls, 80 | storage_path: str, 81 | tensor: torch.Tensor, 82 | chunking_instruction: List[Chunk], 83 | is_async_snapshot: bool = False, 84 | _tensor_prepare_func: Optional[ 85 | Callable[[torch.Tensor, bool], torch.Tensor] 86 | ] = None, 87 | ) -> Tuple[ChunkedTensorEntry, List[WriteReq]]: 88 | write_reqs = [] 89 | chunks = [] 90 | for chunk in chunking_instruction: 91 | suffix = "_".join(str(x) for x in chunk.offsets) 92 | chunk_entry, chunk_write_reqs = TensorIOPreparer.prepare_write( 93 | storage_path=f"{storage_path}_{suffix}", 94 | tensor=cls._get_subtensor_view(tensor, chunk), 95 | is_async_snapshot=is_async_snapshot, 96 | _tensor_prepare_func=_tensor_prepare_func, 97 | ) 98 | chunks.append( 99 | Shard(offsets=chunk.offsets, sizes=chunk.sizes, tensor=chunk_entry) 100 | ) 101 | write_reqs += chunk_write_reqs 102 | chunked_entry = ChunkedTensorEntry( 103 | dtype=dtype_to_string(tensor.dtype), 104 | shape=list(tensor.shape), 105 | chunks=chunks, 106 | replicated=False, 107 | ) 108 | return chunked_entry, write_reqs 109 | 110 | @classmethod 111 | def prepare_read( 112 | cls, 113 | entry: ChunkedTensorEntry, 114 | tensor_out: Optional[torch.Tensor] = None, 115 | buffer_size_limit_bytes: Optional[int] = None, 116 | ) -> Tuple[List[ReadReq], Future[torch.Tensor]]: 117 | if tensor_out is None or not TensorIOPreparer.can_load_inplace( 118 | entry=entry, obj=tensor_out 119 | ): 120 | tensor_out = TensorIOPreparer.empty_tensor_from_entry(entry) 121 | read_reqs = [] 122 | for chunk in entry.chunks: 123 | tensor_out_chunk = cls._get_subtensor_view(tensor_out, chunk) 124 | chunk_read_reqs, _ = TensorIOPreparer.prepare_read( 125 | chunk.tensor, tensor_out_chunk, buffer_size_limit_bytes 126 | ) 127 | read_reqs += chunk_read_reqs 128 | return read_reqs, Future(obj=tensor_out) 129 | -------------------------------------------------------------------------------- /tests/gpu_tests/test_snapshot_dtensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import logging 11 | import uuid 12 | from typing import Optional 13 | 14 | import torch 15 | from torch import distributed as dist, nn 16 | from torch.distributed import init_device_mesh 17 | from torch.distributed._tensor import DeviceMesh 18 | from torch.distributed.fsdp import ( 19 | FullyShardedDataParallel as FSDP, 20 | ShardingStrategy, 21 | StateDictType, 22 | ) 23 | from torch.distributed.fsdp.api import ( 24 | ShardedOptimStateDictConfig, 25 | ShardedStateDictConfig, 26 | ) 27 | from torch.testing._internal.distributed._tensor.common_dtensor import ( 28 | DTensorTestBase, 29 | skip_if_lt_x_gpu, 30 | with_comms, 31 | ) 32 | from torchsnapshot import Snapshot 33 | from torchsnapshot.test_utils import check_state_dict_eq 34 | from torchsnapshot.tricks.fsdp import FSDPOptimizerAdapter 35 | 36 | logger: logging.Logger = logging.getLogger(__name__) 37 | 38 | 39 | WORLD_SIZE: int = 4 40 | 41 | 42 | class DummyModel(torch.nn.Module): 43 | # pyre-fixme[3]: Return type must be annotated. 44 | def __init__(self): 45 | super().__init__() 46 | self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU()) 47 | self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU()) 48 | self.net3 = nn.Linear(32, 64) 49 | self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8)) 50 | 51 | # pyre-fixme[3]: Return type must be annotated. 52 | # pyre-fixme[2]: Parameter must be annotated. 53 | def forward(self, x): 54 | return self.net4(self.net3(self.net2(self.net1(x)))) 55 | 56 | # pyre-fixme[3]: Return type must be annotated. 57 | def get_input(self): 58 | return torch.rand(4, 8, device="cuda") 59 | 60 | 61 | # TODO: Test different world sizes (may require not using DTensorTestBase) 62 | # TODO: Test FSDP + TP once dim_map is updated for [Shard(0), Shard(0)] cases 63 | class TestSnapshotWithDTensor(DTensorTestBase): 64 | # pyre-fixme[3]: Return type must be annotated. 65 | def _create_model( 66 | self, seed: int, optim_lr: float, device_mesh: Optional[DeviceMesh] = None 67 | ): 68 | torch.manual_seed(seed) 69 | # Using HSDP model as an example model that uses DTensor 70 | # This should create model with placements 71 | # [Replicate(), Shard(0)] 72 | if device_mesh: 73 | model = FSDP( 74 | DummyModel().cuda(), 75 | device_mesh=device_mesh, 76 | sharding_strategy=ShardingStrategy.HYBRID_SHARD, 77 | ) 78 | else: 79 | mesh_2d = init_device_mesh("cuda", (2, WORLD_SIZE // 2)) 80 | intra_node_pg = mesh_2d.get_group(mesh_dim=1) 81 | inter_node_pg = mesh_2d.get_group(mesh_dim=0) 82 | model = FSDP( 83 | DummyModel().cuda(), 84 | process_group=(intra_node_pg, inter_node_pg), 85 | sharding_strategy=ShardingStrategy.HYBRID_SHARD, 86 | ) 87 | 88 | FSDP.set_state_dict_type( 89 | model, 90 | StateDictType.SHARDED_STATE_DICT, 91 | state_dict_config=ShardedStateDictConfig(), 92 | optim_state_dict_config=ShardedOptimStateDictConfig(), 93 | ) 94 | 95 | # Need to step and zero_grad in order to initialize all the optimizer parameters 96 | optim = torch.optim.Adam(model.parameters(), lr=optim_lr) 97 | optim.step(closure=None) 98 | optim.zero_grad(set_to_none=True) 99 | 100 | optim = FSDPOptimizerAdapter(model, optim) 101 | 102 | return model, optim 103 | 104 | @with_comms 105 | @skip_if_lt_x_gpu(WORLD_SIZE) 106 | # pyre-fixme[3]: Return type must be annotated. 107 | def test_save_and_load_same_world_size(self): 108 | mesh_2d = init_device_mesh("cuda", (2, WORLD_SIZE // 2)) 109 | src_model, src_optim = self._create_model( 110 | seed=42, optim_lr=0.1, device_mesh=mesh_2d 111 | ) 112 | dst_model, dst_optim = self._create_model( 113 | seed=24, optim_lr=0.2, device_mesh=mesh_2d 114 | ) 115 | assert not check_state_dict_eq(src_model.state_dict(), dst_model.state_dict()) 116 | assert not check_state_dict_eq(src_optim.state_dict(), dst_optim.state_dict()) 117 | 118 | tmp_path = f"/tmp/{uuid.uuid4()}" 119 | if dist.get_rank() == 0: 120 | logger.info(f"Saving to {tmp_path}") 121 | 122 | snapshot = Snapshot.take( 123 | str(tmp_path), {"model": src_model, "optim": src_optim} 124 | ) 125 | snapshot.restore({"model": dst_model, "optim": dst_optim}) 126 | logging.info(f"{dst_model.state_dict()}") 127 | assert check_state_dict_eq(dst_model.state_dict(), src_model.state_dict()) 128 | assert check_state_dict_eq(dst_optim.state_dict(), src_optim.state_dict()) 129 | -------------------------------------------------------------------------------- /tests/gpu_tests/test_dtensor_io_preparer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from typing import List, Sequence, Set, Tuple 11 | 12 | import numpy as np 13 | 14 | import torch 15 | 16 | import torch.distributed as dist 17 | 18 | from torch.distributed._tensor import ( 19 | DeviceMesh, 20 | distribute_tensor, 21 | Placement, 22 | Replicate, 23 | Shard, 24 | ) 25 | from torch.testing._internal.common_utils import ( 26 | instantiate_parametrized_tests, 27 | parametrize, 28 | ) 29 | from torch.testing._internal.distributed._tensor.common_dtensor import ( 30 | DTensorTestBase, 31 | skip_if_lt_x_gpu, 32 | with_comms, 33 | ) 34 | 35 | from torchsnapshot.io_preparer import ( 36 | DTensorIOPreparer, 37 | TensorBufferConsumer, 38 | TensorIOPreparer, 39 | ) 40 | from torchsnapshot.manifest import NestedList 41 | 42 | from torchsnapshot.test_utils import tensor_eq 43 | 44 | WORLD_SIZE = 4 45 | # pyre-fixme[5]: Global expression must be annotated. 46 | _DEVICE_MESH = [ 47 | list(range(WORLD_SIZE)), 48 | np.arange(WORLD_SIZE).reshape(2, 2).tolist(), 49 | ] 50 | _PLACEMENTS = [ 51 | [Shard(0)], 52 | [Shard(1)], 53 | [Shard(0), Replicate()], 54 | [Replicate()], 55 | ] 56 | 57 | 58 | @instantiate_parametrized_tests 59 | class TestDTensorIOPreparer(DTensorTestBase): 60 | @parametrize("shape", [(16, 32), (32, 16)]) 61 | @parametrize("mesh", _DEVICE_MESH) 62 | @parametrize("placements", _PLACEMENTS) 63 | @skip_if_lt_x_gpu(WORLD_SIZE) 64 | # pyre-fixme[56]: While applying decorator `torch.testing._internal.distributed._... 65 | @with_comms 66 | async def test_dtensor_io_preparer( 67 | self, 68 | shape: Tuple[int, int], 69 | mesh: NestedList, 70 | placements: Sequence[Placement], 71 | ) -> None: 72 | """ 73 | Verify the basic behavior of DTensorIOPreparer prepare_write. 74 | """ 75 | device_mesh = DeviceMesh("cuda", mesh=mesh) 76 | 77 | if len(placements) > device_mesh.ndim: 78 | return 79 | src = distribute_tensor( 80 | tensor=torch.rand(*shape, device="cuda"), 81 | device_mesh=device_mesh, 82 | placements=placements, 83 | ) 84 | dst = distribute_tensor( 85 | tensor=torch.rand(*shape, device="cuda"), 86 | device_mesh=device_mesh, 87 | placements=placements, 88 | ) 89 | 90 | entry, write_reqs = DTensorIOPreparer.prepare_write( 91 | storage_path="/foo", 92 | obj=src, 93 | ) 94 | assert len(entry.shards) == len(write_reqs) 95 | 96 | # When subdivision is enabled, we have more write requests than local 97 | # shards, and each write request corresponds to a subview of a local 98 | # shard. 99 | # pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]` 100 | # but got `int`. 101 | assert len(src._spec.num_shards) < len(write_reqs) 102 | entry_total_size = 0 103 | for shard_entry in entry.shards: 104 | entry_total_size += TensorIOPreparer.get_tensor_size_from_entry( 105 | shard_entry.tensor 106 | ) 107 | assert ( 108 | entry_total_size 109 | == src.to_local().storage().size() * src.to_local().element_size() 110 | ) 111 | 112 | # Verify no overlapping locations among local shards 113 | locations = set() 114 | for shard, wr in zip(entry.shards, write_reqs): 115 | assert shard.tensor.location == wr.path 116 | locations.add(wr.path) 117 | 118 | assert len(locations) == len(write_reqs) 119 | 120 | # Verify no overlapping locations among global shards 121 | # pyre-ignore 122 | obj_list: List[Set[str]] = [None] * dist.get_world_size() 123 | dist.all_gather_object(obj_list, locations) 124 | all_locations = [location for ls in obj_list for location in ls] 125 | assert len(set(all_locations)) == len(all_locations) 126 | 127 | location_to_buf = { 128 | wr.path: bytes(await wr.buffer_stager.stage_buffer()) for wr in write_reqs 129 | } 130 | 131 | # Verify that the size of the storage of a persisted shard matches with the 132 | # shape of the shard (as opposed to the size of the storage of the shard). 133 | for idx, buf in enumerate(location_to_buf.values()): 134 | deserialized = TensorBufferConsumer.deserialize_tensor( 135 | buf=buf, entry=entry.shards[idx].tensor 136 | ) 137 | assert ( 138 | deserialized.storage().size() * deserialized.element_size() 139 | == TensorIOPreparer.get_tensor_size_from_entry(entry.shards[idx].tensor) 140 | ) 141 | 142 | # First verify that src != dst, then consume the buffers with dst 143 | # and verify that src == dst 144 | assert not tensor_eq(src, dst) 145 | read_reqs, _ = DTensorIOPreparer.prepare_read(entry=entry, obj_out=dst) 146 | for rr in read_reqs: 147 | await rr.buffer_consumer.consume_buffer(buf=location_to_buf[rr.path]) 148 | assert tensor_eq(src, dst) 149 | -------------------------------------------------------------------------------- /tests/test_test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import unittest 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torch.distributed.launcher as pet 15 | from torch.distributed._shard.metadata import ShardMetadata 16 | from torch.distributed._shard.sharded_tensor import ( 17 | init_from_local_shards, 18 | Shard as ShardedTensorShard, 19 | ShardedTensor, 20 | ) 21 | from torch.distributed._tensor import ( 22 | DeviceMesh, 23 | distribute_tensor, 24 | DTensor, 25 | Replicate, 26 | Shard, 27 | ) 28 | from torchsnapshot.test_utils import ( 29 | assert_state_dict_eq, 30 | check_state_dict_eq, 31 | get_pet_launch_config, 32 | ) 33 | 34 | 35 | class TestUtilsTest(unittest.TestCase): 36 | """ 37 | Watch the watchmen. 38 | """ 39 | 40 | def test_assert_state_dict_eq(self) -> None: 41 | t0 = torch.rand(16, 16) 42 | t1 = torch.rand(16, 16) 43 | a = {"foo": t0, "bar": [t1], "baz": 42} 44 | b = {"foo": t0, "bar": [t1], "baz": 42} 45 | c = {"foo": t0, "bar": [t0], "baz": 42} 46 | d = {"foo": t1, "bar": [t1], "baz": 42} 47 | e = {"foo": t0, "bar": [t1], "baz": 43} 48 | 49 | assert_state_dict_eq(self, a, b) 50 | with self.assertRaises(AssertionError): 51 | assert_state_dict_eq(self, a, c) 52 | with self.assertRaises(AssertionError): 53 | assert_state_dict_eq(self, a, d) 54 | with self.assertRaises(AssertionError): 55 | assert_state_dict_eq(self, a, e) 56 | 57 | def test_check_state_dict_eq(self) -> None: 58 | t0 = torch.rand(16, 16) 59 | t1 = torch.rand(16, 16) 60 | a = {"foo": t0, "bar": [t1], "baz": 42} 61 | b = {"foo": t0, "bar": [t1], "baz": 42} 62 | c = {"foo": t0, "bar": [t0], "baz": 42} 63 | d = {"foo": t1, "bar": [t1], "baz": 42} 64 | e = {"foo": t0, "bar": [t1], "baz": 43} 65 | 66 | self.assertTrue(check_state_dict_eq(a, b)) 67 | self.assertFalse(check_state_dict_eq(a, c)) 68 | self.assertFalse(check_state_dict_eq(a, d)) 69 | self.assertFalse(check_state_dict_eq(a, e)) 70 | 71 | @staticmethod 72 | def _create_sharded_tensor() -> ShardedTensor: 73 | dim_0: int = 128 74 | dim_1: int = 16 75 | 76 | global_tensor = torch.rand((dim_0, dim_1)) 77 | 78 | rank = dist.get_rank() 79 | world_sz = dist.get_world_size() 80 | chunk_sz = int(dim_0 / world_sz) 81 | begin = rank * chunk_sz 82 | 83 | shard_view = torch.narrow(global_tensor, 0, begin, chunk_sz) 84 | shard = ShardedTensorShard( 85 | tensor=shard_view, 86 | metadata=ShardMetadata( 87 | shard_offsets=[begin, 0], 88 | shard_sizes=[chunk_sz, dim_1], 89 | placement=f"rank:{rank}/cpu", 90 | ), 91 | ) 92 | return init_from_local_shards([shard], (dim_0, dim_1)) 93 | 94 | @classmethod 95 | def _worker_sharded_tensor(cls) -> None: 96 | dist.init_process_group(backend="gloo") 97 | 98 | torch.manual_seed(42) 99 | foo = {"": cls._create_sharded_tensor()} 100 | torch.manual_seed(42) 101 | bar = {"": cls._create_sharded_tensor()} 102 | torch.manual_seed(777) 103 | baz = {"": cls._create_sharded_tensor()} 104 | 105 | tc = unittest.TestCase() 106 | assert_state_dict_eq(tc, foo, foo) 107 | assert_state_dict_eq(tc, foo, bar) 108 | with tc.assertRaises(AssertionError): 109 | assert_state_dict_eq(tc, foo, baz) 110 | 111 | tc.assertTrue(check_state_dict_eq(foo, foo)) 112 | tc.assertTrue(check_state_dict_eq(foo, bar)) 113 | tc.assertFalse(check_state_dict_eq(foo, baz)) 114 | 115 | def test_state_dict_eq_with_sharded_tensor(self) -> None: 116 | lc = get_pet_launch_config(nproc=4) 117 | pet.elastic_launch(lc, entrypoint=self._worker_sharded_tensor)() 118 | 119 | @staticmethod 120 | def _create_dtensor() -> DTensor: 121 | dim_0: int = 128 122 | dim_1: int = 16 123 | 124 | local_tensor = torch.rand((dim_0, dim_1)) 125 | 126 | mesh = DeviceMesh("cpu", mesh=[[0, 1], [2, 3]]) 127 | placements = [Replicate(), Shard(0)] 128 | dtensor = distribute_tensor( 129 | tensor=local_tensor, device_mesh=mesh, placements=placements 130 | ) 131 | 132 | return dtensor 133 | 134 | @classmethod 135 | def _worker_dtensor(cls) -> None: 136 | dist.init_process_group(backend="gloo") 137 | 138 | torch.manual_seed(42) 139 | foo = {"": cls._create_dtensor()} 140 | torch.manual_seed(42) 141 | bar = {"": cls._create_dtensor()} 142 | torch.manual_seed(777) 143 | baz = {"": cls._create_dtensor()} 144 | 145 | tc = unittest.TestCase() 146 | assert_state_dict_eq(tc, foo, foo) 147 | assert_state_dict_eq(tc, foo, bar) 148 | with tc.assertRaises(AssertionError): 149 | assert_state_dict_eq(tc, foo, baz) 150 | 151 | tc.assertTrue(check_state_dict_eq(foo, foo)) 152 | tc.assertTrue(check_state_dict_eq(foo, bar)) 153 | tc.assertFalse(check_state_dict_eq(foo, baz)) 154 | 155 | def test_state_dict_eq_with_dtensor(self) -> None: 156 | lc = get_pet_launch_config(nproc=4) 157 | pet.elastic_launch(lc, entrypoint=self._worker_dtensor)() 158 | -------------------------------------------------------------------------------- /benchmarks/deepspeed_opt/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | import os 10 | import time 11 | import uuid 12 | from enum import Enum 13 | from typing import Optional 14 | 15 | import deepspeed 16 | import torch 17 | import torch.distributed as dist 18 | 19 | from deepspeed import DeepSpeedEngine 20 | from torchsnapshot.tricks.deepspeed import patch_engine_to_use_torchsnapshot 21 | from transformers import OPTConfig, OPTModel 22 | from transformers.deepspeed import HfDeepSpeedConfig 23 | 24 | dschf: Optional[HfDeepSpeedConfig] = None 25 | 26 | 27 | # https://arxiv.org/pdf/2205.01068.pdf 28 | TRAIN_BATCH_SIZE = 1024**2 29 | NUM_HIDDEN_LAYERS = 48 30 | NUM_ATTENTION_HEADS = 56 31 | HIDDEN_SIZE = 7168 32 | 33 | 34 | class BenchmarkType(Enum): 35 | TORCHSNAPSHOT = "torchsnapshot" 36 | DEEPSPEED = "deepspeed" 37 | 38 | def __str__(self): 39 | return self.value 40 | 41 | 42 | def rank_0_print(msg: str) -> None: 43 | if dist.get_rank() == 0: 44 | print(msg) 45 | 46 | 47 | def initialize_deepspeed_opt() -> DeepSpeedEngine: 48 | ds_config = { 49 | "train_batch_size": TRAIN_BATCH_SIZE, 50 | "fp16": { 51 | "enabled": True, 52 | }, 53 | "zero_optimization": { 54 | "stage": 3, 55 | }, 56 | "optimizer": { 57 | "type": "Adam", 58 | "params": { 59 | "lr": 2e-4, 60 | "weight_decay": 0.01, 61 | }, 62 | }, 63 | } 64 | # HfDeepSpeedConfig must be created before instantiating the model and and kept alive. 65 | # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration 66 | dschf = HfDeepSpeedConfig(ds_config) # noqa 67 | 68 | with deepspeed.zero.Init(): 69 | config = OPTConfig( 70 | num_hidden_layers=NUM_HIDDEN_LAYERS, 71 | num_attention_heads=NUM_ATTENTION_HEADS, 72 | hidden_size=HIDDEN_SIZE, 73 | ) 74 | model = OPTModel(config) 75 | 76 | engine, _, _, _ = deepspeed.initialize( 77 | model=model, model_parameters=model.parameters(), config_params=ds_config 78 | ) 79 | return engine 80 | 81 | 82 | def benchmark_torchsnapshot( 83 | engine: DeepSpeedEngine, save_dir: str, benchmark_load: bool 84 | ) -> None: 85 | patch_engine_to_use_torchsnapshot(engine) 86 | 87 | rank_0_print("Saving a checkpoint with torchsnapshot...") 88 | begin_ts = time.monotonic() 89 | engine.save_checkpoint(save_dir=save_dir) 90 | rank_0_print( 91 | f"Completed saving with torchsnapshot (snapshot path: {save_dir}).\n" 92 | f"Took {time.monotonic() - begin_ts:.2f} seconds." 93 | ) 94 | 95 | if benchmark_load: 96 | del engine 97 | engine = initialize_deepspeed_opt() 98 | patch_engine_to_use_torchsnapshot(engine) 99 | 100 | rank_0_print("Loading the checkpoint with torchsnapshot...") 101 | begin_ts = time.monotonic() 102 | engine.load_checkpoint(load_dir=save_dir) 103 | rank_0_print( 104 | f"Completed loading with torchsnapshot.\n" 105 | f"Took {time.monotonic() - begin_ts:.2f} seconds." 106 | ) 107 | 108 | 109 | def benchmark_deepspeed( 110 | engine: DeepSpeedEngine, save_dir: str, benchmark_load: bool 111 | ) -> None: 112 | rank_0_print("Saving a checkpoint with DeepSpeedEngine.save_checkpoint()...") 113 | begin_ts = time.monotonic() 114 | engine.save_checkpoint(save_dir=save_dir) 115 | rank_0_print( 116 | f"Completed saving with DeepSpeedEngine.save_checkpoint() (save_dir: {save_dir}).\n" 117 | f"Took {time.monotonic() - begin_ts:.2f} seconds." 118 | ) 119 | if benchmark_load: 120 | del engine 121 | engine = initialize_deepspeed_opt() 122 | rank_0_print("Loading the checkpoint with DeepSpeedEngine.save_checkpoint()...") 123 | begin_ts = time.monotonic() 124 | engine.load_checkpoint(load_dir=save_dir) 125 | rank_0_print( 126 | f"Completed loading with DeepSpeedEngine.load_checkpoint().\n" 127 | f"Took {time.monotonic() - begin_ts:.2f} seconds." 128 | ) 129 | 130 | 131 | def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) -> None: 132 | logger = logging.getLogger("torchsnapshot.scheduler") 133 | logger.setLevel(logging.DEBUG) 134 | 135 | dist.init_process_group(backend="nccl") 136 | local_rank = int(os.environ["LOCAL_RANK"]) 137 | device = torch.device(f"cuda:{local_rank}") 138 | torch.cuda.set_device(device) 139 | 140 | save_dir = f"{work_dir}/{uuid.uuid4()}" 141 | object_list = [None] * dist.get_world_size() 142 | object_list[dist.get_rank()] = save_dir 143 | dist.broadcast_object_list(object_list=object_list, src=0) 144 | save_dir = object_list[0] 145 | 146 | engine = initialize_deepspeed_opt() 147 | if benchmark_type == BenchmarkType.TORCHSNAPSHOT: 148 | benchmark_torchsnapshot( 149 | engine=engine, save_dir=save_dir, benchmark_load=benchmark_load 150 | ) 151 | elif benchmark_type == BenchmarkType.DEEPSPEED: 152 | benchmark_deepspeed( 153 | engine=engine, save_dir=save_dir, benchmark_load=benchmark_load 154 | ) 155 | else: 156 | raise ValueError(f"Unrecognized benchmark type: {benchmark_type}") 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument( 162 | "--benchmark-type", 163 | type=BenchmarkType, 164 | choices=list(BenchmarkType), 165 | default=BenchmarkType.TORCHSNAPSHOT, 166 | ) 167 | parser.add_argument("--work-dir", default="/tmp") 168 | parser.add_argument("--benchmark-load", action="store_true", default=False) 169 | 170 | args: argparse.Namespace = parser.parse_args() 171 | main( 172 | benchmark_type=args.benchmark_type, 173 | work_dir=args.work_dir, 174 | benchmark_load=args.benchmark_load, 175 | ) 176 | -------------------------------------------------------------------------------- /benchmarks/fsdp/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import os 9 | import time 10 | from enum import Enum 11 | from functools import partial 12 | from uuid import uuid4 13 | 14 | import torch 15 | from torch import distributed as dist, nn 16 | from torch.distributed.elastic.multiprocessing.errors import record 17 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType 18 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 19 | from torchsnapshot import Snapshot 20 | 21 | 22 | class BenchmarkType(Enum): 23 | TORCHSNAPSHOT = "torchsnapshot" 24 | TORCH_SAVE = "torch_save" 25 | 26 | def __str__(self): 27 | return self.value 28 | 29 | 30 | def rank_0_print(msg: str) -> None: 31 | if dist.get_rank() == 0: 32 | print(msg) 33 | 34 | 35 | def create_model() -> nn.Module: 36 | # 7.8GB model, 1.9B parameters 37 | model = nn.Transformer( 38 | d_model=864, 39 | num_encoder_layers=1, 40 | num_decoder_layers=20, 41 | nhead=12, 42 | dim_feedforward=50257, 43 | ) 44 | 45 | # 80GB 21B parameters 46 | # model = nn.Transformer( 47 | # d_model=4000, 48 | # num_encoder_layers=1, 49 | # num_decoder_layers=40, 50 | # nhead=40, 51 | # dim_feedforward=50257, 52 | # ) 53 | 54 | model_size = sum( 55 | p.numel() * p.element_size() for p in model.parameters() if p.requires_grad 56 | ) 57 | model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 58 | 59 | rank_0_print(f"model parameters: {model_params:,}") 60 | rank_0_print(f"model size: {model_size / (1024 ** 3):.3} GB") 61 | 62 | return FSDP( 63 | model, 64 | auto_wrap_policy=partial( 65 | transformer_auto_wrap_policy, 66 | transformer_layer_cls={ 67 | nn.TransformerDecoderLayer, 68 | nn.TransformerEncoderLayer, 69 | }, 70 | ), 71 | device_id=int(os.environ["LOCAL_RANK"]), 72 | ) 73 | 74 | 75 | def benchmark_torchsnapshot( 76 | model: nn.Module, save_dir: str, benchmark_load: bool 77 | ) -> None: 78 | rank_0_print("Saving a checkpoint with torchsnapshot...") 79 | app_state = {"model": model} 80 | begin_ts = time.monotonic() 81 | with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): 82 | Snapshot.take( 83 | path=save_dir, 84 | app_state=app_state, 85 | ) 86 | dist.barrier() 87 | end_ts = time.monotonic() 88 | rank_0_print( 89 | f"Completed saving with torchsnapshot (snapshot path: {save_dir}).\n" 90 | f"Took {end_ts - begin_ts:.2f} seconds." 91 | ) 92 | 93 | if benchmark_load: 94 | rank_0_print("Loading the checkpoint with torchsnapshot...") 95 | begin_ts = time.monotonic() 96 | with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): 97 | snapshot = Snapshot(path=save_dir) 98 | snapshot.restore(app_state) 99 | end_ts = time.monotonic() 100 | rank_0_print( 101 | f"Completed loading with torchsnapshot.\n" 102 | f"Took {end_ts - begin_ts:.2f} seconds." 103 | ) 104 | 105 | 106 | def benchmark_torchsave(model: nn.Module, save_dir: str, benchmark_load: bool) -> None: 107 | rank_0_print("Saving a checkpoint with torch.save...") 108 | 109 | os.makedirs(save_dir, exist_ok=True) 110 | save_file = f"{save_dir}/state_dict-{dist.get_rank()}.pt" 111 | 112 | begin_ts = time.monotonic() 113 | with FSDP.state_dict_type( 114 | model, 115 | StateDictType.LOCAL_STATE_DICT, 116 | ): 117 | state_dict = model.state_dict() 118 | torch.save(state_dict, save_file) 119 | dist.barrier() 120 | end_ts = time.monotonic() 121 | rank_0_print( 122 | f"Completed saving with torch.save (path: {save_dir}).\n" 123 | f"Took {end_ts - begin_ts:.2f} seconds." 124 | ) 125 | 126 | if benchmark_load: 127 | begin_ts = time.monotonic() 128 | with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): 129 | model.load_state_dict(torch.load(save_file)) 130 | dist.barrier() 131 | end_ts = time.monotonic() 132 | rank_0_print( 133 | f"Completed loading with torch.save.\n" 134 | f"Took {end_ts - begin_ts:.2f} seconds." 135 | ) 136 | 137 | 138 | @record 139 | def main(benchmark_type: BenchmarkType, work_dir: str, benchmark_load: bool) -> None: 140 | dist.init_process_group("nccl") 141 | local_rank = int(os.environ["LOCAL_RANK"]) 142 | device = torch.device(f"cuda:{local_rank}") 143 | torch.cuda.set_device(device) 144 | 145 | save_dir = f"{work_dir}/{uuid4()}" 146 | object_list = [None] * dist.get_world_size() 147 | object_list[dist.get_rank()] = save_dir 148 | dist.broadcast_object_list(object_list=object_list, src=0) 149 | save_dir = object_list[0] 150 | 151 | model = create_model() 152 | model.to(device) 153 | 154 | if benchmark_type == BenchmarkType.TORCHSNAPSHOT: 155 | benchmark_torchsnapshot(model, save_dir, benchmark_load) 156 | elif benchmark_type == BenchmarkType.TORCH_SAVE: 157 | benchmark_torchsave(model, save_dir, benchmark_load) 158 | else: 159 | raise ValueError(f"Unrecognized benchmark type: {benchmark_type}") 160 | 161 | 162 | if __name__ == "__main__": 163 | parser = argparse.ArgumentParser() 164 | parser.add_argument( 165 | "--benchmark-type", 166 | type=BenchmarkType, 167 | choices=list(BenchmarkType), 168 | default=BenchmarkType.TORCHSNAPSHOT, 169 | ) 170 | parser.add_argument("--work-dir", default="/tmp") 171 | parser.add_argument("--benchmark-load", action="store_true", default=False) 172 | 173 | args: argparse.Namespace = parser.parse_args() 174 | main( 175 | benchmark_type=args.benchmark_type, 176 | work_dir=args.work_dir, 177 | benchmark_load=args.benchmark_load, 178 | ) 179 | -------------------------------------------------------------------------------- /tests/test_read_object.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import tempfile 11 | import unittest 12 | 13 | import torch 14 | import torch.distributed as dist 15 | import torch.distributed.launcher as pet 16 | import torchsnapshot 17 | from torch.distributed._shard import sharded_tensor 18 | from torch.distributed._shard.sharding_spec import ChunkShardingSpec 19 | from torchsnapshot.test_utils import get_pet_launch_config 20 | 21 | 22 | class ReadObjectTest(unittest.TestCase): 23 | def test_read_object(self) -> None: 24 | state = torchsnapshot.StateDict( 25 | foo=42, 26 | bar=torch.randn(20, 20), 27 | ) 28 | 29 | with tempfile.TemporaryDirectory() as path: 30 | snapshot = torchsnapshot.Snapshot.take( 31 | path=path, app_state={"state": state} 32 | ) 33 | 34 | self.assertEqual(snapshot.read_object("0/state/foo"), 42) 35 | self.assertEqual(snapshot.read_object("0/state/foo", 777), 42) 36 | 37 | baz = torch.randn(20, 20) 38 | self.assertFalse(torch.allclose(baz, state["bar"])) 39 | 40 | loaded_bar = snapshot.read_object("0/state/bar", baz) 41 | self.assertEqual(id(loaded_bar), id(baz)) 42 | self.assertNotEqual(id(loaded_bar), id(state["bar"])) 43 | self.assertTrue(torch.allclose(baz, state["bar"])) 44 | 45 | @staticmethod 46 | def _test_read_sharded_tensor() -> None: 47 | tc = unittest.TestCase() 48 | dist.init_process_group(backend="gloo") 49 | torch.manual_seed(42 + dist.get_rank()) 50 | 51 | # pyre-ignore [28] 52 | spec = ChunkShardingSpec( 53 | dim=0, 54 | placements=[f"rank:{rank}/cpu" for rank in range(dist.get_world_size())], 55 | ) 56 | foo = sharded_tensor.empty(spec, 20_000, 128) 57 | for shard in foo.local_shards(): 58 | torch.nn.init.uniform_(shard.tensor) 59 | 60 | bar = sharded_tensor.empty(spec, 20_000, 128) 61 | for shard in bar.local_shards(): 62 | torch.nn.init.uniform_(shard.tensor) 63 | 64 | for foo_shard, bar_shard in zip(foo.local_shards(), bar.local_shards()): 65 | tc.assertFalse(torch.allclose(foo_shard.tensor, bar_shard.tensor)) 66 | 67 | with tempfile.TemporaryDirectory() as path: 68 | snapshot = torchsnapshot.Snapshot.take( 69 | path=path, app_state={"state": torchsnapshot.StateDict(foo=foo)} 70 | ) 71 | snapshot.read_object("0/state/foo", obj_out=bar) 72 | baz = snapshot.read_object("0/state/foo") 73 | 74 | for foo_shard, bar_shard in zip(foo.local_shards(), bar.local_shards()): 75 | tc.assertTrue(torch.allclose(foo_shard.tensor, bar_shard.tensor)) 76 | 77 | tc.assertEqual(baz.shape, torch.Size([20_000, 128])) 78 | 79 | gathered_foo_tensor = torch.empty(20_000, 128) 80 | if dist.get_rank() == 0: 81 | foo.gather(dst=0, out=gathered_foo_tensor) 82 | tc.assertTrue(torch.allclose(baz, gathered_foo_tensor)) 83 | else: 84 | foo.gather(dst=0, out=None) 85 | 86 | def test_read_sharded_tensor(self) -> None: 87 | lc = get_pet_launch_config(nproc=4) 88 | pet.elastic_launch(lc, entrypoint=self._test_read_sharded_tensor)() 89 | 90 | @staticmethod 91 | def _quantize(path: str, tensor: torch.Tensor, tracing: bool) -> torch.Tensor: 92 | return torch.quantize_per_tensor(tensor, 0.1, 10, torch.qint8) 93 | 94 | @classmethod 95 | def _test_read_sharded_tensor_into_tensor(cls, quantized: bool) -> None: 96 | tc = unittest.TestCase() 97 | dist.init_process_group(backend="gloo") 98 | torch.manual_seed(42 + dist.get_rank()) 99 | 100 | # pyre-ignore [28] 101 | spec = ChunkShardingSpec( 102 | dim=0, 103 | placements=[f"rank:{rank}/cpu" for rank in range(dist.get_world_size())], 104 | ) 105 | foo = sharded_tensor.empty(spec, 20_000, 128) 106 | for shard in foo.local_shards(): 107 | torch.nn.init.uniform_(shard.tensor) 108 | 109 | # Gather the sharded tensor for ease of comparison 110 | if dist.get_rank() == 0: 111 | foo_gathered = torch.empty(20_000, 128) 112 | foo.gather(dst=0, out=foo_gathered) 113 | else: 114 | foo_gathered = torch.empty(42) 115 | foo.gather(dst=0, out=None) 116 | 117 | # Create a tensor into which the sharded tensor will be loaded 118 | bar = torch.rand_like(foo_gathered) 119 | 120 | if quantized: 121 | foo_gathered = torch.quantize_per_tensor(foo_gathered, 0.1, 10, torch.qint8) 122 | bar = torch.quantize_per_tensor(bar, 0.1, 10, torch.qint8) 123 | 124 | # Control test: these two tensors should be different 125 | tc.assertFalse( 126 | torch.allclose(torch.dequantize(foo_gathered), torch.dequantize(bar)) 127 | ) 128 | 129 | with tempfile.TemporaryDirectory() as path: 130 | _custom_tensor_prepare_func = cls._quantize if quantized else None 131 | snapshot = torchsnapshot.Snapshot.take( 132 | path=path, 133 | app_state={"state": torchsnapshot.StateDict(foo=foo)}, 134 | _custom_tensor_prepare_func=_custom_tensor_prepare_func, 135 | ) 136 | if dist.get_rank() == 0: 137 | snapshot.read_object("0/state/foo", obj_out=bar) 138 | tc.assertTrue( 139 | torch.allclose( 140 | torch.dequantize(foo_gathered), torch.dequantize(bar) 141 | ) 142 | ) 143 | 144 | def test_read_sharded_tensor_into_tensor(self) -> None: 145 | lc = get_pet_launch_config(nproc=4) 146 | pet.elastic_launch(lc, entrypoint=self._test_read_sharded_tensor_into_tensor)( 147 | True # quantize=True 148 | ) 149 | lc = get_pet_launch_config(nproc=4) 150 | pet.elastic_launch(lc, entrypoint=self._test_read_sharded_tensor_into_tensor)( 151 | False # quantized=False 152 | ) 153 | -------------------------------------------------------------------------------- /torchsnapshot/asyncio_utils.py: -------------------------------------------------------------------------------- 1 | # pyre-unsafe 2 | 3 | import asyncio 4 | import functools 5 | import os 6 | import sys 7 | import threading 8 | from contextlib import contextmanager 9 | from heapq import heappop 10 | 11 | 12 | # copy-pasted from nest-asyncio, but modified to avoid patching the global 13 | # namespace and instead only patching the instance variable 14 | def _patch_loop(loop: asyncio.AbstractEventLoop) -> None: 15 | def run_forever(self): 16 | with manage_run(self), manage_asyncgens(self): 17 | while True: 18 | self._run_once() 19 | if self._stopping: 20 | break 21 | self._stopping = False 22 | 23 | def run_until_complete(self, future): 24 | with manage_run(self): 25 | f = asyncio.ensure_future(future, loop=self) 26 | if f is not future: 27 | f._log_destroy_pending = False 28 | while not f.done(): 29 | self._run_once() 30 | if self._stopping: 31 | break 32 | if not f.done(): 33 | raise RuntimeError("Event loop stopped before Future completed.") 34 | return f.result() 35 | 36 | def _run_once(self): 37 | """ 38 | Simplified re-implementation of asyncio's _run_once that 39 | runs handles as they become ready. 40 | """ 41 | now = self.time() 42 | ready = self._ready 43 | scheduled = self._scheduled 44 | while scheduled and scheduled[0]._cancelled: 45 | heappop(scheduled) 46 | 47 | timeout = ( 48 | 0 49 | if ready or self._stopping 50 | else min(max(scheduled[0]._when - now, 0), 86400) 51 | if scheduled 52 | else None 53 | ) 54 | event_list = self._selector.select(timeout) 55 | self._process_events(event_list) 56 | 57 | end_time = self.time() + self._clock_resolution 58 | while scheduled and scheduled[0]._when < end_time: 59 | handle = heappop(scheduled) 60 | ready.append(handle) 61 | 62 | for _ in range(len(ready)): 63 | if not ready: 64 | break 65 | handle = ready.popleft() 66 | if not handle._cancelled: 67 | handle._run() 68 | handle = None 69 | 70 | @contextmanager 71 | def manage_run(self): 72 | """Set up the loop for running.""" 73 | self._check_closed() 74 | old_thread_id = self._thread_id 75 | old_running_loop = asyncio.events._get_running_loop() 76 | try: 77 | self._thread_id = threading.get_ident() 78 | asyncio.events._set_running_loop(self) 79 | self._num_runs_pending += 1 80 | if self._is_proactorloop: 81 | if self._self_reading_future is None: 82 | self.call_soon(self._loop_self_reading) 83 | yield 84 | finally: 85 | self._thread_id = old_thread_id 86 | asyncio.events._set_running_loop(old_running_loop) 87 | self._num_runs_pending -= 1 88 | if self._is_proactorloop: 89 | if ( 90 | self._num_runs_pending == 0 91 | and self._self_reading_future is not None 92 | ): 93 | ov = self._self_reading_future._ov 94 | self._self_reading_future.cancel() 95 | if ov is not None: 96 | self._proactor._unregister(ov) 97 | self._self_reading_future = None 98 | 99 | @contextmanager 100 | def manage_asyncgens(self): 101 | old_agen_hooks = sys.get_asyncgen_hooks() 102 | try: 103 | self._set_coroutine_origin_tracking(self._debug) 104 | if self._asyncgens is not None: 105 | sys.set_asyncgen_hooks( 106 | firstiter=self._asyncgen_firstiter_hook, 107 | finalizer=self._asyncgen_finalizer_hook, 108 | ) 109 | yield 110 | finally: 111 | self._set_coroutine_origin_tracking(False) 112 | if self._asyncgens is not None: 113 | sys.set_asyncgen_hooks(*old_agen_hooks) 114 | 115 | def _check_running(self): 116 | """Do not throw exception if loop is already running.""" 117 | pass 118 | 119 | # pyre-fixme[8]: Attribute has type `(self: AbstractEventLoop) -> None`; used as 120 | # `partial[typing.Any]`. 121 | loop.run_forever = functools.partial(run_forever, loop) 122 | # pyre-fixme[8]: Attribute has type `(self: AbstractEventLoop, future: 123 | # Union[Awaitable[Variable[_T]], Generator[typing.Any, None, Variable[_T]]]) -> 124 | # _T`; used as `partial[typing.Any]`. 125 | loop.run_until_complete = functools.partial(run_until_complete, loop) 126 | # pyre-fixme[16]: `AbstractEventLoop` has no attribute `_run_once`. 127 | loop._run_once = functools.partial(_run_once, loop) 128 | # pyre-fixme[16]: `AbstractEventLoop` has no attribute `_check_running`. 129 | loop._check_running = functools.partial(_check_running, loop) 130 | # pyre-fixme[16]: `AbstractEventLoop` has no attribute `_nest_patched`. 131 | loop._nest_patched = True 132 | # pyre-fixme[16]: `AbstractEventLoop` has no attribute `_num_runs_pending`. 133 | loop._num_runs_pending = 0 134 | # pyre-fixme[16]: `AbstractEventLoop` has no attribute `_is_proactorloop`. 135 | loop._is_proactorloop = os.name == "nt" and isinstance( 136 | loop, 137 | # pyre-fixme[16]: Module `asyncio` has no attribute `ProactorEventLoop`. 138 | asyncio.ProactorEventLoop, 139 | ) 140 | 141 | 142 | # TODO: this is *not* an amazing w 143 | def maybe_nested_loop() -> asyncio.AbstractEventLoop: 144 | try: 145 | original = asyncio.get_running_loop() 146 | except RuntimeError: 147 | original = None 148 | 149 | loop = asyncio.new_event_loop() 150 | if original is None: 151 | return loop 152 | else: 153 | # Need to monkey-patch the loop so it can be re-entrant, which makes things 154 | # work on old versions of Jupyter 155 | # 156 | # It would be better if we could refactor the code to rely more on 157 | # asyncio.run instead of passing the event loop into places, but oh well... 158 | _patch_loop(loop) 159 | return loop 160 | -------------------------------------------------------------------------------- /tests/gpu_tests/test_partitioner_dtensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | import uuid 11 | from collections import defaultdict 12 | from typing import List 13 | 14 | import torch 15 | import torch.distributed as dist 16 | from torch.distributed._tensor import DTensor 17 | from torch.testing._internal.common_utils import ( 18 | instantiate_parametrized_tests, 19 | parametrize, 20 | ) 21 | from torch.testing._internal.distributed._tensor.common_dtensor import ( 22 | DTensorTestBase, 23 | skip_if_lt_x_gpu, 24 | with_comms, 25 | ) 26 | 27 | from torchsnapshot.batcher import batch_write_requests 28 | from torchsnapshot.io_preparer import prepare_read 29 | from torchsnapshot.io_types import ReadIO, WriteIO 30 | 31 | from torchsnapshot.partitioner import ( 32 | consolidate_replicated_entries_dist, 33 | partition_write_reqs, 34 | ) 35 | from torchsnapshot.pg_wrapper import PGWrapper 36 | from torchsnapshot.serialization import ( 37 | BUFFER_PROTOCOL_SUPPORTED_DTYPES, 38 | NCCL_SUPPORTED_DTYPES, 39 | ) 40 | from torchsnapshot.storage_plugins.fs import FSStoragePlugin 41 | from torchsnapshot.test_utils import _dtensor_test_case, rand_tensor, tensor_eq 42 | 43 | WORLD_SIZE: int = 4 44 | 45 | 46 | @instantiate_parametrized_tests 47 | class TestPartitioner(DTensorTestBase): 48 | @parametrize("dtype", NCCL_SUPPORTED_DTYPES) 49 | @parametrize("enable_batcher", [True, False]) 50 | @skip_if_lt_x_gpu(WORLD_SIZE) 51 | # pyre-fixme[56]: While applying decorator 52 | # `torch.testing._internal.distributed._tensor.common_dtensor.with_comms`: For 1st 53 | # argument expected `(object) -> object` but got `(self: TestPartitioner, dtype: 54 | # dtype, enable_batcher: bool) -> Coroutine[typing.Any, typing.Any, None]`. 55 | @with_comms 56 | async def test_partitioner( 57 | self, 58 | dtype: torch.dtype, 59 | enable_batcher: bool, 60 | ) -> None: 61 | """ 62 | Verify the behavior of the partitioner by: 63 | 64 | - Write DTensor objects with the partitioner enabled: 65 | - Optionally enable the batcher 66 | - Read the written objects and compare with the originals 67 | """ 68 | 69 | tensors = [] 70 | entries = {} 71 | write_reqs = defaultdict(list) 72 | 73 | # Use the same seed to simulate replicated-ness 74 | torch.manual_seed(42) 75 | 76 | # DTensors 77 | for idx in range(10): 78 | logical_path = f"replicated_sharded_{idx}" 79 | tensor, entry, wrs = _dtensor_test_case( 80 | dtype=dtype, 81 | shape=[64, 64], 82 | logical_path=logical_path, 83 | rank=dist.get_rank(), 84 | replicated=False, 85 | ) 86 | tensors.append(tensor) 87 | entries[logical_path] = entry 88 | write_reqs[logical_path].extend(wrs) 89 | 90 | # Perform partition 91 | partitioned_entries, partitioned_write_reqs = partition_write_reqs( 92 | entries=entries, write_reqs=write_reqs, pg=PGWrapper(pg=None) 93 | ) 94 | partitioned_write_reqs = [ 95 | wr for wrs in partitioned_write_reqs.values() for wr in wrs 96 | ] 97 | 98 | # The partitioner should work with or without the batcher 99 | if enable_batcher: 100 | batched_entries, batched_write_reqs = batch_write_requests( 101 | entries=list(partitioned_entries.values()), 102 | write_reqs=partitioned_write_reqs, 103 | ) 104 | # Make sure that batching happened 105 | if dtype in BUFFER_PROTOCOL_SUPPORTED_DTYPES: 106 | assert len(batched_write_reqs) < len(partitioned_write_reqs) 107 | partitioned_entries = dict(zip(partitioned_entries.keys(), batched_entries)) 108 | partitioned_write_reqs = batched_write_reqs 109 | 110 | partitioned_entries = consolidate_replicated_entries_dist( 111 | partitioned_entries, pg=PGWrapper(pg=None), dedup=False 112 | ) 113 | 114 | # Verify that all logical paths are still present 115 | for logical_path in entries.keys(): 116 | assert logical_path in partitioned_entries 117 | 118 | # Gather locations to be written by all ranks 119 | locations = [wr.path for wr in partitioned_write_reqs] 120 | # pyre-ignore 121 | obj_list: List[List[str]] = [None] * dist.get_world_size() 122 | dist.all_gather_object(obj_list, locations) 123 | locations = {location for locations in obj_list for location in locations} 124 | 125 | # Verify there are no duplicate write requests 126 | assert len(locations) == len(set(locations)) 127 | 128 | # Fulfill the write requests 129 | plugin = FSStoragePlugin(root=f"/tmp/{uuid.uuid4()}") 130 | for wr in partitioned_write_reqs: 131 | buf = await wr.buffer_stager.stage_buffer() 132 | write_io = WriteIO(path=wr.path, buf=buf) 133 | await plugin.write(write_io) 134 | 135 | # Wait for all ranks to finish writing before begin reading 136 | dist.barrier() 137 | 138 | # Verify the integrity of the writes by loading the persisted tensors and 139 | # comparing them with the original tensors. 140 | dst_tensors = [] 141 | for tensor in tensors: 142 | if type(tensor) == DTensor: 143 | dst_tensors.append( 144 | DTensor.from_local( 145 | local_tensor=rand_tensor(tuple(tensor.shape), dtype=dtype), 146 | device_mesh=tensor.device_mesh, 147 | placements=tensor.placements, 148 | ) 149 | ) 150 | else: 151 | raise AssertionError(f"Unexpected tensor type {type(tensor)}") 152 | 153 | for logical_path, tensor, dst_tensor in zip( 154 | entries.keys(), tensors, dst_tensors 155 | ): 156 | assert not tensor_eq(tensor, dst_tensor) 157 | 158 | entry = partitioned_entries[logical_path] 159 | rrs, _ = prepare_read(entry, obj_out=dst_tensor) 160 | for rr in rrs: 161 | read_io = ReadIO(path=rr.path, byte_range=rr.byte_range) 162 | await plugin.read(read_io) 163 | await rr.buffer_consumer.consume_buffer(read_io.buf.getvalue()) 164 | 165 | assert tensor_eq(tensor, dst_tensor) 166 | 167 | await plugin.close() 168 | -------------------------------------------------------------------------------- /torchsnapshot/dist_store.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # pyre-strict 9 | 10 | from datetime import timedelta 11 | from typing import Dict, Optional 12 | 13 | import torch.distributed as dist 14 | from torch.distributed.elastic.utils.distributed import get_socket_with_port 15 | 16 | from .pg_wrapper import PGWrapper 17 | 18 | 19 | _DEFAULT_TCP_STORE_TIMEOUT = timedelta(seconds=600) 20 | 21 | _pg_to_store: Dict[Optional[dist.ProcessGroup], dist.Store] = {} 22 | 23 | 24 | def get_or_create_store(pg_wrapper: PGWrapper) -> dist.Store: 25 | """ 26 | Get or create a dist.Store. 27 | 28 | If a default store is present, return the store. Otherwise, bootstrap a 29 | store with the input process group. 30 | 31 | Args: 32 | pg_wrapper: The pg with which to bootstrap a store if a default store 33 | is not present. 34 | 35 | Returns: 36 | A dist.Store instance. 37 | """ 38 | store = None 39 | if dist.is_initialized(): 40 | store = dist.distributed_c10d._get_default_store() 41 | 42 | if store is not None: 43 | return store 44 | else: 45 | # The default store is only absent when the global process group is 46 | # initialized with the MPI backend. In this case, we bootstrap a store 47 | # with the input process group. 48 | if pg_wrapper.pg in _pg_to_store: 49 | return _pg_to_store[pg_wrapper.pg] 50 | store = create_store(pg_wrapper=pg_wrapper) 51 | _pg_to_store[pg_wrapper.pg] = store 52 | return store 53 | 54 | 55 | def create_store(pg_wrapper: PGWrapper) -> dist.Store: 56 | """ 57 | Bootstrap a dist.Store with a process group. 58 | 59 | Args: 60 | pg_wrapper: The pg with which to bootstrap a store if a default store 61 | is not present. 62 | 63 | Returns: 64 | The bootstrapped dist.Store instance. 65 | """ 66 | if pg_wrapper.get_rank() == 0: 67 | # Find a free port 68 | sock = get_socket_with_port() 69 | master_addr, master_port, _, _ = sock.getsockname() 70 | sock.close() 71 | # Broadcast master address/port to peers 72 | obj_list = [master_addr, master_port] 73 | else: 74 | # Receive master address/port from the leader rank 75 | obj_list = [None, None] 76 | pg_wrapper.broadcast_object_list(obj_list=obj_list, src=0) 77 | master_addr, master_port = obj_list[0], obj_list[1] 78 | 79 | store = dist.TCPStore( 80 | host_name=master_addr, 81 | port=master_port, 82 | world_size=pg_wrapper.get_world_size(), 83 | is_master=pg_wrapper.get_rank() == 0, 84 | timeout=_DEFAULT_TCP_STORE_TIMEOUT, 85 | wait_for_workers=True, 86 | ) 87 | _pg_to_store[pg_wrapper.pg] = store 88 | return store 89 | 90 | 91 | class LinearBarrier: 92 | """ 93 | A dist.Store-based linear barrier implementation. 94 | 95 | The barrier is performed in two stages: 96 | 97 | arrive - Non-leader ranks notify the leader rank that they've arrived at 98 | the barrier. 99 | 100 | depart - The leader rank notifies non-leader ranks that it has arrived at 101 | the barrier. 102 | 103 | The barrier is separated into two stages because this allows the leader 104 | rank to perform some actions in-between the two stages, with the knowledge 105 | that all ranks have arrived at the barrier, while holding other ranks in 106 | the barrier. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | prefix: str, 112 | store: dist.Store, 113 | rank: int, 114 | world_size: int, 115 | leader_rank: int, 116 | ) -> None: 117 | self.prefix = prefix 118 | self.store = store 119 | self.rank = rank 120 | self.world_size = world_size 121 | self.leader_rank = leader_rank 122 | self.arrived = False 123 | self.departed = False 124 | 125 | def arrive(self, timeout: timedelta) -> None: 126 | """ 127 | The first stage of the barrier. 128 | 129 | Args: 130 | timeout: The timeout for the "arrive" stage. 131 | """ 132 | if self.arrived: 133 | raise RuntimeError("Can't call .arrive() multiple times on a barrier.") 134 | if self.departed: 135 | raise RuntimeError("Can't call .arrive() on a completed barrier.") 136 | self.arrived = True 137 | 138 | if self.rank == self.leader_rank: 139 | peer_keys = [ 140 | self._key(rank=rank) 141 | for rank in range(self.world_size) 142 | if rank != self.leader_rank 143 | ] 144 | self.store.wait(peer_keys, timeout) 145 | for key in peer_keys: 146 | err = self.store.get(key) 147 | if len(err) != 0: 148 | self.report_error(err=str(err)) 149 | raise RuntimeError(str(err)) 150 | else: 151 | self.store.set(self._key(rank=self.rank), "") 152 | 153 | def depart(self, timeout: timedelta) -> None: 154 | """ 155 | The second stage of the barrier. 156 | 157 | Args: 158 | timeout: The timeout for the "depart" stage. 159 | """ 160 | if not self.arrived: 161 | raise RuntimeError( 162 | "Can't call .depart() before calling .arrive() on a barrier." 163 | ) 164 | if self.departed: 165 | raise RuntimeError("Can't call .depart() on a completed barrier.") 166 | self.arrived = True 167 | 168 | if self.rank == self.leader_rank: 169 | self.store.set(self._key(self.leader_rank), "") 170 | else: 171 | leader_key = self._key(rank=self.leader_rank) 172 | self.store.wait([leader_key], timeout) 173 | err = self.store.get(leader_key) 174 | if len(err) != 0: 175 | raise RuntimeError(str(err)) 176 | 177 | def report_error(self, err: str) -> None: 178 | """ 179 | Report the error that prevents the current rank from completing the barrier. 180 | 181 | Leader rank - can report error before calling .depart(). The error will 182 | be received by non-leader ranks in .depart(). 183 | 184 | Non-leader rank - can report error before calling .arrive(). The error 185 | will be received by the leader rank in .arrive() and non-leader ranks 186 | in .depart(). 187 | 188 | Args: 189 | err: The error to be propagated to peer ranks. 190 | """ 191 | self.store.set( 192 | self._key(self.rank), f"Rank {self.rank} encountered error: {err}" 193 | ) 194 | 195 | def _key(self, rank: int) -> str: 196 | return f"{self.prefix}_{rank}" 197 | --------------------------------------------------------------------------------