├── .github
└── workflows
│ ├── build_docs.yml
│ ├── cherry_pick_release.yml
│ ├── lint_code.yml
│ └── unit_test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── SECURITY.md
├── cupti_build.py
├── docs
├── Makefile
├── make.bat
└── source
│ ├── checkpointing
│ ├── async
│ │ ├── api.rst
│ │ ├── api
│ │ │ ├── core.rst
│ │ │ ├── filesystem_async.rst
│ │ │ ├── state_dict_saver.rst
│ │ │ └── torch_ckpt.rst
│ │ ├── examples.rst
│ │ ├── examples
│ │ │ ├── basic_example.rst
│ │ │ └── writer_example.rst
│ │ ├── index.rst
│ │ └── usage_guide.rst
│ └── local
│ │ ├── api.rst
│ │ ├── api
│ │ ├── base_ckpt_manager.rst
│ │ ├── base_state_dict.rst
│ │ ├── basic_state_dict.rst
│ │ ├── callback.rst
│ │ ├── local_ckpt_manager.rst
│ │ └── replication.rst
│ │ ├── examples.rst
│ │ ├── examples
│ │ └── basic_example.rst
│ │ ├── index.rst
│ │ └── usage_guide.rst
│ ├── conf.py
│ ├── fault_tolerance
│ ├── README-pci-topo-file.md
│ ├── api.rst
│ ├── api
│ │ ├── callback.rst
│ │ ├── client.rst
│ │ ├── config.rst
│ │ └── server.rst
│ ├── examples.rst
│ ├── examples
│ │ ├── basic_example.rst
│ │ ├── in_job_and_in_process_example.rst
│ │ ├── train_ddp_heartbeats.rst
│ │ └── train_ddp_sections.rst
│ ├── index.rst
│ ├── integration.rst
│ ├── integration
│ │ ├── heartbeats.rst
│ │ ├── inprocess.rst
│ │ ├── ptl.rst
│ │ └── sections.rst
│ └── usage_guide.rst
│ ├── index.rst
│ ├── inprocess
│ ├── api.rst
│ ├── api
│ │ ├── abort.rst
│ │ ├── compose.rst
│ │ ├── exception.rst
│ │ ├── finalize.rst
│ │ ├── health_check.rst
│ │ ├── initialize.rst
│ │ ├── rank_assignment.rst
│ │ ├── rank_filter.rst
│ │ ├── state.rst
│ │ └── wrap.rst
│ ├── examples.rst
│ ├── examples
│ │ ├── basic_example.rst
│ │ └── optimal_example.rst
│ ├── index.rst
│ └── usage_guide.rst
│ ├── media
│ ├── nvrx_core_features.png
│ └── nvrx_docs_source.png
│ ├── release-notes.md
│ └── straggler_det
│ ├── api.rst
│ ├── api
│ ├── callback.rst
│ ├── reporting.rst
│ ├── statistics.rst
│ └── straggler.rst
│ ├── examples.rst
│ ├── examples
│ └── basic_example.rst
│ ├── index.rst
│ └── usage_guide.rst
├── examples
├── checkpointing
│ ├── async_ckpt.py
│ ├── async_writer.py
│ └── local_ckpt.py
├── fault_tolerance
│ ├── basic_ft_example.py
│ ├── dist_utils.py
│ ├── fault_tol_cfg_heartbeats.yaml
│ ├── fault_tol_cfg_sections.yaml
│ ├── in_job_and_in_process_example.py
│ ├── log_utils.py
│ ├── run_inprocess_injob_example.sh
│ ├── train_ddp_heartbeats_api.py
│ └── train_ddp_sections_api.py
├── inprocess
│ ├── basic_example.py
│ └── optimal_example.py
└── straggler
│ └── example.py
├── pyproject.toml
├── src
└── nvidia_resiliency_ext
│ ├── checkpointing
│ ├── __init__.py
│ ├── async_ckpt
│ │ ├── cached_metadata_filesystem_reader.py
│ │ ├── core.py
│ │ ├── filesystem_async.py
│ │ ├── state_dict_saver.py
│ │ └── torch_ckpt.py
│ ├── local
│ │ ├── __init__.py
│ │ ├── base_state_dict.py
│ │ ├── basic_state_dict.py
│ │ ├── ckpt_managers
│ │ │ ├── base_manager.py
│ │ │ └── local_manager.py
│ │ └── replication
│ │ │ ├── __init__.py
│ │ │ ├── _torch_future.py
│ │ │ ├── group_utils.py
│ │ │ ├── strategies.py
│ │ │ ├── torch_device_utils.py
│ │ │ └── utils.py
│ └── utils.py
│ ├── fault_tolerance
│ ├── __init__.py
│ ├── _ft_rendezvous.py
│ ├── _torch_elastic_compat
│ │ ├── __init__.py
│ │ ├── agent
│ │ │ ├── __init__.py
│ │ │ └── server
│ │ │ │ ├── __init__.py
│ │ │ │ ├── api.py
│ │ │ │ └── local_elastic_agent.py
│ │ ├── events
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ └── handlers.py
│ │ ├── metrics
│ │ │ ├── __init__.py
│ │ │ └── api.py
│ │ ├── multiprocessing
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── errors
│ │ │ │ ├── __init__.py
│ │ │ │ ├── error_handler.py
│ │ │ │ └── handlers.py
│ │ │ ├── redirects.py
│ │ │ ├── subprocess_handler
│ │ │ │ ├── __init__.py
│ │ │ │ ├── handlers.py
│ │ │ │ └── subprocess_handler.py
│ │ │ └── tail_log.py
│ │ ├── rendezvous
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── c10d_rendezvous_backend.py
│ │ │ ├── dynamic_rendezvous.py
│ │ │ ├── etcd_rendezvous.py
│ │ │ ├── etcd_rendezvous_backend.py
│ │ │ ├── etcd_server.py
│ │ │ ├── etcd_store.py
│ │ │ ├── registry.py
│ │ │ ├── static_tcp_rendezvous.py
│ │ │ └── utils.py
│ │ ├── timer
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── file_based_local_timer.py
│ │ │ └── local_timer.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── api.py
│ │ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── cycling_iterator.py
│ │ │ └── elastic_distributed_sampler.py
│ │ │ ├── distributed.py
│ │ │ ├── log_level.py
│ │ │ ├── logging.py
│ │ │ └── store.py
│ ├── config.py
│ ├── data.py
│ ├── dict_utils.py
│ ├── ipc_connector.py
│ ├── launcher.py
│ ├── rank_monitor_client.py
│ ├── rank_monitor_server.py
│ ├── rank_monitor_state_machine.py
│ ├── timeouts_calc.py
│ └── utils.py
│ ├── inprocess
│ ├── __init__.py
│ ├── abort.py
│ ├── attribution.py
│ ├── completion.py
│ ├── compose.py
│ ├── exception.py
│ ├── finalize.py
│ ├── health_check.py
│ ├── initialize.py
│ ├── monitor_process.py
│ ├── monitor_thread.py
│ ├── nested_restarter.py
│ ├── param_utils.py
│ ├── progress_watchdog.py
│ ├── rank_assignment.py
│ ├── rank_filter.py
│ ├── sibling_monitor.py
│ ├── state.py
│ ├── store.py
│ ├── terminate.py
│ ├── tools
│ │ ├── __init__.py
│ │ └── inject_fault.py
│ ├── utils.py
│ └── wrap.py
│ ├── ptl_resiliency
│ ├── __init__.py
│ ├── _utils.py
│ ├── fault_tolerance_callback.py
│ ├── fault_tolerance_sections_callback.py
│ ├── local_checkpoint_callback.py
│ └── straggler_det_callback.py
│ ├── shared_utils
│ ├── __init__.py
│ └── health_check.py
│ └── straggler
│ ├── __init__.py
│ ├── cupti.py
│ ├── cupti_src
│ ├── BufferPool.cpp
│ ├── BufferPool.h
│ ├── CircularBuffer.h
│ ├── CuptiProfiler.cpp
│ ├── CuptiProfiler.h
│ └── cupti_module_py.cpp
│ ├── dist_utils.py
│ ├── interval_tracker.py
│ ├── name_mapper.py
│ ├── reporting.py
│ ├── statistics.py
│ └── straggler.py
└── tests
├── checkpointing
└── unit
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_async_save.py
│ ├── test_async_writer.py
│ ├── test_async_writer_msc.py
│ ├── test_basic_local.py
│ ├── test_cleanup.py
│ └── test_utilities.py
├── fault_tolerance
├── func
│ ├── _launcher_mode_test_worker.py
│ ├── _workload_ctrl_test_worker.py
│ ├── run_launcher_any_failed_mode_test.sh
│ ├── run_launcher_min_healthy_mode_test.sh
│ ├── run_local_ddp_test_heartbeats.sh
│ ├── run_local_ddp_test_sections.sh
│ ├── run_workload_ctrl_test_excl_node.sh
│ └── run_workload_ctrl_test_shutdown.sh
└── unit
│ ├── __init__.py
│ ├── _launcher_test_util.py
│ ├── conftest.py
│ ├── test_config.py
│ ├── test_dynamic_rendezvous.py
│ ├── test_init.py
│ ├── test_ipc_connector.py
│ ├── test_launcher.py
│ ├── test_layered_restart_v1.py
│ ├── test_process_utils.py
│ ├── test_rank_monitor_server.py
│ ├── test_reconnect.py
│ ├── test_shutdown.py
│ ├── test_shutdown_sections.py
│ ├── test_timeouts.py
│ ├── test_timeouts_calc.py
│ ├── test_timeouts_sections.py
│ └── utils.py
├── inprocess
├── __init__.py
├── app.py
├── common.py
├── test_abort.py
├── test_app.py
├── test_compose.py
├── test_health_check.py
├── test_monitor_thread.py
├── test_nested_restarter.py
├── test_progress_watchdog.py
├── test_rank_assignment.py
├── test_timeout.py
├── test_torch.py
└── test_wrap.py
├── ptl_resiliency
├── func
│ └── nemo20
│ │ ├── Dockerfile.ft_test
│ │ ├── check_straggler_log.py
│ │ ├── ft_test_asserts.sh
│ │ ├── ft_test_launchers.sh
│ │ ├── ft_test_llama3.py
│ │ ├── local_ckpt_test.sh
│ │ ├── straggler_test_llama3.py
│ │ └── test_local_ckpt_llama3.py
└── unit
│ ├── __init__.py
│ ├── test_ft_callback_hb.py
│ ├── test_ft_callback_sections.py
│ ├── test_ft_state_machine.py
│ ├── test_local_ckpt_callback.py
│ └── test_straggler_det_callback.py
├── shared_utils
└── test_health_check.py
└── straggler
├── README.md
├── func
├── check_log.py
└── ddp_test.py
└── unit
├── __init__.py
├── _utils.py
├── test_cupti_ext.py
├── test_cupti_manager.py
├── test_data_shared.py
├── test_det_section_api.py
├── test_individual_gpu_scores.py
├── test_interval_tracker.py
├── test_name_mapper.py
├── test_relative_gpu_scores.py
├── test_reporting.py
├── test_reporting_elapsed.py
├── test_sections.py
└── test_wrap_callables.py
/.github/workflows/build_docs.yml:
--------------------------------------------------------------------------------
1 | name: Build and Deploy Docs
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | workflow_dispatch:
8 |
9 | jobs:
10 | build_docs:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - name: Checkout repository
15 | uses: actions/checkout@v4
16 |
17 | - name: Set up Python
18 | uses: actions/setup-python@v4
19 | with:
20 | python-version: '3.10'
21 |
22 | - name: Install dependencies
23 | run: |
24 | pip install -U sphinx sphinx-rtd-theme sphinxcontrib-napoleon sphinx_copybutton lightning psutil defusedxml
25 |
26 | - name: Build Documentation
27 | run: |
28 | sphinx-build -b html docs/source public/
29 | if [ ! -d "public" ]; then
30 | echo "Error: Documentation build failed. 'public/' directory not found."
31 | exit 1
32 | fi
33 |
34 | - name: Deploy to GitHub Pages
35 | run: |
36 | git config --global user.name "GitHub Actions"
37 | git config --global user.email "actions@github.com"
38 |
39 | # Save generated documentation
40 | if [ -d "public" ]; then
41 | echo "Saving generated documentation..."
42 | ls -al public/
43 | mv public /tmp/public_docs
44 | else
45 | echo "Error: 'public/' directory does not exist. Exiting."
46 | exit 1
47 | fi
48 |
49 | # Clean and switch to gh-pages branch
50 | git reset --hard
51 | git clean -fdx
52 |
53 | if git ls-remote --exit-code origin gh-pages; then
54 | git fetch origin gh-pages
55 | git checkout gh-pages
56 | else
57 | git checkout --orphan gh-pages
58 | fi
59 |
60 | # Clean old content and restore new documentation
61 | echo "Cleaning old content..."
62 |
63 | find . -maxdepth 1 ! -name '.git' ! -name '.' -exec rm -rf {} +
64 | echo "Restoring new documentation..."
65 | mv /tmp/public_docs/* .
66 |
67 | # Deploy to GitHub Pages
68 | touch .nojekyll
69 | git add .
70 | if git diff --cached --quiet; then
71 | echo "No changes to commit. Skipping deployment."
72 | exit 0
73 | else
74 | git commit -m "Deploy updated documentation to GitHub Pages from commit $GITHUB_SHA"
75 | git push origin gh-pages --force
76 | fi
77 |
78 |
--------------------------------------------------------------------------------
/.github/workflows/lint_code.yml:
--------------------------------------------------------------------------------
1 | name: Lint Code
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | lint:
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - name: Checkout repository
17 | uses: actions/checkout@v4
18 |
19 | - name: Set up Python
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: '3.10'
23 |
24 | - name: Install linting dependencies
25 | run: |
26 | pip install black==24.10.0 isort==5.13.2 ruff==0.6.9
27 |
28 | - name: Run Black
29 | run: |
30 | black --check .
31 |
32 | - name: Run isort
33 | run: |
34 | isort --check-only .
35 |
36 | - name: Run Ruff
37 | run: |
38 | ruff check .
39 |
--------------------------------------------------------------------------------
/.github/workflows/unit_test.yml:
--------------------------------------------------------------------------------
1 | name: Run Unit Tests
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 |
10 | jobs:
11 |
12 | build_wheels:
13 | runs-on: ubuntu-24.04
14 | container:
15 | image: 'nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04'
16 | steps:
17 | - name: Update GCC
18 | run: |
19 | export DEBIAN_FRONTEND=noninteractive
20 | apt update && apt install -y build-essential gcc-10 g++-10
21 | - name: Install Python versions and pips
22 | run: |
23 | export DEBIAN_FRONTEND=noninteractive
24 | apt update && apt install -y software-properties-common curl
25 | add-apt-repository ppa:deadsnakes/ppa
26 | apt-get install -y python3.10 python3.10-dev
27 | apt-get install -y python3.11 python3.11-dev
28 | apt-get install -y python3.12 python3.12-dev
29 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
30 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
31 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12
32 | - name: Checkout code
33 | uses: actions/checkout@v4
34 | - name: Build wheel with Python 3.10
35 | run: |
36 | python3.10 -m pip install -U setuptools poetry build six pybind11
37 | python3.10 -m poetry build -f wheel
38 | - name: Build wheel with Python 3.11
39 | run: |
40 | python3.11 -m pip install -U setuptools poetry build six pybind11
41 | python3.11 -m poetry build -f wheel
42 | - name: Build wheel with Python 3.12
43 | run: |
44 | python3.12 -m pip install -U setuptools poetry build six pybind11
45 | python3.12 -m poetry build -f wheel
46 | - name: Upload the wheel artifact
47 | uses: actions/upload-artifact@v4
48 | with:
49 | name: resiliency-wheels
50 | path: dist/*.whl
51 |
52 | unit_tests_cpu_subset:
53 | runs-on: ubuntu-24.04
54 | needs: build_wheels
55 | strategy:
56 | matrix:
57 | container:
58 | - 'pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime'
59 | - 'pytorch/pytorch:2.4.1-cuda12.1-cudnn9-runtime'
60 | - 'pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime'
61 | test_type: ['fault_tolerance', 'ptl_resiliency']
62 | container:
63 | image: ${{ matrix.container }}
64 | env:
65 | MKL_SERVICE_FORCE_INTEL: 1 # Fix for "MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library."
66 | steps:
67 | - name: Checkout code
68 | uses: actions/checkout@v4
69 | - name: Download wheels
70 | uses: actions/download-artifact@v4
71 | with:
72 | name: resiliency-wheels
73 | path: ./dist/
74 | - name: Set up environment
75 | run: |
76 | pip install pytest lightning
77 | PY_VER_NODOT=$(python -c"import sysconfig; print(sysconfig.get_config_var('py_version_nodot'))")
78 | pip install ./dist/nvidia_resiliency_ext-*-cp${PY_VER_NODOT}-*.whl
79 | - name: Run unit tests
80 | shell: bash
81 | run: |
82 | if [[ "${{ matrix.test_type }}" == "straggler" ]]; then
83 | STRAGGLER_DET_CPU_TESTS_PATTERN="test_all_gather_object_calls_num \
84 | or test_fail_if_not_initialized \
85 | or test_individual_gpu_scores_one_rank \
86 | or test_relative_gpu_scores_ \
87 | or test_name_mapper_ \
88 | or test_relative_gpu_scores_"
89 | pytest -s -vvv tests/straggler/unit/ -k "${STRAGGLER_DET_CPU_TESTS_PATTERN}"
90 | elif [[ "${{ matrix.test_type }}" == "fault_tolerance" ]]; then
91 | pytest -s -vvv ./tests/fault_tolerance/unit/
92 | elif [[ "${{ matrix.test_type }}" == "ptl_resiliency" ]]; then
93 | pytest -s -vvv ./tests/ptl_resiliency/unit/test_ft_state_machine.py
94 | else
95 | echo "Unknown test type: ${{ matrix.test_type }}"
96 | exit 1
97 | fi
98 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build
2 | dist
3 | *.egg-info
4 | __pycache__
5 | cupti_module.*.so
6 | .pytest_cache
7 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | default_language_version:
2 | python: python3
3 |
4 | repos:
5 |
6 | - repo: https://github.com/PyCQA/isort
7 | rev: 5.13.2
8 | hooks:
9 | - id: isort
10 | exclude: docs/
11 |
12 | - repo: https://github.com/psf/black-pre-commit-mirror
13 | rev: 24.10.0
14 | hooks:
15 | - id: black
16 | language_version: python3.10
17 |
18 | - repo: https://github.com/astral-sh/ruff-pre-commit
19 | rev: v0.6.9
20 | hooks:
21 | - id: ruff
22 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NVIDIA Resiliency Extension
2 |
3 | The NVIDIA Resiliency Extension (NVRx) integrates multiple resiliency-focused solutions for PyTorch-based workloads. Users can modularly integrate NVRx capabilities into their own infrastructure to maximize AI training productivity at scale. NVRx maximizes goodput by enabling system-wide health checks, quickly detecting faults at runtime and resuming training automatically. NVRx minimizes loss of work by enabling fast and frequent checkpointing.
4 |
5 | For detailed documentation and usage information about each component, please refer to https://nvidia.github.io/nvidia-resiliency-ext/.
6 |
7 | > ⚠️ NOTE: This project is still experimental and under active development. The code, features, and documentation are evolving rapidly. Please expect frequent updates and breaking changes. Contributions are welcome and we encourage you to watch for updates.
8 |
9 |
10 |
11 |
12 | ## Core Components and Capabilities
13 |
14 | - **[Fault Tolerance](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/fault_tolerance/index.rst)**
15 | - Detection of hung ranks.
16 | - Restarting training in-job, without the need to reallocate SLURM nodes.
17 |
18 | - **[In-Process Restarting](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/inprocess/index.rst)**
19 | - Detecting failures and enabling quick recovery.
20 |
21 | - **[Async Checkpointing](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/checkpointing/async/index.rst)**
22 | - Providing an efficient framework for asynchronous checkpointing.
23 |
24 | - **[Local Checkpointing](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/checkpointing/local/index.rst)**
25 | - Providing an efficient framework for local checkpointing.
26 |
27 | - **[Straggler Detection](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/straggler_det/index.rst)**
28 | - Monitoring GPU and CPU performance of ranks.
29 | - Identifying slower ranks that may impede overall training efficiency.
30 |
31 | - **[PyTorch Lightning Callbacks](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/fault_tolerance/integration/ptl.rst)**
32 | - Facilitating seamless NVRx integration with PyTorch Lightning.
33 |
34 | ## Installation
35 |
36 | ### From sources
37 | - `git clone https://github.com/NVIDIA/nvidia-resiliency-ext`
38 | - `cd nvidia-resiliency-ext`
39 | - `pip install .`
40 |
41 |
42 | ### From PyPI wheel
43 | - `pip install nvidia-resiliency-ext`
44 |
45 | ### Platform Support
46 |
47 | | Category | Supported Versions / Requirements |
48 | |----------------------|----------------------------------------------------------------------------|
49 | | Architecture | x86_64, arm64 |
50 | | Operating System | Ubuntu 22.04, 24.04 |
51 | | Python Version | >= 3.10, < 3.13 |
52 | | PyTorch Version | >= 2.3.1 (injob & chkpt), >= 2.5.1 (inprocess) |
53 | | CUDA & CUDA Toolkit | >= 12.5 (12.8 required for GPU health check) |
54 | | NVML Driver | >= 535 (570 required for GPU health check) |
55 | | NCCL Version | >= 2.21.5 (injob & chkpt), >= 2.26.2 (inprocess) |
56 |
57 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security
2 |
3 | NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization.
4 |
5 | If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.**
6 |
7 | ## Reporting Potential Security Vulnerability in an NVIDIA Product
8 |
9 | To report a potential security vulnerability in any NVIDIA product:
10 | - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html)
11 | - E-Mail: psirt@nvidia.com
12 | - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key)
13 | - Please include the following information:
14 | - Product/Driver name and version/branch that contains the vulnerability
15 | - Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
16 | - Instructions to reproduce the vulnerability
17 | - Proof-of-concept or exploit code
18 | - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability
19 |
20 | While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information.
21 |
22 | ## NVIDIA Product Security
23 |
24 | For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security
25 |
--------------------------------------------------------------------------------
/cupti_build.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import glob
17 | import os
18 |
19 | from pybind11.setup_helpers import Pybind11Extension, build_ext
20 |
21 |
22 | def find_file_in_dir(cuda_path, sfile):
23 | """
24 | Looks for file under the directory specified by the cuda_path argument.
25 | If file is not found, returns None
26 |
27 | Args:
28 | cuda_path (str): Directory where to look for the files
29 | sfile (str): The file to look for (e.g., 'libcupti.so').
30 |
31 | Returns:
32 | tuple: (directory_of_file1, directory_of_file2) or (None, None) if either file is not found.
33 | """
34 |
35 | for root, _, files in os.walk(cuda_path):
36 | if sfile in files:
37 | return root
38 | return None
39 |
40 |
41 | def _skip_ext_build():
42 | ans = os.environ.get('STRAGGLER_DET_SKIP_CUPTI_EXT_BUILD', '0')
43 | return ans.lower() in ['1', 'on', 'yes', 'true']
44 |
45 |
46 | def build(setup_kwargs):
47 |
48 | if _skip_ext_build():
49 | print(
50 | "WARNING! CUPTI extension wont be build due to STRAGGLER_DET_SKIP_CUPTI_EXT_BUILD flag."
51 | )
52 | return
53 |
54 | include_dirs = None
55 | library_dirs = None
56 |
57 | cuda_path = os.environ.get("CUDA_PATH", "/usr/local/cuda")
58 | if not os.path.isdir(cuda_path):
59 | raise FileNotFoundError("cuda installation not found in /usr/local/cuda or $CUDA_PATH")
60 |
61 | cupti_h = "cupti.h"
62 | libcupti_so = "libcupti.so"
63 | idir = find_file_in_dir(cuda_path, cupti_h)
64 | ldir = find_file_in_dir(cuda_path, libcupti_so)
65 | if idir and ldir:
66 | include_dirs = [idir]
67 | library_dirs = [ldir]
68 | else:
69 | raise FileNotFoundError(f"required files {libcupti_so} and {cupti_h} not found")
70 |
71 | cpp_extension = Pybind11Extension(
72 | 'nvrx_cupti_module',
73 | # Sort .cpp files for reproducibility
74 | sorted(glob.glob('src/nvidia_resiliency_ext/straggler/cupti_src/*.cpp')),
75 | include_dirs=include_dirs,
76 | library_dirs=library_dirs,
77 | libraries=['cupti'],
78 | extra_compile_args=['-O3'],
79 | language='c++',
80 | cxx_std=17,
81 | )
82 | ext_modules = [
83 | cpp_extension,
84 | ]
85 | setup_kwargs.update(
86 | {
87 | "ext_modules": ext_modules,
88 | "cmdclass": {"build_ext": build_ext},
89 | "zip_safe": False,
90 | }
91 | )
92 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/api.rst:
--------------------------------------------------------------------------------
1 | API documentation
2 | ===============================================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: API documentation
7 |
8 | api/core
9 | api/filesystem_async
10 | api/state_dict_saver
11 | api/torch_ckpt
12 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/api/core.rst:
--------------------------------------------------------------------------------
1 | Asynchronous Checkpoint Core Utilities
2 | ======================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.core
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/api/filesystem_async.rst:
--------------------------------------------------------------------------------
1 | Asynchronous FileSystemWriter Implementation
2 | ============================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/api/state_dict_saver.rst:
--------------------------------------------------------------------------------
1 | Asynchronous Pytorch Distributed Checkpoint save with optimized `FileSystemWriter`
2 | ==================================================================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/api/torch_ckpt.rst:
--------------------------------------------------------------------------------
1 | Asynchronous PyTorch `torch.save` with the Core utility
2 | =======================================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ===============================================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Examples
7 |
8 | examples/basic_example.rst
9 | examples/writer_example.rst
10 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/examples/basic_example.rst:
--------------------------------------------------------------------------------
1 | Basic usage example
2 | ===============================================================================
3 |
4 | .. literalinclude:: ../../../../../examples/checkpointing/async_ckpt.py
5 | :language: python
6 | :linenos:
7 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/examples/writer_example.rst:
--------------------------------------------------------------------------------
1 | FileSystemWriter example
2 | ===============================================================================
3 |
4 | .. literalinclude:: ../../../../../examples/checkpointing/async_writer.py
5 | :language: python
6 | :linenos:
7 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/async/index.rst:
--------------------------------------------------------------------------------
1 | Async Checkpointing
2 | =================================
3 |
4 | The asynchronous checkpointing feature in the NVIDIA Resiliency Extension provides core utilities to offload checkpointing routines to the background.
5 | It leverages `torch.multiprocessing` to either fork a temporary process or spawn a persistent process for efficient, non-blocking checkpointing.
6 |
7 | Applications can monitor asynchronous checkpoint progress in a non-blocking manner
8 | and define a custom finalization step once all ranks complete their background checkpoint saving.
9 |
10 | This repository includes an implementation of asynchronous checkpointing utilities for both `torch.save` and `torch.distributed.save_state_dict`.
11 | Our modified `torch.distributed.save_state_dict` interface is integrated with an optimized backend, `FileSystemWriterAsync`, which:
12 | • Runs in the async checkpoint process creating child parallel processes for intra-node parallelism, avoiding GIL contention.
13 | • Minimizes metadata communication overhead by metadata caching, ensuring efficient checkpoint saving.
14 |
15 |
16 | .. toctree::
17 | :maxdepth: 2
18 | :caption: Contents:
19 |
20 | usage_guide
21 | api
22 | examples
23 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/api.rst:
--------------------------------------------------------------------------------
1 | API documentation
2 | ===============================================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: API documentation
7 |
8 | api/callback
9 | api/base_ckpt_manager
10 | api/local_ckpt_manager
11 | api/replication
12 | api/base_state_dict
13 | api/basic_state_dict
14 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/api/base_ckpt_manager.rst:
--------------------------------------------------------------------------------
1 | BaseCheckpointManager
2 | ======================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.ckpt_managers.base_manager
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/api/base_state_dict.rst:
--------------------------------------------------------------------------------
1 | BaseTensorAwareStateDict
2 | ========================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.base_state_dict
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/api/basic_state_dict.rst:
--------------------------------------------------------------------------------
1 | BasicTensorAwareStateDict
2 | =========================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.basic_state_dict
5 | :members: BasicTensorAwareStateDict
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/api/callback.rst:
--------------------------------------------------------------------------------
1 | PTL Callback support
2 | ====================
3 |
4 | .. automodule:: nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/api/local_ckpt_manager.rst:
--------------------------------------------------------------------------------
1 | LocalCheckpointManager
2 | ======================
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/api/replication.rst:
--------------------------------------------------------------------------------
1 | Replication
2 | ===========
3 |
4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.replication.strategies
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ===============================================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Examples
7 |
8 | examples/basic_example.rst
9 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/examples/basic_example.rst:
--------------------------------------------------------------------------------
1 | Basic usage example
2 | ===============================================================================
3 |
4 | .. literalinclude:: ../../../../../examples/checkpointing/local_ckpt.py
5 | :language: python
6 | :linenos:
7 |
--------------------------------------------------------------------------------
/docs/source/checkpointing/local/index.rst:
--------------------------------------------------------------------------------
1 | Local Checkpointing
2 | ===================
3 |
4 | The local checkpointing mechanism is implemented via the Python `LocalCheckpointManager` class,
5 | which operates on a `TensorAwareStateDict` wrapper.
6 | This wrapper encapsulates the operations necessary for efficient replication and data transfers.
7 |
8 | For standard models,
9 | the provided `BasicTensorAwareStateDict` class is typically sufficient for integration.
10 | However, for more advanced use cases, a custom `TensorAwareStateDict` implementation may be required.
11 |
12 | To minimize saving overheads,
13 | integrating the asynchronous version of the `LocalCheckpointManager` method is strongly recommended.
14 |
15 | Features:
16 |
17 | - Local saving:
18 | Each node saves checkpoint parts locally, either on SSDs or RAM disks, as configured by the user.
19 | - Synchronous and asynchronous support:
20 | Save checkpoints either synchronously or asynchronously, based on the application's requirements.
21 | - Automatic cleanup:
22 | Handles the cleanup of broken or outdated checkpoints automatically.
23 | - Optional replication:
24 | The `LocalCheckpointManager.save` method supports an optional replication mechanism
25 | to allow checkpoint recovery in case of node failure after a restart.
26 | - Configurable resiliency:
27 | The replication factor can be adjusted for enhanced resiliency.
28 | - Latest checkpoint detection:
29 | The `find_latest` method in `LocalCheckpointManager` identifies the most recent complete local checkpoint.
30 | - Automated retrieval:
31 | The `LocalCheckpointManager.load` method automatically retrieves valid checkpoint parts that
32 | are unavailable locally.
33 |
34 | For a comprehensive description of this functionality, including detailed
35 | requirements, restrictions, and usage examples, please refer to the :doc:`Usage
36 | Guide ` and :doc:`Examples `.
37 |
38 |
39 | .. toctree::
40 | :maxdepth: 2
41 | :caption: Contents:
42 |
43 | usage_guide
44 | api
45 | examples
46 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import sys
18 |
19 | sys.path.insert(0, os.path.abspath('../../src'))
20 |
21 |
22 | # -- Project information -----------------------------------------------------
23 |
24 | project = 'nvidia-resiliency-ext'
25 | copyright = '2024, NVIDIA Corporation'
26 | author = 'NVIDIA Corporation'
27 |
28 | # The full version, including alpha/beta/rc tags
29 | release = '0.1'
30 |
31 |
32 | # -- General configuration ---------------------------------------------------
33 |
34 | # Add any Sphinx extension module names here, as strings. They can be
35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
36 | # ones.
37 | extensions = [
38 | 'sphinx.ext.autodoc',
39 | 'sphinx.ext.viewcode',
40 | 'sphinx.ext.napoleon',
41 | 'sphinx.ext.intersphinx',
42 | 'sphinx_copybutton',
43 | ]
44 | intersphinx_mapping = {
45 | 'python': ('https://docs.python.org/3', None),
46 | 'torch': ('https://pytorch.org/docs/stable/', None),
47 | }
48 |
49 |
50 | # Add any paths that contain templates here, relative to this directory.
51 | templates_path = ['_templates']
52 |
53 | # List of patterns, relative to source directory, that match files and
54 | # directories to ignore when looking for source files.
55 | # This pattern also affects html_static_path and html_extra_path.
56 | exclude_patterns = []
57 |
58 | autoclass_content = 'both'
59 | autodoc_typehints = 'description'
60 |
61 | # -- Options for HTML output -------------------------------------------------
62 |
63 | # The theme to use for HTML and HTML Help pages. See the documentation for
64 | # a list of builtin themes.
65 | #
66 | html_theme = 'sphinx_rtd_theme'
67 |
68 | # Add any paths that contain custom static files (such as style sheets) here,
69 | # relative to this directory. They are copied after the builtin static files,
70 | # so a file named "default.css" will overwrite the builtin "default.css".
71 | html_static_path = ['_static']
72 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/README-pci-topo-file.md:
--------------------------------------------------------------------------------
1 | # **Providing a PCI Topology File for GPU and NIC Topology Detection**
2 |
3 | ## **Overview**
4 | In certain environments, such as virtual machines (VMs) provided by cloud service providers (CSPs), the PCI device tree may not be fully populated. When this occurs, traversing the system PCI device tree to determine the GPU and NIC topology is not viable.
5 |
6 | To work around this limitation, users can specify a pre-defined PCI topology file using the following option:
7 |
8 | ```
9 | –ft-pci-topo-file=
10 | ```
11 | where `` is an XML file describing the PCI topology.
12 |
13 | ## **XML Format Requirements**
14 | The PCI topology file follows a structured XML format with the following key elements:
15 |
16 | 1. **`` Block:**
17 | - Each CPU in the system is represented by a `` block.
18 |
19 | 2. **`` Bridge Block:**
20 | - Within each `` block, there are one or more `` blocks that represent PCI bridges.
21 | - Each PCI bridge has a unique `busid` attribute.
22 |
23 | 3. **GPU and IB PCI Devices:**
24 | - Within each PCI bridge block, the `` elements represent GPU and InfiniBand (IB) devices.
25 | - Each device has its own `busid` and attributes such as `class`, `link_speed`, and `link_width`.
26 |
27 | ### **Example 1: Single PCI Bridge per CPU**
28 | In this example, each CPU has a single PCI bridge, connecting multiple GPUs and IB devices.
29 |
30 | ```xml
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | ```
48 |
49 | ### **Example 1: Example 2: Multiple PCI Bridges per CPU**
50 | This example shows a topology where each CPU has multiple PCI bridges, with GPUs and IB devices distributed across them.
51 |
52 | ```xml
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 | ```
74 |
75 | ## **Reference Example**
76 | For a detailed working example, refer to the [NDv4 topology file](https://github.com/Azure/azhpc-images/blob/master/topology/ndv4-topo.xml) in the Azure HPC images repository.
77 |
78 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/api.rst:
--------------------------------------------------------------------------------
1 | API documentation
2 | =================
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 | :caption: API documentation
7 |
8 | api/config
9 | api/client
10 | api/server
11 | api/callback
12 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/api/callback.rst:
--------------------------------------------------------------------------------
1 | Callback
2 | ========
3 |
4 | .. automodule:: nvidia_resiliency_ext.ptl_resiliency.fault_tolerance_callback
5 | :members:
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/api/client.rst:
--------------------------------------------------------------------------------
1 | Client
2 | ======
3 |
4 | .. automodule:: nvidia_resiliency_ext.fault_tolerance.rank_monitor_client
5 | :members:
6 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/api/config.rst:
--------------------------------------------------------------------------------
1 | Config
2 | ======
3 |
4 | .. automodule:: nvidia_resiliency_ext.fault_tolerance.config
5 | :members:
6 | :show-inheritance:
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/api/server.rst:
--------------------------------------------------------------------------------
1 | Server
2 | ======
3 |
4 | .. automodule:: nvidia_resiliency_ext.fault_tolerance.rank_monitor_server
5 | :members:
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Examples
7 |
8 | examples/basic_example.rst
9 | examples/train_ddp_heartbeats.rst
10 | examples/train_ddp_sections.rst
11 | examples/in_job_and_in_process_example.rst
12 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/examples/basic_example.rst:
--------------------------------------------------------------------------------
1 | Basic usage example
2 | ===================
3 |
4 | .. literalinclude:: ../../../../examples/fault_tolerance/basic_ft_example.py
5 | :language: python
6 | :linenos:
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/examples/in_job_and_in_process_example.rst:
--------------------------------------------------------------------------------
1 | FT Launcher & Inprocess integration example
2 | ===========================================
3 |
4 | .. literalinclude:: ../../../../examples/fault_tolerance/in_job_and_in_process_example.py
5 | :language: python
6 | :linenos:
7 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/examples/train_ddp_heartbeats.rst:
--------------------------------------------------------------------------------
1 | Heartbeat API usage example with DDP
2 | ====================================
3 |
4 | .. literalinclude:: ../../../../examples/fault_tolerance/train_ddp_heartbeats_api.py
5 | :language: python
6 | :linenos:
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/examples/train_ddp_sections.rst:
--------------------------------------------------------------------------------
1 | Section API usage example with DDP
2 | ==================================
3 |
4 | .. literalinclude:: ../../../../examples/fault_tolerance/train_ddp_sections_api.py
5 | :language: python
6 | :linenos:
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/index.rst:
--------------------------------------------------------------------------------
1 | Fault Tolerance
2 | ===============
3 |
4 | Fault Tolerance is a Python package that features:
5 | * Workload hang detection.
6 | * Automatic calculation of timeouts used for hang detection.
7 | * Detection of rank(s) terminated due to an error.
8 | * Workload respawning in case of a failure.
9 |
10 | Fault Tolerance is included in the ``nvidia_resiliency_ext.fault_tolerance`` package.
11 |
12 | The ``nvidia-resiliency-ext`` package also includes the PTL callback ``FaultToleranceCallback`` that simplifies FT package integration with PyTorch Lightning-based workloads.
13 | ``FaultToleranceCallback`` is included in the ``nvidia_resiliency_ext.ptl_resiliency`` package.
14 |
15 | .. toctree::
16 | :maxdepth: 2
17 | :caption: Contents:
18 |
19 | usage_guide
20 | integration
21 | api
22 | examples
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/integration.rst:
--------------------------------------------------------------------------------
1 | Integration Guides
2 | ==================
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 | :caption: Integration Guides
7 |
8 | integration/heartbeats
9 | integration/sections
10 | integration/ptl
11 | integration/inprocess
12 |
13 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/integration/heartbeats.rst:
--------------------------------------------------------------------------------
1 | Heartbeats API Integration
2 | **************************
3 |
4 | 1. Prerequisites
5 | =================
6 | * Run ranks using ``ft_launcher``. The command line is mostly compatible with ``torchrun``.
7 | * Pass the FT config to the ``ft_launcher``.
8 |
9 | .. note::
10 | Some clusters (e.g., SLURM) use SIGTERM as a default method of requesting a graceful workload shutdown.
11 | It is recommended to implement appropriate signal handling in a fault-tolerant workload.
12 | To avoid deadlocks and other unintended side effects, signal handling should be synchronized across all ranks.
13 |
14 | 2. FT configuration
15 | ====================
16 |
17 | Timeouts for fault detection need to be adjusted for each workload:
18 | * ``initial_rank_heartbeat_timeout`` should be long enough to allow for workload initialization.
19 | * ``rank_heartbeat_timeout`` should be at least as long as the longest possible interval between steps.
20 |
21 | **Importantly, heartbeats are not sent during checkpoint loading and saving**, so the time for checkpoint-related operations should be taken into account.
22 |
23 | Fixed timeout values can be used throughout the training runs, or timeouts can be calculated based on observed heartbeat intervals.
24 | `null` timeout values are interpreted as infinite timeouts. In such cases, values need to be calculated to make the FT usable.
25 |
26 | .. note::
27 | When --ft-initial-rank-heartbeat-timeout and --ft-rank-heartbeat-timeout are not
28 | provided in the command-line arguments, the launcher defaults to FT's predefined values. These are
29 | not null/None; currently, the defaults are 60 minutes for --ft-initial-rank-heartbeat-timeout
30 | and 45 minutes for --ft-rank-heartbeat-timeout.
31 |
32 | Configuration file example:
33 |
34 | .. literalinclude:: ../../../../examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml
35 | :language: yaml
36 | :linenos:
37 |
38 |
39 | A summary of all FT configuration items can be found in :class:`nvidia_resiliency_ext.fault_tolerance.config.FaultToleranceConfig`
40 |
41 | 3. Integration with PyTorch workload code
42 | ============================================
43 | 1. Initialize a ``RankMonitorClient`` instance on each rank with ``RankMonitorClient.init_workload_monitoring()``.
44 | 2. *(Optional)* Restore the state of ``RankMonitorClient`` instances using ``RankMonitorClient.load_state_dict()``.
45 | 3. Periodically send heartbeats from ranks using ``RankMonitorClient.send_heartbeat()``.
46 | 4. *(Optional)* After a sufficient range of heartbeat intervals has been observed, call ``RankMonitorClient.calculate_and_set_hb_timeouts()`` to estimate timeouts.
47 | 5. *(Optional)* Save the ``RankMonitorClient`` instance's ``state_dict()`` to a file so that computed timeouts can be reused in the next run.
48 | 6. Shut down ``RankMonitorClient`` instances using ``RankMonitorClient.shutdown_workload_monitoring()``.
49 |
50 | Please refer to the :doc:`../examples/train_ddp_heartbeats` for an implementation example.
51 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/integration/inprocess.rst:
--------------------------------------------------------------------------------
1 | FT Launcher & Inprocess integration
2 | ***********************************
3 | **FT launcher** integrates with **Inprocess recovery mechanisms**, improving fault tolerance by coordinating injob and inprocess fault recovery.
4 |
5 | 1. Heartbeat Mechanism
6 | ======================
7 | * The **FT launcher heartbeat** remains active throughout execution to detect and mitigate potential hangs.
8 | * Users must configure timeouts manually, ensuring they exceed **inprocess operational timeouts** to prevent conflicts.
9 |
10 | 2. Worker Monitoring & Restart Policy
11 | =====================================
12 | A new ``--ft-restart-policy`` argument in ``ft_launcher`` modifies the default worker monitor logic for better compatibility with :doc:`../../inprocess/index`.
13 |
14 | **Policy Options**
15 |
16 | * ``min-healthy``: Restarts workers only when the number of healthy worker groups falls below minimum specified in ``--nnodes``, as set in ``ft_launcher``.
17 |
18 | .. note::
19 |
20 | For proper behavior, minimum specified in ``--nnodes`` should match the ``inprocess`` restarter setting by either:
21 |
22 | - Ensuring ``inprocess`` operates at the node level like ``injob`` by adding a ``rank_assignment`` filter to the wrapper, or
23 | - Making ``injob`` operate at the rank level like ``inprocess`` by specifying one rank per agent.
24 |
25 | See the `rank assignment guide <../../inprocess/usage_guide.html#rank-assignment>`_ for more details.
26 |
27 | **Example of rank_assignment:**
28 |
29 | .. code-block:: python
30 |
31 | rank_assignment = (
32 | inprocess.Compose(
33 | inprocess.rank_assignment.ShiftRanks(),
34 | inprocess.rank_assignment.FilterGroupedByKey(
35 | key_or_fn=lambda _, _: socket.gethostname(),
36 | condition=lambda count: count == 8,
37 | ),
38 | ),
39 | )
40 |
41 | **Behavior in min-healthy mode:**
42 |
43 | * If enough nodes remain healthy, the worker monitor stays inactive while collaborating with :doc:`../../inprocess/index`..
44 | * If the threshold is breached, ``FT launcher`` takes over and restarts the training process.
45 |
46 |
47 | Supported & Unsupported Configurations
48 | ======================================
49 |
50 | To ensure correct behavior with inprocess:
51 |
52 | ✅ Supported:
53 |
54 | * ``restart-policy=min-healthy`` **(Required)**:
55 |
56 | * Prevents unintended upscaling.
57 | * Disables any-failed worker monitoring.
58 |
59 | ❌ Unsupported:
60 |
61 | * ``any-failed`` with inprocess **(Not allowed)**:
62 |
63 | * Incompatible with inprocess restarts.
64 | * Causes FT launcher to misinterpret terminated processes as failures, triggering unnecessary restarts.
65 | * Enables upscaling, allowing FT launcher to restart training when a new node becomes available.
66 | * Can lead to undefined behavior when combined with inprocess restarts.
67 |
68 | In short, ``any-failed`` must not be used with inprocess, as it disrupts the intended fault recovery process.
69 |
70 | Please refer to the :doc:`../examples/in_job_and_in_process_example` for an implementation example.
71 |
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/integration/ptl.rst:
--------------------------------------------------------------------------------
1 | PyTorch Lightning Integration
2 | *****************************
3 |
4 | This section describes Fault Tolerance integration with a PTL-based workload (i.e., NeMo) using ``FaultToleranceCallback``.
5 |
6 | 1. Use ``ft_launcher`` to start the workload
7 | ============================================
8 |
9 | Fault tolerance relies on a special launcher (``ft_launcher``), which is a modified ``torchrun``.
10 | If you are using NeMo, the `NeMo-Framework-Launcher `_ can be used to generate SLURM batch scripts with FT support.
11 |
12 | 2. Add the FT callback to the PTL trainer
13 | ==========================================
14 |
15 | Add the FT callback to the PTL callbacks.
16 |
17 | .. code-block:: python
18 |
19 | from nvidia_resiliency_ext.ptl_resiliency import FaultToleranceCallback
20 |
21 | fault_tol_cb = FaultToleranceCallback(
22 | autoresume=True,
23 | calculate_timeouts=True,
24 | logger_name="test_logger",
25 | exp_dir=tmp_path,
26 | )
27 |
28 | trainer = pl.Trainer(
29 | ...
30 | callbacks=[..., fault_tol_cb],
31 | resume_from_checkpoint=True,
32 | )
33 |
34 |
35 | Core FT callback functionality includes:
36 | * Establishing a connection with a rank monitor.
37 | * Sending heartbeats during training and evaluation steps.
38 | * Disconnecting from a rank monitor.
39 |
40 | Optionally, it can also:
41 | * Compute timeouts that will be used instead of timeouts defined in the FT config.
42 | * Create a flag file when the training is completed.
43 |
44 | FT callback initialization parameters are described in the ``FaultToleranceCallback`` constructor docstring:
45 | :class:`nvidia_resiliency_ext.ptl_resiliency.fault_tolerance_callback.FaultToleranceCallback`
46 |
47 | 3. Implementing auto-resume
48 | ===========================
49 |
50 | Auto-resume simplifies running training jobs that consist of multiple sequential runs.
51 |
52 | .. note::
53 | Auto-resume is not part of the FT package. It is entirely implemented in a launcher script and the ``FaultToleranceCallback``.
54 |
55 | ``FaultToleranceCallback`` exposes an "interface" that allows implementing an auto-resume launcher script. Specifically, if ``autoresume=True``,
56 | the FT callback creates a special marker file when training is completed. The marker file location is expected to be set in the ``FAULT_TOL_FINISHED_FLAG_FILE`` environment variable.
57 |
58 | The following steps can be used to implement an auto-resume launcher script:
59 | * The launcher script starts ranks with ``ft_launcher``.
60 | * ``FAULT_TOL_FINISHED_FLAG_FILE`` should be passed to rank processes.
61 | * When ``ft_launcher`` exits, the launcher script checks if the ``FAULT_TOL_FINISHED_FLAG_FILE`` file was created.
62 |
63 | * If ``FAULT_TOL_FINISHED_FLAG_FILE`` exists, the auto-resume loop stops, as training is complete.
64 | * If ``FAULT_TOL_FINISHED_FLAG_FILE`` does not exist, the continuation job can be issued (other conditions can be checked, e.g., if the maximum number of failures is not reached).
--------------------------------------------------------------------------------
/docs/source/fault_tolerance/integration/sections.rst:
--------------------------------------------------------------------------------
1 | Sections API Integration
2 | ************************
3 |
4 | 1. Prerequisites
5 | =================
6 | * Run ranks using ``ft_launcher``. The command line is mostly compatible with ``torchrun``.
7 | * Pass the FT config to the ``ft_launcher``.
8 |
9 | .. note::
10 | Some clusters (e.g., SLURM) use SIGTERM as a default method of requesting a graceful workload shutdown.
11 | It is recommended to implement appropriate signal handling in a fault-tolerant workload.
12 | To avoid deadlocks and other unintended side effects, signal handling should be synchronized across all ranks.
13 |
14 | 2. FT configuration
15 | ====================
16 |
17 | With the section-based API, timeouts must be set for all defined sections, which wrap operations like
18 | training/eval steps, checkpoint saving, and initialization. Additionally, an out-of-section timeout
19 | applies when no section is active.
20 |
21 | .. note::
22 | Ensure out-of-section timeout is long enough to accommodate restart overhead, as excessively small values can cause imbalance.
23 | If needed, consider merging sections (e.g., moving 'init' into 'out-of-section') to provide more buffer time.
24 |
25 | Relevant FT configuration items are:
26 | * ``rank_section_timeouts`` is a map of a section name to its timeout.
27 | * ``rank_out_of_section_timeout`` is the out-of-section timeout.
28 |
29 | Fixed timeout values can be used throughout the training runs, or timeouts can be calculated based on observed intervals.
30 | `null` timeout values are interpreted as infinite timeouts. In such cases, values need to be calculated to make the FT usable.
31 |
32 | Config file example:
33 |
34 | .. literalinclude:: ../../../../examples/fault_tolerance/fault_tol_cfg_sections.yaml
35 | :language: yaml
36 | :linenos:
37 |
38 | A summary of all FT configuration items can be found in :class:`nvidia_resiliency_ext.fault_tolerance.config.FaultToleranceConfig`
39 |
40 | 3. Integration with PyTorch workload code
41 | ============================================
42 | 1. Initialize a ``RankMonitorClient`` instance on each rank with ``RankMonitorClient.init_workload_monitoring()``.
43 | 2. *(Optional)* Restore the state of ``RankMonitorClient`` instances using ``RankMonitorClient.load_state_dict()``.
44 | 3. Mark some sections of the code with ``RankMonitorClient.start_section('')`` and ``RankMonitorClient.end_section('')``.
45 | 4. *(Optional)* After a sufficient range of section intervals has been observed, call ``RankMonitorClient.calculate_and_set_section_timeouts()`` to estimate timeouts.
46 | 5. *(Optional)* Save the ``RankMonitorClient`` instance's ``state_dict()`` to a file so that computed timeouts can be reused in the next run.
47 | 6. Shut down ``RankMonitorClient`` instances using ``RankMonitorClient.shutdown_workload_monitoring()``.
48 |
49 | Please refer to the :doc:`../examples/train_ddp_sections` for an implementation example.
50 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | nvidia-resiliency-ext v0.4.0
2 | =============================
3 |
4 | **nvidia-resiliency-ext** is a set of tools developed by NVIDIA to improve large-scale distributed training resiliency.
5 |
6 | .. image:: ./media/nvrx_docs_source.png
7 | :width: 750
8 | :alt: Figure highlighting core NVRx features including automatic restart, hierarchical checkpointing, fault detection and health checks.
9 |
10 | Features
11 | --------
12 |
13 | * `Hang detection and automatic in-job restarting `_
14 | * `In-process restarting `_
15 | * `Async checkpointing `_
16 | * `Local checkpointing `_
17 | * `Straggler (slower ranks) detection `_
18 |
19 | .. toctree::
20 | :maxdepth: 3
21 | :caption: Documentation contents:
22 |
23 | fault_tolerance/index
24 | inprocess/index
25 | checkpointing/async/index
26 | checkpointing/local/index
27 | straggler_det/index
28 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api.rst:
--------------------------------------------------------------------------------
1 | API documentation
2 | ===============================================================================
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 | :caption: API documentation
7 |
8 | api/wrap
9 | api/compose
10 | api/state
11 | api/rank_assignment
12 | api/rank_filter
13 | api/initialize
14 | api/abort
15 | api/finalize
16 | api/health_check
17 | api/exception
18 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/abort.rst:
--------------------------------------------------------------------------------
1 | Abort
2 | ===============================================================================
3 |
4 | .. autoclass:: nvidia_resiliency_ext.inprocess.abort.Abort
5 | :special-members: __call__
6 |
7 | .. automodule:: nvidia_resiliency_ext.inprocess.abort
8 | :members:
9 | :exclude-members: Abort
10 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/compose.rst:
--------------------------------------------------------------------------------
1 | Compose
2 | ===============================================================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.inprocess
5 | :members: Compose
6 | :ignore-module-all:
7 | :no-index:
8 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/exception.rst:
--------------------------------------------------------------------------------
1 | Exceptions
2 | ===============================================================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.inprocess.exception
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/finalize.rst:
--------------------------------------------------------------------------------
1 | Finalize
2 | ===============================================================================
3 |
4 | .. autoclass:: nvidia_resiliency_ext.inprocess.finalize.Finalize
5 | :special-members: __call__
6 |
7 | .. automodule:: nvidia_resiliency_ext.inprocess.finalize
8 | :members:
9 | :exclude-members: Finalize
10 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/health_check.rst:
--------------------------------------------------------------------------------
1 | Health Check
2 | ===============================================================================
3 |
4 | .. autoclass:: nvidia_resiliency_ext.inprocess.health_check.HealthCheck
5 | :special-members: __call__
6 |
7 | .. automodule:: nvidia_resiliency_ext.inprocess.health_check
8 | :members:
9 | :exclude-members: HealthCheck
10 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/initialize.rst:
--------------------------------------------------------------------------------
1 | Initialize
2 | ===============================================================================
3 |
4 | .. autoclass:: nvidia_resiliency_ext.inprocess.initialize.Initialize
5 | :special-members: __call__
6 |
7 | .. automodule:: nvidia_resiliency_ext.inprocess.initialize
8 | :members:
9 | :exclude-members: Initialize
10 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/rank_assignment.rst:
--------------------------------------------------------------------------------
1 | Rank Assignment
2 | ===============================================================================
3 |
4 | Rank Assignment
5 | ---------------
6 | Base class
7 | ^^^^^^^^^^
8 | .. autoclass:: nvidia_resiliency_ext.inprocess.rank_assignment.RankAssignment
9 | :special-members: __call__
10 |
11 | .. autoclass:: nvidia_resiliency_ext.inprocess.rank_assignment.RankAssignmentCtx
12 |
13 | .. autoexception:: nvidia_resiliency_ext.inprocess.rank_assignment.RankDiscarded
14 |
15 | Tree
16 | ^^^^
17 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_assignment
18 | :members: Layer, LayerFlag, Tree
19 |
20 | Composable Rank Assignments
21 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^
22 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_assignment
23 | :members: FillGaps, ShiftRanks, FilterCountGroupedByKey
24 | :no-index:
25 |
26 | Rank Filtering
27 | --------------
28 | Base class
29 | ^^^^^^^^^^
30 | .. autoclass:: nvidia_resiliency_ext.inprocess.rank_assignment.RankFilter
31 | :special-members: __call__
32 |
33 | Rank Filters
34 | ^^^^^^^^^^^^
35 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_assignment
36 | :members: ActivateAllRanks, MaxActiveWorldSize, ActiveWorldSizeDivisibleBy
37 | :no-index:
38 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/rank_filter.rst:
--------------------------------------------------------------------------------
1 | Rank Filter
2 | ===============================================================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_filter
5 | :members:
6 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/state.rst:
--------------------------------------------------------------------------------
1 | State
2 | ===============================================================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.inprocess
5 | :members: Mode
6 | :no-index:
7 |
8 | .. automodule:: nvidia_resiliency_ext.inprocess
9 | :members: FrozenState, State
10 | :no-index:
11 |
--------------------------------------------------------------------------------
/docs/source/inprocess/api/wrap.rst:
--------------------------------------------------------------------------------
1 | Wrapper
2 | ===============================================================================
3 |
4 | .. automodule:: nvidia_resiliency_ext.inprocess
5 | :members: Wrapper, CallWrapper
6 | :ignore-module-all:
7 |
--------------------------------------------------------------------------------
/docs/source/inprocess/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ===============================================================================
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Examples
7 |
8 | examples/basic_example.rst
9 | examples/optimal_example.rst
10 |
--------------------------------------------------------------------------------
/docs/source/inprocess/examples/basic_example.rst:
--------------------------------------------------------------------------------
1 | Basic usage example
2 | ===============================================================================
3 |
4 | .. literalinclude:: ../../../../examples/inprocess/basic_example.py
5 | :language: python
6 | :linenos:
7 |
--------------------------------------------------------------------------------
/docs/source/inprocess/examples/optimal_example.rst:
--------------------------------------------------------------------------------
1 | Optimal usage example
2 | ===============================================================================
3 |
4 | .. literalinclude:: ../../../../examples/inprocess/optimal_example.py
5 | :language: python
6 | :linenos:
7 |
--------------------------------------------------------------------------------
/docs/source/inprocess/index.rst:
--------------------------------------------------------------------------------
1 | Inprocess Restart
2 | =================
3 |
4 | In-process restart mechanism is implemented via a Python function wrapper that
5 | adds restart capabilities to an existing Python function implementing
6 | distributed PyTorch workload. Upon a fault, the wrapped function is restarted
7 | across all distributed ranks, within the same operating system process.
8 | Invoking restart of the wrapped function excludes distributed ranks that are
9 | terminated, missing, or deemed unhealthy. When a failure occurs on any worker,
10 | the wrapper ensures the function restarts simultaneously on all healthy ranks.
11 | This process continues until all ranks complete execution successfully or a
12 | termination condition is met.
13 |
14 | Compared to a traditional scheduler-level restart, restarting within the same
15 | process removes overheads associated with launching a new scheduler job,
16 | starting a container, initializing a new Python interpreter, loading
17 | dependencies, and creating a new CUDA context.
18 |
19 | Restarting in the same process also enables the reuse of pre-instantiated,
20 | process-group- and rank-independent objects across restart attempts. This reuse
21 | eliminates the overhead of repeated reinitialization and minimizes restart
22 | latency.
23 |
24 | Features:
25 |
26 | - automatic deinitialization of PyTorch distributed process group, and restart
27 | of the wrapped function upon encountering an unhandled Python exception in
28 | any distributed rank
29 | - timeout mechanism to detect and recover from deadlocks or livelocks, and a
30 | guarantee that the job is making meaningful forward progress
31 | - modular and customizable rank reassignment and health check functions
32 | - support for pre-allocated and pre-initialized reserve workers
33 | - gradual engineering ramp up: integration with existing codebase may start
34 | from restarting the entire ``main()`` function, then gradually refactor
35 | process-group-independent initialization into a separate function call in
36 | order to maximally reuse Python objects between restarts and minimize fault
37 | recovery overhead
38 |
39 | For a comprehensive description of this functionality, including detailed
40 | requirements, restrictions, and usage examples, please refer to the :doc:`Usage
41 | Guide ` and :doc:`Examples `.
42 |
43 |
44 |
45 | .. toctree::
46 | :maxdepth: 2
47 | :caption: Contents:
48 |
49 | usage_guide
50 | api
51 | examples
52 |
--------------------------------------------------------------------------------
/docs/source/media/nvrx_core_features.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/docs/source/media/nvrx_core_features.png
--------------------------------------------------------------------------------
/docs/source/media/nvrx_docs_source.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/docs/source/media/nvrx_docs_source.png
--------------------------------------------------------------------------------
/docs/source/straggler_det/api.rst:
--------------------------------------------------------------------------------
1 | API documentation
2 | =================
3 |
4 | .. toctree::
5 | :maxdepth: 3
6 | :caption: API documentation
7 |
8 | api/straggler
9 | api/reporting
10 | api/statistics
11 | api/callback
12 |
--------------------------------------------------------------------------------
/docs/source/straggler_det/api/callback.rst:
--------------------------------------------------------------------------------
1 | Callback
2 | ========
3 |
4 | .. automodule:: nvidia_resiliency_ext.ptl_resiliency.straggler_det_callback
5 | :members:
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/docs/source/straggler_det/api/reporting.rst:
--------------------------------------------------------------------------------
1 | Reporting
2 | =========
3 |
4 | .. automodule:: nvidia_resiliency_ext.straggler.reporting
5 | :members:
6 | :show-inheritance:
7 |
8 |
9 |
--------------------------------------------------------------------------------
/docs/source/straggler_det/api/statistics.rst:
--------------------------------------------------------------------------------
1 | Statistics
2 | ==========
3 |
4 | .. automodule:: nvidia_resiliency_ext.straggler.statistics
5 | :members:
6 | :show-inheritance:
7 |
8 |
9 |
--------------------------------------------------------------------------------
/docs/source/straggler_det/api/straggler.rst:
--------------------------------------------------------------------------------
1 | Straggler
2 | =========
3 |
4 | .. automodule:: nvidia_resiliency_ext.straggler.straggler
5 | :members:
6 | :show-inheritance:
7 |
--------------------------------------------------------------------------------
/docs/source/straggler_det/examples.rst:
--------------------------------------------------------------------------------
1 | Examples
2 | ========
3 |
4 | .. toctree::
5 | :maxdepth: 2
6 | :caption: Examples
7 |
8 | examples/basic_example.rst
9 |
--------------------------------------------------------------------------------
/docs/source/straggler_det/examples/basic_example.rst:
--------------------------------------------------------------------------------
1 | Basic usage example
2 | ===================
3 |
4 | .. literalinclude:: ../../../../examples/straggler/example.py
5 | :language: python
6 | :linenos:
--------------------------------------------------------------------------------
/docs/source/straggler_det/index.rst:
--------------------------------------------------------------------------------
1 | Straggler Detection
2 | ===================
3 |
4 | The **Straggler Detection** package's purpose is to detect slower ranks participating in a PyTorch distributed workload.
5 | The ``nvidia-resiliency-ext`` package also includes the PTL callback ``StragglerDetectionCallback`` that simplifies integration with PyTorch Lightning-based workloads.
6 |
7 | Straggler Detection is included in the ``nvidia_resiliency_ext.straggler`` package.
8 | ``StragglerDetectionCallback`` is included in the ``nvidia_resiliency_ext.ptl_resiliency`` package.
9 |
10 | .. toctree::
11 | :maxdepth: 2
12 | :caption: Contents:
13 |
14 | usage_guide
15 | api
16 | examples
--------------------------------------------------------------------------------
/examples/checkpointing/async_ckpt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import torch.nn as nn
8 |
9 | from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint
10 |
11 | # Set up basic logging configuration
12 | logging.basicConfig(level=logging.INFO)
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(
17 | description='Async Checkpointing Basic Example',
18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter,
19 | )
20 | parser.add_argument(
21 | '--ckpt_dir',
22 | default="/tmp/test_async_ckpt/",
23 | help="Checkpoint directory for async checkpoints",
24 | )
25 | parser.add_argument(
26 | '--persistent_queue',
27 | action='store_true',
28 | help="Enables a persistent version of AsyncCallsQueue.",
29 | )
30 | return parser.parse_args()
31 |
32 |
33 | # Define a simple model
34 | class SimpleModel(nn.Module):
35 | def __init__(self):
36 | super(SimpleModel, self).__init__()
37 | self.fc1 = nn.Linear(10, 5) # Linear layer: input size 10, output size 5
38 | self.fc2 = nn.Linear(5, 2) # Linear layer: input size 5, output size 2
39 | self.activation = nn.ReLU() # Activation function: ReLU
40 |
41 | def forward(self, x):
42 | x = self.activation(self.fc1(x))
43 | x = self.fc2(x)
44 | return x
45 |
46 |
47 | def init_distributed_backend(backend="nccl"):
48 | """
49 | Initialize the distributed process group for NCCL backend.
50 | Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set.
51 | """
52 | try:
53 | dist.init_process_group(
54 | backend=backend, # Use NCCL backend
55 | init_method="env://", # Use environment variables for initialization
56 | )
57 | logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.")
58 |
59 | # Ensure each process uses a different GPU
60 | torch.cuda.set_device(dist.get_rank())
61 | except Exception as e:
62 | logging.error(f"Error initializing the distributed backend: {e}")
63 | raise
64 |
65 |
66 | def cleanup(ckpt_dir):
67 | if dist.get_rank() == 0:
68 | logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}")
69 | for file_item in os.scandir(ckpt_dir):
70 | if file_item.is_file():
71 | os.remove(file_item.path)
72 |
73 |
74 | def main():
75 | args = parse_args()
76 | logging.info(f'{args}')
77 |
78 | # Initialize the distributed backend
79 | init_distributed_backend(backend="nccl")
80 |
81 | # Instantiate the model and move to CUDA
82 | model = SimpleModel().to("cuda")
83 | org_sd = model.state_dict()
84 | # Define checkpoint directory and manager
85 | ckpt_dir = args.ckpt_dir
86 | os.makedirs(ckpt_dir, exist_ok=True)
87 | logging.info(f"Created checkpoint directory: {ckpt_dir}")
88 | ckpt_file_name = os.path.join(ckpt_dir, f"ckpt_rank{torch.distributed.get_rank()}.pt")
89 |
90 | ckpt_impl = TorchAsyncCheckpoint(persistent_queue=args.persistent_queue)
91 |
92 | ckpt_impl.async_save(org_sd, ckpt_file_name)
93 |
94 | ckpt_impl.finalize_async_save(blocking=True, no_dist=True, terminate=True)
95 |
96 | loaded_sd = torch.load(ckpt_file_name, map_location="cuda")
97 |
98 | for k in loaded_sd.keys():
99 | assert torch.equal(loaded_sd[k], org_sd[k]), f"loaded_sd[{k}] != org_sd[{k}]"
100 |
101 | # Synchronize processes to ensure all have completed the loading
102 | dist.barrier()
103 |
104 | # Clean up checkpoint directory only on rank 0
105 | cleanup(ckpt_dir)
106 |
107 | # Ensure NCCL process group is properly destroyed
108 | if dist.is_initialized():
109 | dist.destroy_process_group()
110 |
111 |
112 | if __name__ == "__main__":
113 | main()
114 |
--------------------------------------------------------------------------------
/examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml:
--------------------------------------------------------------------------------
1 | fault_tolerance:
2 | initial_rank_heartbeat_timeout: null
3 | rank_heartbeat_timeout: null
4 | log_level: "DEBUG"
5 |
--------------------------------------------------------------------------------
/examples/fault_tolerance/fault_tol_cfg_sections.yaml:
--------------------------------------------------------------------------------
1 | fault_tolerance:
2 | initial_rank_heartbeat_timeout: null
3 | rank_heartbeat_timeout: null
4 | rank_section_timeouts:
5 | init: 20
6 | step: 10
7 | checkpoint: 30
8 | rank_out_of_section_timeout: 30
9 | log_level: "DEBUG"
10 |
--------------------------------------------------------------------------------
/examples/fault_tolerance/log_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import os
18 | import sys
19 |
20 | import dist_utils
21 |
22 |
23 | def setup_logging(log_all_ranks=True, filename=os.devnull, filemode='w'):
24 | """
25 | Configures logging.
26 | By default logs from all workers are printed to the stderr, entries are
27 | prefixed with "N: " where N is the rank of the worker. Logs printed to the stderr don't include timestaps.
28 | Full logs with timestamps are saved to the log_file file.
29 | """
30 |
31 | class RankFilter(logging.Filter):
32 | def __init__(self, rank, log_all_ranks):
33 | self.rank = rank
34 | self.log_all_ranks = log_all_ranks
35 |
36 | def filter(self, record):
37 | record.rank = self.rank
38 | if self.log_all_ranks:
39 | return True
40 | else:
41 | return self.rank == 0
42 |
43 | rank = dist_utils.get_rank()
44 | rank_filter = RankFilter(rank, log_all_ranks)
45 |
46 | if log_all_ranks:
47 | logging_format = f"%(asctime)s - %(levelname)s - {rank} - %(message)s"
48 | else:
49 | logging_format = "%(asctime)s - %(levelname)s - %(message)s"
50 | if rank != 0:
51 | filename = os.devnull
52 |
53 | for handler in logging.root.handlers[:]:
54 | logging.root.removeHandler(handler)
55 | handler.close()
56 |
57 | logging.basicConfig(
58 | level=logging.DEBUG,
59 | format=logging_format,
60 | datefmt="%Y-%m-%d %H:%M:%S",
61 | filename=filename,
62 | filemode=filemode,
63 | )
64 | stderr = logging.StreamHandler(sys.stderr)
65 | stderr.setLevel(logging.DEBUG)
66 | if log_all_ranks:
67 | formatter = logging.Formatter(f'{rank}: %(message)s')
68 | else:
69 | formatter = logging.Formatter('%(message)s')
70 | stderr.setFormatter(formatter)
71 | logging.getLogger('').addHandler(stderr)
72 | logging.getLogger('').addFilter(rank_filter)
73 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | [tool.poetry]
5 | name = "nvidia-resiliency-ext"
6 | repository = "https://github.com/NVIDIA/nvidia-resiliency-ext"
7 | version = "0.4.0"
8 | description = "NVIDIA Resiliency Package"
9 | authors = ["NVIDIA Corporation"]
10 | readme = "README.md"
11 | license = "Apache 2.0"
12 | classifiers = [
13 | "Development Status :: 4 - Beta",
14 | "Programming Language :: Python :: 3",
15 | "Programming Language :: Python :: 3.10",
16 | "Operating System :: OS Independent",
17 | ]
18 | packages = [
19 | { include = "nvidia_resiliency_ext", from = "src" },
20 | ]
21 |
22 | exclude = [
23 | "src/nvidia_resiliency_ext/straggler/cupti_src"
24 | ]
25 |
26 | [tool.poetry.build]
27 | script = "cupti_build.py"
28 | generate-setup-file = true
29 |
30 | [build-system]
31 | requires = ["poetry-core>=1.0.0", "pybind11", "setuptools", "wheel"]
32 | build-backend = "poetry.core.masonry.api"
33 |
34 | [tool.poetry.dependencies]
35 | torch = ">=2.3.0"
36 | packaging = "*"
37 | python = ">=3.10"
38 | psutil = ">=6.0.0"
39 | pyyaml = "*"
40 | pynvml = ">=12.0.0"
41 | nvidia-ml-py = ">=12.570.86"
42 | defusedxml = "*"
43 |
44 | [tool.poetry.scripts]
45 | ft_launcher = "nvidia_resiliency_ext.fault_tolerance.launcher:main"
46 |
47 |
48 | [tool.isort]
49 | profile = "black" # black-compatible
50 | line_length = 100 # should match black parameters
51 | py_version = 310 # python 3.10 as a target version
52 | # filter_files and extend_skip_glob are needed for pre-commit to filter out the files
53 | filter_files = true
54 | extend_skip_glob = [
55 | "setup.py",
56 | "cupti_build.py",
57 | "src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/*"
58 | ]
59 |
60 | [tool.black]
61 | line_length = 100
62 | skip_string_normalization = true
63 | # major year version is stable, see details in
64 | # https://black.readthedocs.io/en/stable/the_black_code_style/index.html
65 | # `required_version` is necessary for consistency (other `black` versions will fail to reformat files)
66 | required_version = "24"
67 | target-version = ['py310', 'py311', 'py312']
68 | force-exclude = '''
69 | # Force exclude, as this config is also used by pre-commit
70 | # https://stackoverflow.com/questions/73247204/black-not-respecting-extend-exclude-in-pyproject-toml
71 | # A regex preceded with ^/ will apply only to files and directories # in the root of the project.
72 | (
73 | ^\/setup.py\/
74 | | ^\/build.py\/
75 | | ^\/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat\/
76 | )
77 | '''
78 |
79 | [tool.ruff]
80 | extend-exclude = [
81 | "setup.py",
82 | "build.py",
83 | "src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat"
84 | ]
85 |
86 | [tool.ruff.lint]
87 | # F841 Local variable `...` is assigned to but never used
88 | ignore = ["F841"]
89 |
90 | [tool.ruff.lint.per-file-ignores]
91 | # avoid "unused import" warnings for __init__.py files
92 | "__init__.py" = ["F401"]
93 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/checkpointing/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/checkpointing/async_ckpt/cached_metadata_filesystem_reader.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ FS Reader with metadata cached support. """
17 |
18 | import os
19 | from typing import Union
20 |
21 | from torch.distributed.checkpoint import FileSystemReader, Metadata
22 |
23 |
24 | class CachedMetadataFileSystemReader(FileSystemReader):
25 | """
26 | Extends FileSystemReader to cache metadata for improved performance.
27 |
28 | Attributes:
29 | _cached_metadata (Metadata or None): Cached metadata from the file system.
30 | """
31 |
32 | def __init__(self, path: Union[str, os.PathLike]) -> None:
33 | """
34 | Initialize with file system path.
35 |
36 | Args:
37 | path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
38 | """
39 | super().__init__(path=path)
40 | self._cached_metadata = None
41 |
42 | def read_metadata(self) -> Metadata:
43 | """
44 | Read metadata from file system, caching for subsequent calls.
45 |
46 | Returns:
47 | Metadata: Checkpoint metadata.
48 | """
49 | if self._cached_metadata is None:
50 | self._cached_metadata = super().read_metadata()
51 | return self._cached_metadata
52 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """
17 | TorchAsyncCheckpoint defines a wrapper for the async version of `torch.save` with
18 | an additional method to synchronize async saving requests
19 | """
20 |
21 |
22 | import logging
23 |
24 | import torch
25 |
26 | from ..utils import preload_tensors, wrap_for_async
27 | from .core import AsyncCallsQueue, AsyncRequest
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 |
32 | class TorchAsyncCheckpoint(object):
33 | async_fn = None
34 |
35 | def __init__(self, persistent_queue=False):
36 | self.save = torch.save
37 | self._async_calls_queue = AsyncCallsQueue(persistent=persistent_queue)
38 | # Use direct torch.save for persistent queue, avoid unnecessary wrapping
39 | TorchAsyncCheckpoint.async_fn = (
40 | torch.save if persistent_queue else wrap_for_async(torch.save)
41 | )
42 |
43 | def async_save(self, state_dict, *args, **kwargs):
44 | """
45 | Keeps the original interface of `torch.save`
46 | Schedules a `AsyncReuqest` with preloading tensors to CPU with pinned memcpy
47 | """
48 |
49 | preloaded_sd = preload_tensors(state_dict)
50 | torch.cuda.synchronize()
51 | async_request = AsyncRequest(
52 | TorchAsyncCheckpoint.async_fn, (preloaded_sd, *args), [], kwargs
53 | )
54 | self._async_calls_queue.schedule_async_request(async_request)
55 |
56 | def finalize_async_save(self, blocking: bool = False, no_dist=True, terminate=False):
57 | """Finalizes active async save calls.
58 |
59 | Args:
60 | blocking (bool, optional): if True, will wait until all active requests
61 | are done. Otherwise, finalizes only the async request that already
62 | finished. Defaults to False.
63 | no_dist (bool, Optional): if True, training ranks simply check its
64 | asynchronous checkpoint writer without synchronization.
65 | terminate (bool, optional): if True, the asynchronous queue will
66 | be closed as the last action of this function.
67 | """
68 | if blocking and self._async_calls_queue.get_num_unfinalized_calls() > 0:
69 | if torch.distributed.get_rank() == 0:
70 | logger.info(
71 | 'Unfinalized async checkpoint saves. Finalizing them synchronously now.'
72 | )
73 |
74 | self._async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=no_dist)
75 | if terminate:
76 | self._async_calls_queue.close()
77 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/checkpointing/local/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/checkpointing/local/replication/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/checkpointing/local/replication/utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | def zip_strict(*args):
18 | """
19 | Alternative to Python's builtin zip(..., strict=True) (available in 3.10+).
20 | Apart from providing functionality in earlier versions of Python is also more verbose.
21 | (Python's zip does not print lengths, only which iterable has finished earlier)
22 | """
23 | args = [list(a) for a in args]
24 | lens = [len(a) for a in args]
25 | assert len(set(lens)) <= 1, f"Tried to zip iterables of unequal lengths: {lens}!"
26 | return zip(*args)
27 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .config import FaultToleranceConfig # noqa: F401
17 | from .data import WorkloadAction # noqa: F401
18 | from .data import WorkloadControlRequest # noqa: F401
19 | from .rank_monitor_client import RankMonitorClient # noqa: F401
20 | from .rank_monitor_client import RankMonitorClientError # noqa: F401
21 | from .rank_monitor_server import RankMonitorServer # noqa: F401
22 | from .rank_monitor_state_machine import InvalidStateTransitionException # noqa: F401
23 | from .rank_monitor_state_machine import RankMonitorState # noqa: F401
24 | from .rank_monitor_state_machine import RankMonitorStateMachine # noqa: F401
25 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env/python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # SPDX-License-Identifier: BSD-3-Clause
10 | # Modifications made by NVIDIA
11 | # - This package is a copy of `torch.distributed.elastic` from PyTorch version 2.3
12 | # - All occurences of 'torch.distributed.elastic' were replaced with 'fault_tolerance._torch_elastic_compat'
13 |
14 | """
15 |
16 | Torchelastic agent and user worker failover contract:
17 |
18 | **TL;DR;**:
19 |
20 | * TE(torchelastic) expects user workers to finish with the 5 minutes drift
21 | * It is better to design DDP app to fail for all workers, rather than a single one.
22 | * TE does not synchronize number of restarts between agents
23 | * TE re-rendezvous does not trigger restart decrease
24 | * When a single agent finishes its job(successfully or not), it will close rendezvous.
25 | If other agents still have workers in progress, they will be terminated.
26 | * Based on above, scale down does not work if at least single agent finishes the job.
27 | * When Scale up is detected by agents, it will not decrease ``max_restarts``
28 |
29 |
30 | In general TE(torchelastic) can launch arbitrary user code, but there is some
31 | clarifications need to be done around what failover mechanism torchelastic
32 | provides and what failover mechanism it expects from user workers.
33 |
34 | Torchelastic currently supports DDP style applications. That means that
35 | TE expects *ALL* workers finish approximately at the same time. In practice,
36 | it is nearly to impossible to guarantee that all workers in arbitrary
37 | DDP application finish at the time, so TE provides a finalization barrier
38 | that waits for TIMEOUT(5 minutes) for worker finalization.
39 |
40 | **Worker Failure**
41 |
42 | When worker fails, TE will check the number of restarts
43 | available, if there is more than 0 restarts, TE will start a new rendezvous
44 | round and restart the worker process. New rendezvous round will other
45 | TE agents to terminate their workers.
46 |
47 | .. note:: The TE agent does not synchronize restarts between themselves.
48 | When a single agent performs restart, it will trigger a local ``max_restarts``
49 | decrease, other agent will not decrease their ``max_restarts``.
50 | the user to run the distributed application locally on a dev host.
51 |
52 | A single worker failure can cause the whole cluster to fail:
53 | If a single worker is constantly failing, it will cause the TE agent
54 | ``max_restarts`` to go to zero. This will cause an agent to finish its
55 | work and close rendezvous. If there are any other workers on different
56 | agents, they will be terminated.
57 |
58 |
59 | **Re-Rendezvous**
60 |
61 | Re-rendezvous occurs when TE agents detect a new node
62 | trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents
63 | will terminate its workers and start a new rendezvous round.
64 |
65 | Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous
66 | has already max_nodes, the new node won't be added to the wait list right
67 | away since there is no need to tear down a rendezvous that is already fully
68 | utilized. The new node will wait until its timeout (600 secs by default)
69 | and periodically check the number of participants. If the number becomes
70 | less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs.
71 |
72 | *Scale up event*. When scale up event happens, torchelastic rendezvous
73 | will detect that there are new nodes trying to join. Torchelastic agent
74 | will stop all workers and perform re-rendezvous. Note: when scale up event
75 | happens, *``max_restarts``* will *not* decrease.
76 |
77 | *Scale down event*. When scale down event happens, rendezvous will not
78 | notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` ,
79 | it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` ,
80 | TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*.
81 |
82 | """
83 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/agent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/agent/__init__.py
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/agent/server/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | """
10 | The elastic agent is the control plane of torchelastic.
11 |
12 | It is a process that launches and manages underlying worker processes.
13 | The agent is responsible for:
14 |
15 | 1. Working with distributed torch: the workers are started with all the
16 | necessary information to successfully and trivially call
17 | ``torch.distributed.init_process_group()``.
18 |
19 | 2. Fault tolerance: monitors workers and upon detecting worker failures
20 | or unhealthiness, tears down all workers and restarts everyone.
21 |
22 | 3. Elasticity: Reacts to membership changes and restarts workers with the new
23 | members.
24 |
25 | The simplest agents are deployed per node and works with local processes.
26 | A more advanced agent can launch and manage workers remotely. Agents can
27 | be completely decentralized, making decisions based on the workers it manages.
28 | Or can be coordinated, communicating to other agents (that manage workers
29 | in the same job) to make a collective decision.
30 | """
31 |
32 | from .api import ( # noqa: F401
33 | ElasticAgent,
34 | RunResult,
35 | SimpleElasticAgent,
36 | Worker,
37 | WorkerGroup,
38 | WorkerSpec,
39 | WorkerState,
40 | )
41 | from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE
42 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/events/api.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import json
10 | from dataclasses import asdict, dataclass, field
11 | from enum import Enum
12 | from typing import Dict, Union, Optional
13 |
14 | __all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent']
15 |
16 | EventMetadataValue = Union[str, int, float, bool, None]
17 |
18 |
19 | class EventSource(str, Enum):
20 | """Known identifiers of the event producers."""
21 |
22 | AGENT = "AGENT"
23 | WORKER = "WORKER"
24 |
25 |
26 | @dataclass
27 | class Event:
28 | """
29 | The class represents the generic event that occurs during the torchelastic job execution.
30 |
31 | The event can be any kind of meaningful action.
32 |
33 | Args:
34 | name: event name.
35 | source: the event producer, e.g. agent or worker
36 | timestamp: timestamp in milliseconds when event occurred.
37 | metadata: additional data that is associated with the event.
38 | """
39 |
40 | name: str
41 | source: EventSource
42 | timestamp: int = 0
43 | metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
44 |
45 | def __str__(self):
46 | return self.serialize()
47 |
48 | @staticmethod
49 | def deserialize(data: Union[str, "Event"]) -> "Event":
50 | if isinstance(data, Event):
51 | return data
52 | if isinstance(data, str):
53 | data_dict = json.loads(data)
54 | data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
55 | return Event(**data_dict)
56 |
57 | def serialize(self) -> str:
58 | return json.dumps(asdict(self))
59 |
60 |
61 | class NodeState(str, Enum):
62 | """The states that a node can be in rendezvous."""
63 |
64 | INIT = "INIT"
65 | RUNNING = "RUNNING"
66 | SUCCEEDED = "SUCCEEDED"
67 | FAILED = "FAILED"
68 |
69 |
70 | @dataclass
71 | class RdzvEvent:
72 | """
73 | Dataclass to represent any rendezvous event.
74 |
75 | Args:
76 | name: Event name. (E.g. Current action being performed)
77 | run_id: The run id of the rendezvous
78 | message: The message describing the event
79 | hostname: Hostname of the node
80 | pid: The process id of the node
81 | node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED)
82 | master_endpoint: The master endpoint for the rendezvous store, if known
83 | rank: The rank of the node, if known
84 | local_id: The local_id of the node, if defined in dynamic_rendezvous.py
85 | error_trace: Error stack trace, if this is an error event.
86 | """
87 |
88 | name: str
89 | run_id: str
90 | message: str
91 | hostname: str
92 | pid: int
93 | node_state: NodeState
94 | master_endpoint: str = ""
95 | rank: Optional[int] = None
96 | local_id: Optional[int] = None
97 | error_trace: str = ""
98 |
99 | def __str__(self):
100 | return self.serialize()
101 |
102 | @staticmethod
103 | def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent":
104 | if isinstance(data, RdzvEvent):
105 | return data
106 | if isinstance(data, str):
107 | data_dict = json.loads(data)
108 | data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
109 | return RdzvEvent(**data_dict)
110 |
111 | def serialize(self) -> str:
112 | return json.dumps(asdict(self))
113 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/events/handlers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import logging
10 | from typing import Dict
11 |
12 |
13 | _log_handlers: Dict[str, logging.Handler] = {
14 | "console": logging.StreamHandler(),
15 | "dynamic_rendezvous": logging.NullHandler(),
16 | "null": logging.NullHandler(),
17 | }
18 |
19 |
20 | def get_logging_handler(destination: str = "null") -> logging.Handler:
21 | global _log_handlers
22 | return _log_handlers[destination]
23 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/errors/handlers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 | # Multiprocessing error-reporting module
9 |
10 | # SPDX-License-Identifier: BSD-3-Clause
11 | # Modifications made by NVIDIA
12 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat'
13 |
14 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.errors.error_handler import ErrorHandler
15 |
16 | __all__ = ['get_error_handler']
17 |
18 | def get_error_handler():
19 | return ErrorHandler()
20 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/redirects.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # Taken and modified from original source:
10 | # https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
11 | import ctypes
12 | import logging
13 | import os
14 | import sys
15 | from contextlib import contextmanager
16 | from functools import partial
17 |
18 | IS_WINDOWS = sys.platform == "win32"
19 | IS_MACOS = sys.platform == "darwin"
20 |
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def get_libc():
26 | if IS_WINDOWS or IS_MACOS:
27 | logger.warning(
28 | "NOTE: Redirects are currently not supported in Windows or MacOs."
29 | )
30 | return None
31 | else:
32 | return ctypes.CDLL("libc.so.6")
33 |
34 |
35 | libc = get_libc()
36 |
37 |
38 | def _c_std(stream: str):
39 | return ctypes.c_void_p.in_dll(libc, stream)
40 |
41 |
42 | def _python_std(stream: str):
43 | return {"stdout": sys.stdout, "stderr": sys.stderr}[stream]
44 |
45 |
46 | _VALID_STD = {"stdout", "stderr"}
47 |
48 |
49 | @contextmanager
50 | def redirect(std: str, to_file: str):
51 | """
52 | Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``.
53 |
54 | This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``).
55 | See usage for details.
56 |
57 | Directory of ``dst_filename`` is assumed to exist and the destination file
58 | is overwritten if it already exists.
59 |
60 | .. note:: Due to buffering cross source writes are not guaranteed to
61 | appear in wall-clock order. For instance in the example below
62 | it is possible for the C-outputs to appear before the python
63 | outputs in the log file.
64 |
65 | Usage:
66 |
67 | ::
68 |
69 | # syntactic-sugar for redirect("stdout", "tmp/stdout.log")
70 | with redirect_stdout("/tmp/stdout.log"):
71 | print("python stdouts are redirected")
72 | libc = ctypes.CDLL("libc.so.6")
73 | libc.printf(b"c stdouts are also redirected"
74 | os.system("echo system stdouts are also redirected")
75 |
76 | print("stdout restored")
77 |
78 | """
79 | if std not in _VALID_STD:
80 | raise ValueError(
81 | f"unknown standard stream <{std}>, must be one of {_VALID_STD}"
82 | )
83 |
84 | c_std = _c_std(std)
85 | python_std = _python_std(std)
86 | std_fd = python_std.fileno()
87 |
88 | def _redirect(dst):
89 | libc.fflush(c_std)
90 | python_std.flush()
91 | os.dup2(dst.fileno(), std_fd)
92 |
93 | with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst:
94 | _redirect(dst)
95 | try:
96 | yield
97 | finally:
98 | _redirect(orig_std)
99 |
100 |
101 | redirect_stdout = partial(redirect, "stdout")
102 | redirect_stderr = partial(redirect, "stderr")
103 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/subprocess_handler/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # SPDX-License-Identifier: BSD-3-Clause
10 | # Modifications made by NVIDIA
11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat'
12 |
13 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.subprocess_handler.handlers import (
14 | get_subprocess_handler,
15 | )
16 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.subprocess_handler.subprocess_handler import (
17 | SubprocessHandler,
18 | )
19 |
20 | __all__ = ["SubprocessHandler", "get_subprocess_handler"]
21 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/subprocess_handler/handlers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # SPDX-License-Identifier: BSD-3-Clause
10 | # Modifications made by NVIDIA
11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat'
12 |
13 | from typing import Dict, Tuple
14 |
15 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.subprocess_handler.subprocess_handler import (
16 | SubprocessHandler,
17 | )
18 |
19 | __all__ = ["get_subprocess_handler"]
20 |
21 |
22 | def get_subprocess_handler(
23 | entrypoint: str,
24 | args: Tuple,
25 | env: Dict[str, str],
26 | stdout: str,
27 | stderr: str,
28 | local_rank_id: int,
29 | ):
30 | return SubprocessHandler(
31 | entrypoint=entrypoint,
32 | args=args,
33 | env=env,
34 | stdout=stdout,
35 | stderr=stderr,
36 | local_rank_id=local_rank_id,
37 | )
38 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/subprocess_handler/subprocess_handler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # SPDX-License-Identifier: BSD-3-Clause
10 | # Modifications made by NVIDIA
11 | # Added shell=False to Popen to mitigate security thread
12 | # Added suppression for subprocess low serverity issue
13 |
14 | import os
15 | import signal
16 | # Issue: [B404:blacklist] Consider possible security implications associated with the subprocess module.
17 | # Severity: Low Confidence: High
18 | # CWE: CWE-78 (https://cwe.mitre.org/data/definitions/78.html)
19 | # More Info: https://bandit.readthedocs.io/en/1.7.9/blacklists/blacklist_imports.html#b404-import-subprocess
20 | import subprocess # nosec
21 | import sys
22 |
23 | from typing import Any, Dict, Optional, Tuple
24 |
25 | __all__ = ["SubprocessHandler"]
26 |
27 | IS_WINDOWS = sys.platform == "win32"
28 |
29 |
30 | def _get_default_signal() -> signal.Signals:
31 | """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows."""
32 | if IS_WINDOWS:
33 | return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
34 | else:
35 | return signal.SIGTERM
36 |
37 |
38 | class SubprocessHandler:
39 | """
40 | Convenience wrapper around python's ``subprocess.Popen``. Keeps track of
41 | meta-objects associated to the process (e.g. stdout and stderr redirect fds).
42 | """
43 |
44 | def __init__(
45 | self,
46 | entrypoint: str,
47 | args: Tuple,
48 | env: Dict[str, str],
49 | stdout: str,
50 | stderr: str,
51 | local_rank_id: int,
52 | ):
53 | self._stdout = open(stdout, "w") if stdout else None
54 | self._stderr = open(stderr, "w") if stderr else None
55 | # inherit parent environment vars
56 | env_vars = os.environ.copy()
57 | env_vars.update(env)
58 |
59 | args_str = (entrypoint, *[str(e) for e in args])
60 | self.local_rank_id = local_rank_id
61 | self.proc: subprocess.Popen = self._popen(args_str, env_vars)
62 |
63 | def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen:
64 | kwargs: Dict[str, Any] = {}
65 | if not IS_WINDOWS:
66 | kwargs["start_new_session"] = True
67 | # Issue: [B603:subprocess_without_shell_equals_true] subprocess call - check for execution of untrusted input.
68 | # Severity: Low Confidence: High
69 | # CWE: CWE-78 (https://cwe.mitre.org/data/definitions/78.html)
70 | # More Info: https://bandit.readthedocs.io/en/1.7.9/plugins/b603_subprocess_without_shell_equals_true.html
71 | return subprocess.Popen(
72 | # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes],
73 | # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got
74 | # `Tuple[str, *Tuple[Any, ...]]`.
75 | args=args,
76 | env=env,
77 | stdout=self._stdout,
78 | stderr=self._stderr,
79 | **kwargs,
80 | shell=False,
81 | ) # nosec
82 |
83 | def close(self, death_sig: Optional[signal.Signals] = None) -> None:
84 | if not death_sig:
85 | death_sig = _get_default_signal()
86 | if IS_WINDOWS:
87 | self.proc.send_signal(death_sig)
88 | else:
89 | os.killpg(self.proc.pid, death_sig)
90 | if self._stdout:
91 | self._stdout.close()
92 | if self._stderr:
93 | self._stderr.close()
94 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/rendezvous/registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # SPDX-License-Identifier: BSD-3-Clause
8 | # Modifications made by NVIDIA
9 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat'
10 | from .api import RendezvousHandler, RendezvousParameters
11 | from .api import rendezvous_handler_registry as handler_registry
12 | from .dynamic_rendezvous import create_handler
13 |
14 | __all__ = ['get_rendezvous_handler']
15 |
16 | def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler:
17 | from . import static_tcp_rendezvous
18 |
19 | return static_tcp_rendezvous.create_rdzv_handler(params)
20 |
21 |
22 | def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler:
23 | from . import etcd_rendezvous
24 |
25 | return etcd_rendezvous.create_rdzv_handler(params)
26 |
27 |
28 | def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler:
29 | from .etcd_rendezvous_backend import create_backend
30 |
31 | backend, store = create_backend(params)
32 |
33 | return create_handler(store, backend, params)
34 |
35 |
36 | def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler:
37 | from .c10d_rendezvous_backend import create_backend
38 |
39 | backend, store = create_backend(params)
40 |
41 | return create_handler(store, backend, params)
42 |
43 |
44 | def _register_default_handlers() -> None:
45 | handler_registry.register("etcd", _create_etcd_handler)
46 | handler_registry.register("etcd-v2", _create_etcd_v2_handler)
47 | handler_registry.register("c10d", _create_c10d_handler)
48 | handler_registry.register("static", _create_static_handler)
49 |
50 |
51 | def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler:
52 | """
53 | Obtain a reference to a :py:class`RendezvousHandler`.
54 |
55 | Custom rendezvous handlers can be registered by
56 |
57 | ::
58 |
59 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous import rendezvous_handler_registry
60 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous.registry import get_rendezvous_handler
61 |
62 | def create_my_rdzv(params: RendezvousParameters):
63 | return MyCustomRdzv(params)
64 |
65 | rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv)
66 |
67 | my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters)
68 | """
69 | return handler_registry.create_handler(params)
70 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/rendezvous/static_tcp_rendezvous.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # SPDX-License-Identifier: BSD-3-Clause
10 | # Modifications made by NVIDIA
11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat'
12 | import datetime
13 | import logging
14 | from typing import Tuple, cast, Optional
15 |
16 | # pyre-ignore[21]: Could not find name `Store` in `torch.distributed`.
17 | from torch.distributed import Store, TCPStore, PrefixStore
18 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous import RendezvousHandler, RendezvousParameters
19 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous.utils import parse_rendezvous_endpoint
20 |
21 | log = logging.getLogger(__name__)
22 |
23 | _default_timeout_seconds = 600
24 |
25 |
26 | class StaticTCPRendezvous(RendezvousHandler):
27 | """
28 | Static rendezvous that is a wrapper around the TCPStore.
29 |
30 | Creates TCPStore based on the input parameters with the
31 | listener on the agent with group_rank=0
32 | """
33 |
34 | def __init__(
35 | self,
36 | master_addr: str,
37 | master_port: int,
38 | rank: int,
39 | world_size: int,
40 | run_id: str,
41 | timeout: int,
42 | ):
43 | self.master_addr = master_addr
44 | self.master_port = master_port
45 | self.rank = rank
46 | self.world_size = world_size
47 | self.run_id = run_id
48 | self.timeout = datetime.timedelta(seconds=timeout)
49 | self._store: Optional[Store] = None
50 |
51 | def get_backend(self) -> str:
52 | return "static"
53 |
54 | def next_rendezvous(self) -> Tuple[Store, int, int]:
55 | log.info("Creating TCPStore as the c10d::Store implementation")
56 | if not self._store:
57 | is_master = self.rank == 0
58 | self._store = TCPStore( # type: ignore[call-arg]
59 | self.master_addr,
60 | self.master_port,
61 | self.world_size,
62 | is_master,
63 | self.timeout,
64 | multi_tenant=True,
65 | )
66 | store = PrefixStore(self.run_id, self._store)
67 | return store, self.rank, self.world_size
68 |
69 | def is_closed(self):
70 | return False
71 |
72 | def set_closed(self):
73 | pass
74 |
75 | def num_nodes_waiting(self):
76 | return 0
77 |
78 | def get_run_id(self) -> str:
79 | return self.run_id
80 |
81 | def shutdown(self) -> bool:
82 | return True
83 |
84 |
85 | def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
86 | if "rank" not in params.config:
87 | raise ValueError(
88 | "rank is absent in RendezvousParameters."
89 | "Try add --node-rank to the cmd request"
90 | )
91 | endpoint = params.endpoint.strip()
92 | if not endpoint:
93 | raise ValueError(
94 | "endpoint is absent in RendezvousParameters"
95 | "Try add --master-port and --master-addr to the cmd request"
96 | )
97 | master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1)
98 | if master_port == -1:
99 | raise ValueError(
100 | f"Port is absent in endpoint: {endpoint}. Try launching with --master-port"
101 | )
102 | world_size = params.max_nodes
103 | rank = cast(int, params.config.get("rank"))
104 | run_id = params.run_id
105 | if "timeout" in params.config:
106 | timeout = int(params.config["timeout"])
107 | else:
108 | timeout = _default_timeout_seconds
109 | return StaticTCPRendezvous(
110 | master_addr, master_port, rank, world_size, run_id, timeout
111 | )
112 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/timer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """
8 | Expiration timers are set up on the same process as the agent and
9 | used from your script to deal with stuck workers. When you go into
10 | a code-block that has the potential to get stuck you can acquire
11 | an expiration timer, which instructs the timer server to kill the
12 | process if it does not release the timer by the self-imposed expiration
13 | deadline.
14 |
15 | Usage::
16 |
17 | import torchelastic.timer as timer
18 | import torchelastic.agent.server as agent
19 |
20 | def main():
21 | start_method = "spawn"
22 | message_queue = mp.get_context(start_method).Queue()
23 | server = timer.LocalTimerServer(message, max_interval=0.01)
24 | server.start() # non-blocking
25 |
26 | spec = WorkerSpec(
27 | fn=trainer_func,
28 | args=(message_queue,),
29 | ...)
30 | agent = agent.LocalElasticAgent(spec, start_method)
31 | agent.run()
32 |
33 | def trainer_func(message_queue):
34 | timer.configure(timer.LocalTimerClient(message_queue))
35 | with timer.expires(after=60): # 60 second expiry
36 | # do some work
37 |
38 | In the example above if ``trainer_func`` takes more than 60 seconds to
39 | complete, then the worker process is killed and the agent retries the worker group.
40 | """
41 |
42 | from .api import TimerClient, TimerRequest, TimerServer, configure, expires # noqa: F401
43 | from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401
44 | from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest # noqa: F401
45 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401
10 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/api.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import os
10 | import socket
11 | from string import Template
12 | from typing import List, Any
13 |
14 |
15 | def get_env_variable_or_raise(env_name: str) -> str:
16 | r"""
17 | Tries to retrieve environment variable. Raises ``ValueError``
18 | if no environment variable found.
19 |
20 | Args:
21 | env_name (str): Name of the env variable
22 | """
23 | value = os.environ.get(env_name, None)
24 | if value is None:
25 | msg = f"Environment variable {env_name} expected, but not set"
26 | raise ValueError(msg)
27 | return value
28 |
29 |
30 | def get_socket_with_port() -> socket.socket:
31 | addrs = socket.getaddrinfo(
32 | host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
33 | )
34 | for addr in addrs:
35 | family, type, proto, _, _ = addr
36 | s = socket.socket(family, type, proto)
37 | try:
38 | s.bind(("localhost", 0))
39 | s.listen(0)
40 | return s
41 | except OSError as e:
42 | s.close()
43 | raise RuntimeError("Failed to create a socket")
44 |
45 |
46 | class macros:
47 | """
48 | Defines simple macros for caffe2.distributed.launch cmd args substitution
49 | """
50 |
51 | local_rank = "${local_rank}"
52 |
53 | @staticmethod
54 | def substitute(args: List[Any], local_rank: str) -> List[str]:
55 | args_sub = []
56 | for arg in args:
57 | if isinstance(arg, str):
58 | sub = Template(arg).safe_substitute(local_rank=local_rank)
59 | args_sub.append(sub)
60 | else:
61 | args_sub.append(arg)
62 | return args_sub
63 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from .cycling_iterator import CyclingIterator # noqa: F401
10 | from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401
11 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/data/cycling_iterator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 |
10 | class CyclingIterator:
11 | """
12 | An iterator decorator that cycles through the
13 | underlying iterator "n" times. Useful to "unroll"
14 | the dataset across multiple training epochs.
15 |
16 | The generator function is called as ``generator_fn(epoch)``
17 | to obtain the underlying iterator, where ``epoch`` is a
18 | number less than or equal to ``n`` representing the ``k``th cycle
19 |
20 | For example if ``generator_fn`` always returns ``[1,2,3]``
21 | then ``CyclingIterator(n=2, generator_fn)`` will iterate through
22 | ``[1,2,3,1,2,3]``
23 | """
24 |
25 | def __init__(self, n: int, generator_fn, start_epoch=0):
26 | self._n = n
27 | self._epoch = start_epoch
28 | self._generator_fn = generator_fn
29 | self._iter = generator_fn(self._epoch)
30 |
31 | def __iter__(self):
32 | return self
33 |
34 | def __next__(self):
35 | try:
36 | return next(self._iter)
37 | except StopIteration as eod: # eod == end of data
38 | if self._epoch < self._n - 1:
39 | self._epoch += 1
40 | self._iter = self._generator_fn(self._epoch)
41 | return self.__next__()
42 | else:
43 | raise eod
44 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/data/elastic_distributed_sampler.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | import math
10 |
11 | import torch
12 | from torch.utils.data.distributed import DistributedSampler
13 |
14 |
15 | class ElasticDistributedSampler(DistributedSampler):
16 | """
17 | Sampler that restricts data loading to a subset of
18 | the dataset for elastic training.
19 |
20 | It is especially useful in conjunction with
21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
22 | process can pass a DistributedSampler instance as a DataLoader sampler,
23 | and load a subset of the original dataset that is exclusive to it.
24 |
25 | .. note::
26 | Dataset is assumed to be of constant size.
27 |
28 | Args:
29 | dataset: Dataset used for sampling.
30 | num_replicas (optional): Number of processes participating in
31 | distributed training.
32 | rank (optional): Rank of the current process within num_replicas.
33 | start_index (optional): Which index of the dataset to start sampling from
34 | """
35 |
36 | def __init__(self, dataset, num_replicas=None, rank=None, start_index=0):
37 | super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank)
38 | if start_index >= len(dataset):
39 | raise ValueError(
40 | f"Start index {start_index} should be less than dataset size {len(dataset)}"
41 | )
42 |
43 | self.start_index = start_index
44 | self.num_samples = int(
45 | math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type]
46 | )
47 | self.total_size = self.num_samples * self.num_replicas
48 |
49 | def __iter__(self):
50 | # deterministically shuffle based on epoch
51 | g = torch.Generator()
52 | g.manual_seed(self.epoch)
53 | indices = (
54 | torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type]
55 | .add(self.start_index)
56 | .tolist()
57 | )
58 |
59 | # add extra samples to make it evenly divisible
60 | indices += indices[: (self.total_size - len(indices))]
61 | assert len(indices) == self.total_size
62 |
63 | # subsample
64 | indices = indices[self.rank : self.total_size : self.num_replicas]
65 | assert len(indices) == self.num_samples
66 |
67 | return iter(indices)
68 |
69 | def __len__(self):
70 | return self.num_samples
71 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/log_level.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 |
10 | def get_log_level() -> str:
11 | """
12 | Return default log level for pytorch.
13 | """
14 | return "WARNING"
15 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/logging.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # SPDX-License-Identifier: BSD-3-Clause
10 | # Modifications made by NVIDIA
11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat'
12 |
13 | import inspect
14 | import logging
15 | import os
16 | import warnings
17 | from typing import Optional
18 |
19 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.utils.log_level import get_log_level
20 |
21 |
22 | def get_logger(name: Optional[str] = None):
23 | """
24 | Util function to set up a simple logger that writes
25 | into stderr. The loglevel is fetched from the LOGLEVEL
26 | env. variable or WARNING as default. The function will use the
27 | module name of the caller if no name is provided.
28 |
29 | Args:
30 | name: Name of the logger. If no name provided, the name will
31 | be derived from the call stack.
32 | """
33 |
34 | # Derive the name of the caller, if none provided
35 | # Use depth=2 since this function takes up one level in the call stack
36 | return _setup_logger(name or _derive_module_name(depth=2))
37 |
38 |
39 | def _setup_logger(name: Optional[str] = None):
40 | log = logging.getLogger(name)
41 | log.setLevel(os.environ.get("LOGLEVEL", get_log_level()))
42 | return log
43 |
44 |
45 | def _derive_module_name(depth: int = 1) -> Optional[str]:
46 | """
47 | Derives the name of the caller module from the stack frames.
48 |
49 | Args:
50 | depth: The position of the frame in the stack.
51 | """
52 | try:
53 | stack = inspect.stack()
54 | assert depth < len(stack)
55 | # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index)
56 | frame_info = stack[depth]
57 |
58 | module = inspect.getmodule(frame_info[0])
59 | if module:
60 | module_name = module.__name__
61 | else:
62 | # inspect.getmodule(frame_info[0]) does NOT work (returns None) in
63 | # binaries built with @mode/opt
64 | # return the filename (minus the .py extension) as modulename
65 | filename = frame_info[1]
66 | module_name = os.path.splitext(os.path.basename(filename))[0]
67 | return module_name
68 | except Exception as e:
69 | warnings.warn(
70 | f"Error deriving logger module name, using . Exception: {e}",
71 | RuntimeWarning,
72 | )
73 | return None
74 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/store.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | # Copyright (c) Facebook, Inc. and its affiliates.
4 | # All rights reserved.
5 | #
6 | # This source code is licensed under the BSD-style license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | from datetime import timedelta
10 | from typing import List
11 |
12 |
13 | def get_all(store, rank: int, prefix: str, size: int):
14 | r"""
15 | Given a store and a prefix, the method goes through the array of keys
16 | of the following format: ``{prefix}{idx}``, where idx is in a range
17 | from 0 to size, and tries to retrieve the data.
18 |
19 | The Rank0 process waits at the end to make sure all other processes
20 | finished the procedure before exiting.
21 |
22 | Usage
23 |
24 | ::
25 |
26 | values = get_all(store, 'torchelastic/data', 3)
27 | value1 = values[0] # retrieves the data for key torchelastic/data0
28 | value2 = values[1] # retrieves the data for key torchelastic/data1
29 | value3 = values[2] # retrieves the data for key torchelastic/data2
30 |
31 | """
32 | data_arr = []
33 | for idx in range(size):
34 | data = store.get(f"{prefix}{idx}")
35 | data_arr.append(data)
36 | store.set(f"{prefix}{rank}.FIN", b"FIN")
37 | if rank == 0:
38 | # Rank0 runs the TCPStore daemon, as a result it needs to exit last.
39 | # Otherwise, the barrier may timeout if rank0 process finished the work
40 | # before other processes finished `get_all` method
41 | for node_rank in range(size):
42 | store.get(f"{prefix}{node_rank}.FIN")
43 |
44 | return data_arr
45 |
46 |
47 | def synchronize(
48 | store,
49 | data: bytes,
50 | rank: int,
51 | world_size: int,
52 | key_prefix: str,
53 | barrier_timeout: float = 300,
54 | ) -> List[bytes]:
55 | """
56 | Synchronizes ``world_size`` agents between each other using the underlying c10d store.
57 | The ``data`` will be available on each of the agents.
58 |
59 | Note: The data on the path is not deleted, as a result there can be stale data if
60 | you use the same key_prefix twice.
61 | """
62 | store.set_timeout(timedelta(seconds=barrier_timeout))
63 | store.set(f"{key_prefix}{rank}", data)
64 | agent_data = get_all(store, rank, key_prefix, world_size)
65 | return agent_data
66 |
67 |
68 | def barrier(
69 | store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300
70 | ) -> None:
71 | """
72 | A global lock between agents.
73 |
74 | Note: Since the data is not removed from the store, the barrier can be used
75 | once per unique ``key_prefix``.
76 | """
77 | data = f"{rank}".encode()
78 | synchronize(store, data, rank, world_size, key_prefix, barrier_timeout)
79 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from . import (
18 | completion,
19 | exception,
20 | finalize,
21 | health_check,
22 | initialize,
23 | monitor_thread,
24 | nested_restarter,
25 | rank_assignment,
26 | state,
27 | terminate,
28 | )
29 | from .compose import Compose
30 | from .state import FrozenState, Mode, State
31 | from .wrap import CallWrapper, Wrapper
32 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/attribution.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import dataclasses
18 | import enum
19 | import itertools
20 | import re
21 |
22 |
23 | class Interruption(enum.Enum):
24 | EXCEPTION = enum.auto()
25 | BASE_EXCEPTION = enum.auto()
26 | SOFT_TIMEOUT = enum.auto()
27 | HARD_TIMEOUT = enum.auto()
28 | TERMINATED = enum.auto()
29 | UNRESPONSIVE = enum.auto()
30 | MONITOR_PROCESS_EXCEPTION = enum.auto()
31 |
32 |
33 | @dataclasses.dataclass(frozen=True)
34 | class InterruptionRecord:
35 | rank: int
36 | interruption: Interruption
37 |
38 | @classmethod
39 | def from_str(cls, string: str):
40 | rank_match = re.search(r'rank=(\d+)', string)
41 | interruption_match = re.search(r'Interruption\.(\w+)', string)
42 |
43 | if not rank_match or not interruption_match:
44 | raise ValueError("Invalid State string format")
45 |
46 | rank = int(rank_match.group(1))
47 | interruption_name = interruption_match.group(1)
48 | interruption = Interruption[interruption_name]
49 |
50 | return cls(rank=rank, interruption=interruption)
51 |
52 |
53 | def format_interruption_records(records):
54 | msg = ', '.join(
55 | (
56 | f'{interruption} on {ranks=}'
57 | for interruption, group in itertools.groupby(records, key=lambda r: r.interruption)
58 | for ranks in [set([elem.rank for elem in group])]
59 | )
60 | )
61 | return msg
62 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/completion.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import abc
18 |
19 | from .state import FrozenState
20 |
21 |
22 | class Completion(abc.ABC):
23 | r'''
24 | Abstract base class for ``global_finalize_success`` argument for
25 | :py:class:`inprocess.Wrapper`.
26 |
27 | :py:class:`Completion` is executed by any unterminated rank when
28 | it has completed the workload wrapped by inprocess.
29 |
30 | Multiple instances of :py:class:`Completion` could be composed with
31 | :py:class:`inprocess.Compose` to achieve the desired behavior.
32 | '''
33 |
34 | @abc.abstractmethod
35 | def __call__(self, state: FrozenState) -> FrozenState:
36 | r'''
37 | Implementation of a :py:class:`Completion`.
38 |
39 | Args:
40 | state: read-only :py:class:`Wrapper` state
41 |
42 | Returns:
43 | Forwarded read-only input ``state``.
44 | '''
45 | raise NotImplementedError
46 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/compose.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import inspect
18 | import warnings
19 | from collections.abc import Callable
20 | from typing import Any, TypeVar
21 |
22 | T = TypeVar('T')
23 |
24 |
25 | def find_common_ancestor(*instances):
26 | common_mro = set(type(instances[0]).mro())
27 |
28 | for instance in instances[1:]:
29 | common_mro &= set(type(instance).mro())
30 |
31 | if common_mro:
32 | mro_list = type(instances[0]).mro()
33 | common_ancestor = [cls for cls in mro_list if cls in common_mro]
34 | return common_ancestor[0]
35 | else:
36 | return None
37 |
38 |
39 | class Compose:
40 | r'''
41 | Performs functional composition (chaining) of multiple callable class
42 | instances.
43 |
44 | Output of the previous callable is passed as input to the next callable,
45 | and the output of the last callable is returned as the final output of a
46 | :py:class:`Compose` instance.
47 |
48 | Constructed :py:class:`Compose` object is an instance of the lowest common
49 | ancestor in `method resolution order
50 | `_ of
51 | all input callable class instances.
52 |
53 | Example:
54 |
55 | .. code-block:: python
56 |
57 | composed = Compose(a, b, c)
58 | ret = composed(arg) # is equivalent to ret = a(b(c(arg)))
59 | '''
60 |
61 | def __new__(cls, *instances: Callable[[T], T]):
62 |
63 | common_ancestor = find_common_ancestor(*instances)
64 | DynamicCompose = type(
65 | 'DynamicCompose',
66 | (Compose, common_ancestor),
67 | {
68 | 'instances': instances,
69 | '__new__': object.__new__,
70 | },
71 | )
72 | return DynamicCompose()
73 |
74 | def __init__(self, *args, **kwargs):
75 | pass
76 |
77 | def __call__(self, *args: Any):
78 | for instance in reversed(self.instances):
79 | ret = instance(*args or ())
80 | if ret is None and args and args != (None,):
81 | msg = (
82 | f'{type(self).__name__} didn\'t chain arguments after '
83 | f'calling {instance=} with {args=}'
84 | )
85 | warnings.warn(msg)
86 | if not isinstance(ret, tuple) and len(inspect.signature(instance).parameters) > 0:
87 | args = (ret,)
88 | else:
89 | args = ret
90 | return ret
91 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/exception.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
18 | class RestartError(Exception):
19 | r'''
20 | Base :py:exc:`Exception` for exceptions raised by
21 | :py:class:`inprocess.Wrapper`.
22 | '''
23 |
24 | pass
25 |
26 |
27 | class RestartAbort(BaseException):
28 | r'''
29 | A terminal Python :py:exc:`BaseException` indicating that the
30 | :py:class:`inprocess.Wrapper` should be aborted immediately, bypassing any
31 | further restart attempts.
32 | '''
33 |
34 | pass
35 |
36 |
37 | class HealthCheckError(RestartError):
38 | r'''
39 | :py:exc:`RestartError` exception to indicate that
40 | :py:class:`inprocess.health_check.HealthCheck` raised errors, and execution
41 | shouldn't be restarted on this distributed rank.
42 | '''
43 |
44 | pass
45 |
46 |
47 | class InternalError(RestartError):
48 | r'''
49 | :py:class:`inprocess.Wrapper` internal error.
50 | '''
51 |
52 | pass
53 |
54 |
55 | class TimeoutError(RestartError):
56 | r'''
57 | :py:class:`inprocess.Wrapper` timeout error.
58 | '''
59 |
60 | pass
61 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/finalize.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import abc
18 | import datetime
19 | import threading
20 | from typing import Any, Callable, Optional
21 |
22 | from . import exception
23 | from .state import FrozenState
24 |
25 |
26 | class Finalize(abc.ABC):
27 | r'''
28 | Abstract base class for ``finalize`` argument for
29 | :py:class:`inprocess.Wrapper`.
30 |
31 | :py:class:`Finalize` brings the process into a state where a restart of the
32 | wrapped function may be attempted, e.g.: deinitialize any global variables
33 | or synchronize with any asynchronous tasks issued by the wrapped function
34 | that was not already performed by exception handlers in the wrapped
35 | function.
36 |
37 | Any failure during execution of :py:class:`Finalize` should raise an
38 | exception. In this case the health check is skipped, exception is reraised
39 | by the wrapper, and it should cause termination of the main Python
40 | interpreter process.
41 |
42 | :py:class:`Finalize` class is executed after a fault was detected,
43 | distributed group was destroyed, but before the
44 | :py:class:`inprocess.health_check.HealthCheck` is performed.
45 |
46 | Multiple instances of :py:class:`Finalize` could be composed with
47 | :py:class:`inprocess.Compose` to achieve the desired behavior.
48 | '''
49 |
50 | @abc.abstractmethod
51 | def __call__(self, state: FrozenState) -> FrozenState:
52 | r'''
53 | Implementation of a :py:class:`Finalize`.
54 |
55 | Args:
56 | state: read-only :py:class:`Wrapper` state
57 |
58 | Returns:
59 | Forwarded read-only input ``state``.
60 | '''
61 | raise NotImplementedError
62 |
63 |
64 | class ThreadedFinalize(Finalize):
65 | r'''
66 | Executes the provided finalize ``fn`` function with specified positional
67 | and keyword arguments in a separate :py:class:`threading.Thread`.
68 |
69 | Raises an exception if execution takes longer than the specified
70 | ``timeout``.
71 |
72 | Args:
73 | timeout: timeout for a thread executing ``fn``
74 | fn: function to be executed
75 | args: tuple of positional arguments
76 | kwargs: dictionary of keyword arguments
77 | '''
78 |
79 | def __init__(
80 | self,
81 | timeout: datetime.timedelta,
82 | fn: Callable[..., Any],
83 | args: Optional[tuple[Any, ...]] = (),
84 | kwargs: Optional[dict[str, Any]] = None,
85 | ):
86 | if kwargs is None:
87 | kwargs = {}
88 |
89 | self.timeout = timeout
90 | self.fn = fn
91 | self.args = args
92 | self.kwargs = kwargs
93 |
94 | def __call__(self, state: FrozenState) -> FrozenState:
95 | rank = state.rank
96 | thread = threading.Thread(
97 | target=self.fn,
98 | name=f'{type(self).__name__}-{rank}',
99 | args=self.args,
100 | kwargs=self.kwargs,
101 | daemon=True,
102 | )
103 | thread.start()
104 | thread.join(self.timeout.total_seconds())
105 | if thread.is_alive():
106 | raise exception.TimeoutError
107 |
108 | return state
109 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/initialize.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import abc
18 | from typing import Optional
19 |
20 | from . import exception
21 | from .state import FrozenState
22 |
23 |
24 | class Initialize(abc.ABC):
25 | r'''
26 | Abstract base class for ``initialize`` argument for
27 | :py:class:`inprocess.Wrapper`.
28 |
29 | :py:class:`Initialize` is executed at the start of every restart iteration,
30 | including the first one. :py:class:`Initialize` can raise exceptions (e.g.,
31 | if specific preconditions are not met). Raising a standard Python
32 | :py:exc:`Exception` triggers another restart, while raising a
33 | :py:exc:`BaseException` terminates the wrapper.
34 |
35 | Multiple instances of :py:class:`Initialize` could be composed with
36 | :py:class:`inprocess.Compose` to achieve the desired behavior.
37 | '''
38 |
39 | @abc.abstractmethod
40 | def __call__(self, state: FrozenState) -> FrozenState:
41 | r'''
42 | Implementation of a :py:class:`Initialize`.
43 |
44 | Args:
45 | state: read-only :py:class:`Wrapper` state
46 |
47 | Returns:
48 | Forwarded read-only input ``state``.
49 | '''
50 | raise NotImplementedError
51 |
52 |
53 | class RetryController(Initialize):
54 | r'''
55 | Controls retry logic for distributed training based on specified iteration
56 | and world size limits.
57 |
58 | This class manages the conditions under which distributed training retries
59 | are allowed, raising a :py:exc:`inprocess.exception.RestartAbort` exception
60 | when the conditions are not met.
61 |
62 | Args:
63 | max_iterations: the maximum number of iterations allowed before
64 | aborting retries. If :py:obj:`None`, there is no iteration limit
65 | min_world_size: The minimum required world size to proceed with
66 | execution
67 | min_active_world_size: The minimum required active world size to
68 | proceed with execution
69 | '''
70 |
71 | def __init__(
72 | self,
73 | max_iterations: Optional[int] = None,
74 | min_world_size: int = 1,
75 | min_active_world_size: int = 1,
76 | ):
77 | self.max_iterations = max_iterations
78 | self.min_world_size = min_world_size
79 | self.min_active_world_size = min_active_world_size
80 |
81 | def __call__(self, state: FrozenState) -> FrozenState:
82 | if (
83 | state.world_size < self.min_world_size
84 | or state.active_world_size < self.min_active_world_size
85 | or (self.max_iterations is not None and state.iteration >= self.max_iterations)
86 | ):
87 | msg = (
88 | f'{state.iteration=} {self.max_iterations=} '
89 | f'{state.world_size=} {self.min_world_size=} '
90 | f'{state.active_world_size=} {self.min_active_world_size=} '
91 | )
92 | raise exception.RestartAbort(msg)
93 | return state
94 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/nested_restarter.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import dataclasses
18 | from typing import Optional
19 |
20 | from ..fault_tolerance.rank_monitor_server import RankMonitorLogger
21 | from .abort import Abort
22 | from .completion import Completion
23 | from .initialize import Initialize
24 | from .state import FrozenState
25 | from .terminate import Terminate
26 |
27 |
28 | class NestedRestarterLogger(RankMonitorLogger):
29 | """Logger used in the nested restarter process"""
30 |
31 | def __init__(self):
32 | super().__init__(name="InprocessRestarter", is_restarter_logger=True)
33 |
34 |
35 | @dataclasses.dataclass
36 | class NestedRestarterCallback:
37 | r'''
38 | Callback for logging the NVRx nested restarter integration.
39 | '''
40 |
41 | _shared_logger = NestedRestarterLogger()
42 |
43 | restarter_state: str
44 | restarter_stage: Optional[str] = None
45 | logger: NestedRestarterLogger = dataclasses.field(default=_shared_logger)
46 | special_rank: int = 0
47 |
48 | def __call__(self, state: FrozenState) -> FrozenState:
49 |
50 | if state.initial_rank == self.special_rank:
51 | self.logger.set_connected_rank(state.initial_rank)
52 | msg = f'[NestedRestarter] name=[InProcess] state={self.restarter_state}'
53 | if self.restarter_stage is not None:
54 | msg += f" stage={self.restarter_stage}"
55 |
56 | self.logger.log_for_restarter(msg)
57 |
58 | return state
59 |
60 |
61 | @dataclasses.dataclass
62 | class NestedRestarterHandlingCompleted(Initialize, NestedRestarterCallback):
63 |
64 | restarter_state: str = 'initialize'
65 | restarter_stage: str = None
66 |
67 | def __init__(self, special_rank: int = 0):
68 | self._called_once = False
69 | self.special_rank = special_rank
70 | self.logger = NestedRestarterCallback._shared_logger
71 |
72 | def __call__(self, state: FrozenState) -> FrozenState:
73 |
74 | # Apply the callback functionality
75 | state = NestedRestarterCallback.__call__(self, state)
76 |
77 | if not self._called_once:
78 | self._called_once = True
79 | self.restarter_state = 'handling'
80 | self.restarter_stage = 'completed'
81 |
82 | return state
83 |
84 |
85 | @dataclasses.dataclass
86 | class NestedRestarterHandlingStarting(Abort, NestedRestarterCallback):
87 | restarter_state: str = 'handling'
88 | restarter_stage: str = 'starting'
89 |
90 | def __call__(self, state: FrozenState) -> FrozenState:
91 | return NestedRestarterCallback.__call__(self, state)
92 |
93 |
94 | @dataclasses.dataclass
95 | class NestedRestarterFinalized(Completion, NestedRestarterCallback):
96 | restarter_state: str = 'finalized'
97 |
98 | def __call__(self, state: FrozenState) -> FrozenState:
99 | return NestedRestarterCallback.__call__(self, state)
100 |
101 |
102 | @dataclasses.dataclass
103 | class NestedRestarterAborted(Terminate, NestedRestarterCallback):
104 | restarter_state: str = 'aborted'
105 |
106 | def __call__(self, state: FrozenState) -> FrozenState:
107 | return NestedRestarterCallback.__call__(self, state)
108 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/terminate.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import abc
18 |
19 | from .state import FrozenState
20 |
21 |
22 | class Terminate(abc.ABC):
23 | r'''
24 | Abstract base class for ``global_finalize_failure`` argument for
25 | :py:class:`inprocess.Wrapper`.
26 |
27 | :py:class:`Terminate` is executed by any unterminated rank when
28 | that rank terminates.
29 |
30 | Multiple instances of :py:class:`Terminate` could be composed with
31 | :py:class:`inprocess.Compose` to achieve the desired behavior.
32 | '''
33 |
34 | @abc.abstractmethod
35 | def __call__(self, state: FrozenState) -> FrozenState:
36 | r'''
37 | Implementation of a :py:class:`Terminate`.
38 |
39 | Args:
40 | state: read-only :py:class:`Wrapper` state
41 |
42 | Returns:
43 | Forwarded read-only input ``state``.
44 | '''
45 | raise NotImplementedError
46 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/inprocess/tools/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from . import inject_fault
18 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/ptl_resiliency/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from ._utils import SimulatedFaultParams # noqa: F401
17 | from .fault_tolerance_callback import FaultToleranceCallback # noqa: F401
18 | from .fault_tolerance_sections_callback import FaultToleranceSectionsCallback # noqa: F401
19 | from .straggler_det_callback import StragglerDetectionCallback # noqa: F401
20 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/shared_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/src/nvidia_resiliency_ext/shared_utils/__init__.py
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/straggler/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .reporting import Report, StragglerId # noqa: F401
17 | from .statistics import Statistic # noqa: F401
18 | from .straggler import CallableId, Detector # noqa: F401
19 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/straggler/cupti.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import threading
17 |
18 |
19 | class CuptiManager:
20 | """Provide thread safe access to the CUPTI extension.
21 |
22 | Implements simple usage counter, to track active profiling runs.
23 | """
24 |
25 | def __init__(self, bufferSize=1_000_000, numBuffers=8, statsMaxLenPerKernel=4096):
26 | """
27 | Args:
28 | bufferSize (int, optional): CUPTI buffer size. Defaults to 1MB.
29 | numBuffers (int, optional): Num of CUPTI buffers in a pool . Defaults to 8.
30 | statsMaxLenPerKernel (int, optional): Max number of timing entries per kernel.
31 | (when this limit is rached, oldest timing entries are discarded). Defaults to 4096.
32 | """
33 |
34 | # lazy load the extension module, to avoid circular import
35 | import nvrx_cupti_module as cupti_module # type: ignore
36 |
37 | self.cupti_ext = cupti_module.CuptiProfiler(
38 | bufferSize=bufferSize,
39 | numBuffers=numBuffers,
40 | statsMaxLenPerKernel=statsMaxLenPerKernel,
41 | )
42 | self.is_initialized = False
43 | self.started_cnt = 0
44 | self.lock = threading.Lock()
45 |
46 | def _ensure_initialized(self):
47 | """Check for CuptiProfiler initialization."""
48 | if not self.is_initialized:
49 | raise RuntimeError("CuptiManager was not initialized")
50 |
51 | def initialize(self):
52 | """Call CuptiProfiler initialization method, registering CUPTI
53 | callbacks for profiling."""
54 | with self.lock:
55 | self.cupti_ext.initialize()
56 | self.is_initialized = True
57 |
58 | def shutdown(self):
59 | """Finalize CUPTI."""
60 | with self.lock:
61 | self.cupti_ext.shutdown()
62 | self.is_initialized = False
63 | self.started_cnt = 0
64 |
65 | def start_profiling(self):
66 | """Enable CUDA kernels activity tracking."""
67 | with self.lock:
68 | self._ensure_initialized()
69 | if self.started_cnt == 0:
70 | self.cupti_ext.start()
71 | self.started_cnt += 1
72 |
73 | def stop_profiling(self):
74 | """Disable CUDA kernels activity tracking."""
75 | with self.lock:
76 | self._ensure_initialized()
77 | if self.started_cnt > 0:
78 | self.started_cnt -= 1
79 | if self.started_cnt == 0:
80 | self.cupti_ext.stop()
81 | else:
82 | raise RuntimeError("No active profiling run.")
83 |
84 | def get_results(self):
85 | """Calculate kernel execution timing statistics."""
86 | with self.lock:
87 | self._ensure_initialized()
88 | stats = self.cupti_ext.get_stats()
89 | return stats.copy()
90 |
91 | def reset_results(self):
92 | """Reset kernel execution timing records."""
93 | with self.lock:
94 | self._ensure_initialized()
95 | self.cupti_ext.reset()
96 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/straggler/cupti_src/BufferPool.cpp:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #include "BufferPool.h"
19 | #include
20 | #include
21 | #include
22 | #include
23 | #include
24 |
25 |
26 | BufferPool::BufferPool(size_t bufferSize, int numBuffers)
27 | : bufferSize(bufferSize), numBuffers(numBuffers) {
28 | for (int i = 0; i < numBuffers; ++i) {
29 | uint8_t* newBuffer = (uint8_t*)malloc(bufferSize);
30 | if (!newBuffer) {
31 | throw std::bad_alloc();
32 | }
33 | freeBuffers.push_back(newBuffer);
34 | }
35 | }
36 |
37 | BufferPool::~BufferPool() {
38 | while (!freeBuffers.empty()) {
39 | free(freeBuffers.back());
40 | freeBuffers.pop_back();
41 | }
42 | }
43 |
44 | uint8_t* BufferPool::getBuffer() {
45 | std::lock_guard lock(mutex);
46 | if (freeBuffers.empty()) {
47 | return nullptr; // prob better to allocate a new buffer
48 | }
49 | uint8_t* buffer = freeBuffers.back();
50 | freeBuffers.pop_back();
51 | return buffer;
52 | }
53 |
54 | void BufferPool::releaseBuffer(uint8_t* buffer) {
55 | std::lock_guard lock(mutex);
56 | freeBuffers.push_back(buffer);
57 | }
58 |
59 | size_t BufferPool::getBufferSize() const {
60 | return bufferSize;
61 | }
62 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/straggler/cupti_src/BufferPool.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 | #include
22 | #include
23 |
24 | class BufferPool {
25 | public:
26 | BufferPool(size_t bufferSize = 1024 * 1024 * 4, int numBuffers = 20);
27 | ~BufferPool();
28 |
29 | uint8_t* getBuffer();
30 | void releaseBuffer(uint8_t* buffer);
31 | size_t getBufferSize() const;
32 |
33 | private:
34 | std::vector freeBuffers;
35 | size_t bufferSize;
36 | int numBuffers;
37 | std::mutex mutex;
38 | };
39 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/straggler/cupti_src/CircularBuffer.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include
21 |
22 | template
23 | class CircularBuffer {
24 | private:
25 | std::vector _buffer;
26 | size_t _head;
27 | size_t _tail;
28 | size_t _size;
29 | size_t _capacity;
30 |
31 | public:
32 | explicit CircularBuffer(size_t capacity=32) :
33 | _buffer(capacity), _head(0), _tail(0), _size(0), _capacity(capacity) {}
34 |
35 | ~CircularBuffer() = default;
36 |
37 | bool empty() const {
38 | return _size == 0;
39 | }
40 |
41 | bool full() const {
42 | return _size == _capacity;
43 | }
44 |
45 | size_t size() const {
46 | return _size;
47 | }
48 |
49 | size_t capacity() const {
50 | return _capacity;
51 | }
52 |
53 | void push_back(const T& value) {
54 | _buffer[_tail] = value;
55 | _tail = (_tail + 1) % _capacity;
56 | if (full()) {
57 | _head = (_head + 1) % _capacity;
58 | } else {
59 | _size++;
60 | }
61 | }
62 |
63 | std::vector linearize() const {
64 | std::vector res(_size);
65 | for (size_t linearIndex = 0; linearIndex < _size; linearIndex++) {
66 | res[linearIndex] = _buffer[(_head + linearIndex) % _capacity];
67 | }
68 | return res;
69 | }
70 | };
71 |
--------------------------------------------------------------------------------
/src/nvidia_resiliency_ext/straggler/cupti_src/CuptiProfiler.h:
--------------------------------------------------------------------------------
1 | /*
2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | * SPDX-License-Identifier: Apache-2.0
4 | *
5 | * Licensed under the Apache License, Version 2.0 (the "License");
6 | * you may not use this file except in compliance with the License.
7 | * You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | #pragma once
19 |
20 | #include