├── .github └── workflows │ ├── build_docs.yml │ ├── cherry_pick_release.yml │ ├── lint_code.yml │ └── unit_test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── cupti_build.py ├── docs ├── Makefile ├── make.bat └── source │ ├── checkpointing │ ├── async │ │ ├── api.rst │ │ ├── api │ │ │ ├── core.rst │ │ │ ├── filesystem_async.rst │ │ │ ├── state_dict_saver.rst │ │ │ └── torch_ckpt.rst │ │ ├── examples.rst │ │ ├── examples │ │ │ ├── basic_example.rst │ │ │ └── writer_example.rst │ │ ├── index.rst │ │ └── usage_guide.rst │ └── local │ │ ├── api.rst │ │ ├── api │ │ ├── base_ckpt_manager.rst │ │ ├── base_state_dict.rst │ │ ├── basic_state_dict.rst │ │ ├── callback.rst │ │ ├── local_ckpt_manager.rst │ │ └── replication.rst │ │ ├── examples.rst │ │ ├── examples │ │ └── basic_example.rst │ │ ├── index.rst │ │ └── usage_guide.rst │ ├── conf.py │ ├── fault_tolerance │ ├── README-pci-topo-file.md │ ├── api.rst │ ├── api │ │ ├── callback.rst │ │ ├── client.rst │ │ ├── config.rst │ │ └── server.rst │ ├── examples.rst │ ├── examples │ │ ├── basic_example.rst │ │ ├── in_job_and_in_process_example.rst │ │ ├── train_ddp_heartbeats.rst │ │ └── train_ddp_sections.rst │ ├── index.rst │ ├── integration.rst │ ├── integration │ │ ├── heartbeats.rst │ │ ├── inprocess.rst │ │ ├── ptl.rst │ │ └── sections.rst │ └── usage_guide.rst │ ├── index.rst │ ├── inprocess │ ├── api.rst │ ├── api │ │ ├── abort.rst │ │ ├── compose.rst │ │ ├── exception.rst │ │ ├── finalize.rst │ │ ├── health_check.rst │ │ ├── initialize.rst │ │ ├── rank_assignment.rst │ │ ├── rank_filter.rst │ │ ├── state.rst │ │ └── wrap.rst │ ├── examples.rst │ ├── examples │ │ ├── basic_example.rst │ │ └── optimal_example.rst │ ├── index.rst │ └── usage_guide.rst │ ├── media │ ├── nvrx_core_features.png │ └── nvrx_docs_source.png │ ├── release-notes.md │ └── straggler_det │ ├── api.rst │ ├── api │ ├── callback.rst │ ├── reporting.rst │ ├── statistics.rst │ └── straggler.rst │ ├── examples.rst │ ├── examples │ └── basic_example.rst │ ├── index.rst │ └── usage_guide.rst ├── examples ├── checkpointing │ ├── async_ckpt.py │ ├── async_writer.py │ └── local_ckpt.py ├── fault_tolerance │ ├── basic_ft_example.py │ ├── dist_utils.py │ ├── fault_tol_cfg_heartbeats.yaml │ ├── fault_tol_cfg_sections.yaml │ ├── in_job_and_in_process_example.py │ ├── log_utils.py │ ├── run_inprocess_injob_example.sh │ ├── train_ddp_heartbeats_api.py │ └── train_ddp_sections_api.py ├── inprocess │ ├── basic_example.py │ └── optimal_example.py └── straggler │ └── example.py ├── pyproject.toml ├── src └── nvidia_resiliency_ext │ ├── checkpointing │ ├── __init__.py │ ├── async_ckpt │ │ ├── cached_metadata_filesystem_reader.py │ │ ├── core.py │ │ ├── filesystem_async.py │ │ ├── state_dict_saver.py │ │ └── torch_ckpt.py │ ├── local │ │ ├── __init__.py │ │ ├── base_state_dict.py │ │ ├── basic_state_dict.py │ │ ├── ckpt_managers │ │ │ ├── base_manager.py │ │ │ └── local_manager.py │ │ └── replication │ │ │ ├── __init__.py │ │ │ ├── _torch_future.py │ │ │ ├── group_utils.py │ │ │ ├── strategies.py │ │ │ ├── torch_device_utils.py │ │ │ └── utils.py │ └── utils.py │ ├── fault_tolerance │ ├── __init__.py │ ├── _ft_rendezvous.py │ ├── _torch_elastic_compat │ │ ├── __init__.py │ │ ├── agent │ │ │ ├── __init__.py │ │ │ └── server │ │ │ │ ├── __init__.py │ │ │ │ ├── api.py │ │ │ │ └── local_elastic_agent.py │ │ ├── events │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ └── handlers.py │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ └── api.py │ │ ├── multiprocessing │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── errors │ │ │ │ ├── __init__.py │ │ │ │ ├── error_handler.py │ │ │ │ └── handlers.py │ │ │ ├── redirects.py │ │ │ ├── subprocess_handler │ │ │ │ ├── __init__.py │ │ │ │ ├── handlers.py │ │ │ │ └── subprocess_handler.py │ │ │ └── tail_log.py │ │ ├── rendezvous │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── c10d_rendezvous_backend.py │ │ │ ├── dynamic_rendezvous.py │ │ │ ├── etcd_rendezvous.py │ │ │ ├── etcd_rendezvous_backend.py │ │ │ ├── etcd_server.py │ │ │ ├── etcd_store.py │ │ │ ├── registry.py │ │ │ ├── static_tcp_rendezvous.py │ │ │ └── utils.py │ │ ├── timer │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── file_based_local_timer.py │ │ │ └── local_timer.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── api.py │ │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── cycling_iterator.py │ │ │ └── elastic_distributed_sampler.py │ │ │ ├── distributed.py │ │ │ ├── log_level.py │ │ │ ├── logging.py │ │ │ └── store.py │ ├── config.py │ ├── data.py │ ├── dict_utils.py │ ├── ipc_connector.py │ ├── launcher.py │ ├── rank_monitor_client.py │ ├── rank_monitor_server.py │ ├── rank_monitor_state_machine.py │ ├── timeouts_calc.py │ └── utils.py │ ├── inprocess │ ├── __init__.py │ ├── abort.py │ ├── attribution.py │ ├── completion.py │ ├── compose.py │ ├── exception.py │ ├── finalize.py │ ├── health_check.py │ ├── initialize.py │ ├── monitor_process.py │ ├── monitor_thread.py │ ├── nested_restarter.py │ ├── param_utils.py │ ├── progress_watchdog.py │ ├── rank_assignment.py │ ├── rank_filter.py │ ├── sibling_monitor.py │ ├── state.py │ ├── store.py │ ├── terminate.py │ ├── tools │ │ ├── __init__.py │ │ └── inject_fault.py │ ├── utils.py │ └── wrap.py │ ├── ptl_resiliency │ ├── __init__.py │ ├── _utils.py │ ├── fault_tolerance_callback.py │ ├── fault_tolerance_sections_callback.py │ ├── local_checkpoint_callback.py │ └── straggler_det_callback.py │ ├── shared_utils │ ├── __init__.py │ └── health_check.py │ └── straggler │ ├── __init__.py │ ├── cupti.py │ ├── cupti_src │ ├── BufferPool.cpp │ ├── BufferPool.h │ ├── CircularBuffer.h │ ├── CuptiProfiler.cpp │ ├── CuptiProfiler.h │ └── cupti_module_py.cpp │ ├── dist_utils.py │ ├── interval_tracker.py │ ├── name_mapper.py │ ├── reporting.py │ ├── statistics.py │ └── straggler.py └── tests ├── checkpointing └── unit │ ├── __init__.py │ ├── conftest.py │ ├── test_async_save.py │ ├── test_async_writer.py │ ├── test_async_writer_msc.py │ ├── test_basic_local.py │ ├── test_cleanup.py │ └── test_utilities.py ├── fault_tolerance ├── func │ ├── _launcher_mode_test_worker.py │ ├── _workload_ctrl_test_worker.py │ ├── run_launcher_any_failed_mode_test.sh │ ├── run_launcher_min_healthy_mode_test.sh │ ├── run_local_ddp_test_heartbeats.sh │ ├── run_local_ddp_test_sections.sh │ ├── run_workload_ctrl_test_excl_node.sh │ └── run_workload_ctrl_test_shutdown.sh └── unit │ ├── __init__.py │ ├── _launcher_test_util.py │ ├── conftest.py │ ├── test_config.py │ ├── test_dynamic_rendezvous.py │ ├── test_init.py │ ├── test_ipc_connector.py │ ├── test_launcher.py │ ├── test_layered_restart_v1.py │ ├── test_process_utils.py │ ├── test_rank_monitor_server.py │ ├── test_reconnect.py │ ├── test_shutdown.py │ ├── test_shutdown_sections.py │ ├── test_timeouts.py │ ├── test_timeouts_calc.py │ ├── test_timeouts_sections.py │ └── utils.py ├── inprocess ├── __init__.py ├── app.py ├── common.py ├── test_abort.py ├── test_app.py ├── test_compose.py ├── test_health_check.py ├── test_monitor_thread.py ├── test_nested_restarter.py ├── test_progress_watchdog.py ├── test_rank_assignment.py ├── test_timeout.py ├── test_torch.py └── test_wrap.py ├── ptl_resiliency ├── func │ └── nemo20 │ │ ├── Dockerfile.ft_test │ │ ├── check_straggler_log.py │ │ ├── ft_test_asserts.sh │ │ ├── ft_test_launchers.sh │ │ ├── ft_test_llama3.py │ │ ├── local_ckpt_test.sh │ │ ├── straggler_test_llama3.py │ │ └── test_local_ckpt_llama3.py └── unit │ ├── __init__.py │ ├── test_ft_callback_hb.py │ ├── test_ft_callback_sections.py │ ├── test_ft_state_machine.py │ ├── test_local_ckpt_callback.py │ └── test_straggler_det_callback.py ├── shared_utils └── test_health_check.py └── straggler ├── README.md ├── func ├── check_log.py └── ddp_test.py └── unit ├── __init__.py ├── _utils.py ├── test_cupti_ext.py ├── test_cupti_manager.py ├── test_data_shared.py ├── test_det_section_api.py ├── test_individual_gpu_scores.py ├── test_interval_tracker.py ├── test_name_mapper.py ├── test_relative_gpu_scores.py ├── test_reporting.py ├── test_reporting_elapsed.py ├── test_sections.py └── test_wrap_callables.py /.github/workflows/build_docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | jobs: 10 | build_docs: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout repository 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.10' 21 | 22 | - name: Install dependencies 23 | run: | 24 | pip install -U sphinx sphinx-rtd-theme sphinxcontrib-napoleon sphinx_copybutton lightning psutil defusedxml 25 | 26 | - name: Build Documentation 27 | run: | 28 | sphinx-build -b html docs/source public/ 29 | if [ ! -d "public" ]; then 30 | echo "Error: Documentation build failed. 'public/' directory not found." 31 | exit 1 32 | fi 33 | 34 | - name: Deploy to GitHub Pages 35 | run: | 36 | git config --global user.name "GitHub Actions" 37 | git config --global user.email "actions@github.com" 38 | 39 | # Save generated documentation 40 | if [ -d "public" ]; then 41 | echo "Saving generated documentation..." 42 | ls -al public/ 43 | mv public /tmp/public_docs 44 | else 45 | echo "Error: 'public/' directory does not exist. Exiting." 46 | exit 1 47 | fi 48 | 49 | # Clean and switch to gh-pages branch 50 | git reset --hard 51 | git clean -fdx 52 | 53 | if git ls-remote --exit-code origin gh-pages; then 54 | git fetch origin gh-pages 55 | git checkout gh-pages 56 | else 57 | git checkout --orphan gh-pages 58 | fi 59 | 60 | # Clean old content and restore new documentation 61 | echo "Cleaning old content..." 62 | 63 | find . -maxdepth 1 ! -name '.git' ! -name '.' -exec rm -rf {} + 64 | echo "Restoring new documentation..." 65 | mv /tmp/public_docs/* . 66 | 67 | # Deploy to GitHub Pages 68 | touch .nojekyll 69 | git add . 70 | if git diff --cached --quiet; then 71 | echo "No changes to commit. Skipping deployment." 72 | exit 0 73 | else 74 | git commit -m "Deploy updated documentation to GitHub Pages from commit $GITHUB_SHA" 75 | git push origin gh-pages --force 76 | fi 77 | 78 | -------------------------------------------------------------------------------- /.github/workflows/lint_code.yml: -------------------------------------------------------------------------------- 1 | name: Lint Code 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - name: Checkout repository 17 | uses: actions/checkout@v4 18 | 19 | - name: Set up Python 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: '3.10' 23 | 24 | - name: Install linting dependencies 25 | run: | 26 | pip install black==24.10.0 isort==5.13.2 ruff==0.6.9 27 | 28 | - name: Run Black 29 | run: | 30 | black --check . 31 | 32 | - name: Run isort 33 | run: | 34 | isort --check-only . 35 | 36 | - name: Run Ruff 37 | run: | 38 | ruff check . 39 | -------------------------------------------------------------------------------- /.github/workflows/unit_test.yml: -------------------------------------------------------------------------------- 1 | name: Run Unit Tests 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | 10 | jobs: 11 | 12 | build_wheels: 13 | runs-on: ubuntu-24.04 14 | container: 15 | image: 'nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04' 16 | steps: 17 | - name: Update GCC 18 | run: | 19 | export DEBIAN_FRONTEND=noninteractive 20 | apt update && apt install -y build-essential gcc-10 g++-10 21 | - name: Install Python versions and pips 22 | run: | 23 | export DEBIAN_FRONTEND=noninteractive 24 | apt update && apt install -y software-properties-common curl 25 | add-apt-repository ppa:deadsnakes/ppa 26 | apt-get install -y python3.10 python3.10-dev 27 | apt-get install -y python3.11 python3.11-dev 28 | apt-get install -y python3.12 python3.12-dev 29 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 30 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11 31 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 32 | - name: Checkout code 33 | uses: actions/checkout@v4 34 | - name: Build wheel with Python 3.10 35 | run: | 36 | python3.10 -m pip install -U setuptools poetry build six pybind11 37 | python3.10 -m poetry build -f wheel 38 | - name: Build wheel with Python 3.11 39 | run: | 40 | python3.11 -m pip install -U setuptools poetry build six pybind11 41 | python3.11 -m poetry build -f wheel 42 | - name: Build wheel with Python 3.12 43 | run: | 44 | python3.12 -m pip install -U setuptools poetry build six pybind11 45 | python3.12 -m poetry build -f wheel 46 | - name: Upload the wheel artifact 47 | uses: actions/upload-artifact@v4 48 | with: 49 | name: resiliency-wheels 50 | path: dist/*.whl 51 | 52 | unit_tests_cpu_subset: 53 | runs-on: ubuntu-24.04 54 | needs: build_wheels 55 | strategy: 56 | matrix: 57 | container: 58 | - 'pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime' 59 | - 'pytorch/pytorch:2.4.1-cuda12.1-cudnn9-runtime' 60 | - 'pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime' 61 | test_type: ['fault_tolerance', 'ptl_resiliency'] 62 | container: 63 | image: ${{ matrix.container }} 64 | env: 65 | MKL_SERVICE_FORCE_INTEL: 1 # Fix for "MKL_THREADING_LAYER=INTEL is incompatible with libgomp.so.1 library." 66 | steps: 67 | - name: Checkout code 68 | uses: actions/checkout@v4 69 | - name: Download wheels 70 | uses: actions/download-artifact@v4 71 | with: 72 | name: resiliency-wheels 73 | path: ./dist/ 74 | - name: Set up environment 75 | run: | 76 | pip install pytest lightning 77 | PY_VER_NODOT=$(python -c"import sysconfig; print(sysconfig.get_config_var('py_version_nodot'))") 78 | pip install ./dist/nvidia_resiliency_ext-*-cp${PY_VER_NODOT}-*.whl 79 | - name: Run unit tests 80 | shell: bash 81 | run: | 82 | if [[ "${{ matrix.test_type }}" == "straggler" ]]; then 83 | STRAGGLER_DET_CPU_TESTS_PATTERN="test_all_gather_object_calls_num \ 84 | or test_fail_if_not_initialized \ 85 | or test_individual_gpu_scores_one_rank \ 86 | or test_relative_gpu_scores_ \ 87 | or test_name_mapper_ \ 88 | or test_relative_gpu_scores_" 89 | pytest -s -vvv tests/straggler/unit/ -k "${STRAGGLER_DET_CPU_TESTS_PATTERN}" 90 | elif [[ "${{ matrix.test_type }}" == "fault_tolerance" ]]; then 91 | pytest -s -vvv ./tests/fault_tolerance/unit/ 92 | elif [[ "${{ matrix.test_type }}" == "ptl_resiliency" ]]; then 93 | pytest -s -vvv ./tests/ptl_resiliency/unit/test_ft_state_machine.py 94 | else 95 | echo "Unknown test type: ${{ matrix.test_type }}" 96 | exit 1 97 | fi 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | *.egg-info 4 | __pycache__ 5 | cupti_module.*.so 6 | .pytest_cache 7 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | repos: 5 | 6 | - repo: https://github.com/PyCQA/isort 7 | rev: 5.13.2 8 | hooks: 9 | - id: isort 10 | exclude: docs/ 11 | 12 | - repo: https://github.com/psf/black-pre-commit-mirror 13 | rev: 24.10.0 14 | hooks: 15 | - id: black 16 | language_version: python3.10 17 | 18 | - repo: https://github.com/astral-sh/ruff-pre-commit 19 | rev: v0.6.9 20 | hooks: 21 | - id: ruff 22 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NVIDIA Resiliency Extension 2 | 3 | The NVIDIA Resiliency Extension (NVRx) integrates multiple resiliency-focused solutions for PyTorch-based workloads. Users can modularly integrate NVRx capabilities into their own infrastructure to maximize AI training productivity at scale. NVRx maximizes goodput by enabling system-wide health checks, quickly detecting faults at runtime and resuming training automatically. NVRx minimizes loss of work by enabling fast and frequent checkpointing. 4 | 5 | For detailed documentation and usage information about each component, please refer to https://nvidia.github.io/nvidia-resiliency-ext/. 6 | 7 | > ⚠️ NOTE: This project is still experimental and under active development. The code, features, and documentation are evolving rapidly. Please expect frequent updates and breaking changes. Contributions are welcome and we encourage you to watch for updates. 8 | 9 | Figure highlighting core NVRx features including automatic restart, hierarchical checkpointing, fault detection and health checks 10 | 11 | 12 | ## Core Components and Capabilities 13 | 14 | - **[Fault Tolerance](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/fault_tolerance/index.rst)** 15 | - Detection of hung ranks. 16 | - Restarting training in-job, without the need to reallocate SLURM nodes. 17 | 18 | - **[In-Process Restarting](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/inprocess/index.rst)** 19 | - Detecting failures and enabling quick recovery. 20 | 21 | - **[Async Checkpointing](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/checkpointing/async/index.rst)** 22 | - Providing an efficient framework for asynchronous checkpointing. 23 | 24 | - **[Local Checkpointing](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/checkpointing/local/index.rst)** 25 | - Providing an efficient framework for local checkpointing. 26 | 27 | - **[Straggler Detection](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/straggler_det/index.rst)** 28 | - Monitoring GPU and CPU performance of ranks. 29 | - Identifying slower ranks that may impede overall training efficiency. 30 | 31 | - **[PyTorch Lightning Callbacks](https://github.com/NVIDIA/nvidia-resiliency-ext/blob/main/docs/source/fault_tolerance/integration/ptl.rst)** 32 | - Facilitating seamless NVRx integration with PyTorch Lightning. 33 | 34 | ## Installation 35 | 36 | ### From sources 37 | - `git clone https://github.com/NVIDIA/nvidia-resiliency-ext` 38 | - `cd nvidia-resiliency-ext` 39 | - `pip install .` 40 | 41 | 42 | ### From PyPI wheel 43 | - `pip install nvidia-resiliency-ext` 44 | 45 | ### Platform Support 46 | 47 | | Category | Supported Versions / Requirements | 48 | |----------------------|----------------------------------------------------------------------------| 49 | | Architecture | x86_64, arm64 | 50 | | Operating System | Ubuntu 22.04, 24.04 | 51 | | Python Version | >= 3.10, < 3.13 | 52 | | PyTorch Version | >= 2.3.1 (injob & chkpt), >= 2.5.1 (inprocess) | 53 | | CUDA & CUDA Toolkit | >= 12.5 (12.8 required for GPU health check) | 54 | | NVML Driver | >= 535 (570 required for GPU health check) | 55 | | NCCL Version | >= 2.21.5 (injob & chkpt), >= 2.26.2 (inprocess) | 56 | 57 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security 2 | 3 | NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. 4 | 5 | If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.** 6 | 7 | ## Reporting Potential Security Vulnerability in an NVIDIA Product 8 | 9 | To report a potential security vulnerability in any NVIDIA product: 10 | - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) 11 | - E-Mail: psirt@nvidia.com 12 | - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) 13 | - Please include the following information: 14 | - Product/Driver name and version/branch that contains the vulnerability 15 | - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) 16 | - Instructions to reproduce the vulnerability 17 | - Proof-of-concept or exploit code 18 | - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability 19 | 20 | While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. 21 | 22 | ## NVIDIA Product Security 23 | 24 | For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security 25 | -------------------------------------------------------------------------------- /cupti_build.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import glob 17 | import os 18 | 19 | from pybind11.setup_helpers import Pybind11Extension, build_ext 20 | 21 | 22 | def find_file_in_dir(cuda_path, sfile): 23 | """ 24 | Looks for file under the directory specified by the cuda_path argument. 25 | If file is not found, returns None 26 | 27 | Args: 28 | cuda_path (str): Directory where to look for the files 29 | sfile (str): The file to look for (e.g., 'libcupti.so'). 30 | 31 | Returns: 32 | tuple: (directory_of_file1, directory_of_file2) or (None, None) if either file is not found. 33 | """ 34 | 35 | for root, _, files in os.walk(cuda_path): 36 | if sfile in files: 37 | return root 38 | return None 39 | 40 | 41 | def _skip_ext_build(): 42 | ans = os.environ.get('STRAGGLER_DET_SKIP_CUPTI_EXT_BUILD', '0') 43 | return ans.lower() in ['1', 'on', 'yes', 'true'] 44 | 45 | 46 | def build(setup_kwargs): 47 | 48 | if _skip_ext_build(): 49 | print( 50 | "WARNING! CUPTI extension wont be build due to STRAGGLER_DET_SKIP_CUPTI_EXT_BUILD flag." 51 | ) 52 | return 53 | 54 | include_dirs = None 55 | library_dirs = None 56 | 57 | cuda_path = os.environ.get("CUDA_PATH", "/usr/local/cuda") 58 | if not os.path.isdir(cuda_path): 59 | raise FileNotFoundError("cuda installation not found in /usr/local/cuda or $CUDA_PATH") 60 | 61 | cupti_h = "cupti.h" 62 | libcupti_so = "libcupti.so" 63 | idir = find_file_in_dir(cuda_path, cupti_h) 64 | ldir = find_file_in_dir(cuda_path, libcupti_so) 65 | if idir and ldir: 66 | include_dirs = [idir] 67 | library_dirs = [ldir] 68 | else: 69 | raise FileNotFoundError(f"required files {libcupti_so} and {cupti_h} not found") 70 | 71 | cpp_extension = Pybind11Extension( 72 | 'nvrx_cupti_module', 73 | # Sort .cpp files for reproducibility 74 | sorted(glob.glob('src/nvidia_resiliency_ext/straggler/cupti_src/*.cpp')), 75 | include_dirs=include_dirs, 76 | library_dirs=library_dirs, 77 | libraries=['cupti'], 78 | extra_compile_args=['-O3'], 79 | language='c++', 80 | cxx_std=17, 81 | ) 82 | ext_modules = [ 83 | cpp_extension, 84 | ] 85 | setup_kwargs.update( 86 | { 87 | "ext_modules": ext_modules, 88 | "cmdclass": {"build_ext": build_ext}, 89 | "zip_safe": False, 90 | } 91 | ) 92 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/api.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | =============================================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: API documentation 7 | 8 | api/core 9 | api/filesystem_async 10 | api/state_dict_saver 11 | api/torch_ckpt 12 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/api/core.rst: -------------------------------------------------------------------------------- 1 | Asynchronous Checkpoint Core Utilities 2 | ====================================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.core 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/api/filesystem_async.rst: -------------------------------------------------------------------------------- 1 | Asynchronous FileSystemWriter Implementation 2 | ============================================ 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.filesystem_async 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/api/state_dict_saver.rst: -------------------------------------------------------------------------------- 1 | Asynchronous Pytorch Distributed Checkpoint save with optimized `FileSystemWriter` 2 | ================================================================================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.state_dict_saver 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/api/torch_ckpt.rst: -------------------------------------------------------------------------------- 1 | Asynchronous PyTorch `torch.save` with the Core utility 2 | ======================================================= 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | =============================================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Examples 7 | 8 | examples/basic_example.rst 9 | examples/writer_example.rst 10 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/examples/basic_example.rst: -------------------------------------------------------------------------------- 1 | Basic usage example 2 | =============================================================================== 3 | 4 | .. literalinclude:: ../../../../../examples/checkpointing/async_ckpt.py 5 | :language: python 6 | :linenos: 7 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/examples/writer_example.rst: -------------------------------------------------------------------------------- 1 | FileSystemWriter example 2 | =============================================================================== 3 | 4 | .. literalinclude:: ../../../../../examples/checkpointing/async_writer.py 5 | :language: python 6 | :linenos: 7 | -------------------------------------------------------------------------------- /docs/source/checkpointing/async/index.rst: -------------------------------------------------------------------------------- 1 | Async Checkpointing 2 | ================================= 3 | 4 | The asynchronous checkpointing feature in the NVIDIA Resiliency Extension provides core utilities to offload checkpointing routines to the background. 5 | It leverages `torch.multiprocessing` to either fork a temporary process or spawn a persistent process for efficient, non-blocking checkpointing. 6 | 7 | Applications can monitor asynchronous checkpoint progress in a non-blocking manner 8 | and define a custom finalization step once all ranks complete their background checkpoint saving. 9 | 10 | This repository includes an implementation of asynchronous checkpointing utilities for both `torch.save` and `torch.distributed.save_state_dict`. 11 | Our modified `torch.distributed.save_state_dict` interface is integrated with an optimized backend, `FileSystemWriterAsync`, which: 12 | • Runs in the async checkpoint process creating child parallel processes for intra-node parallelism, avoiding GIL contention. 13 | • Minimizes metadata communication overhead by metadata caching, ensuring efficient checkpoint saving. 14 | 15 | 16 | .. toctree:: 17 | :maxdepth: 2 18 | :caption: Contents: 19 | 20 | usage_guide 21 | api 22 | examples 23 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/api.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | =============================================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: API documentation 7 | 8 | api/callback 9 | api/base_ckpt_manager 10 | api/local_ckpt_manager 11 | api/replication 12 | api/base_state_dict 13 | api/basic_state_dict 14 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/api/base_ckpt_manager.rst: -------------------------------------------------------------------------------- 1 | BaseCheckpointManager 2 | ====================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.ckpt_managers.base_manager 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/api/base_state_dict.rst: -------------------------------------------------------------------------------- 1 | BaseTensorAwareStateDict 2 | ======================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.base_state_dict 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/api/basic_state_dict.rst: -------------------------------------------------------------------------------- 1 | BasicTensorAwareStateDict 2 | ========================= 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.basic_state_dict 5 | :members: BasicTensorAwareStateDict 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/api/callback.rst: -------------------------------------------------------------------------------- 1 | PTL Callback support 2 | ==================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.ptl_resiliency.local_checkpoint_callback 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/api/local_ckpt_manager.rst: -------------------------------------------------------------------------------- 1 | LocalCheckpointManager 2 | ====================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/api/replication.rst: -------------------------------------------------------------------------------- 1 | Replication 2 | =========== 3 | 4 | .. automodule:: nvidia_resiliency_ext.checkpointing.local.replication.strategies 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | =============================================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Examples 7 | 8 | examples/basic_example.rst 9 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/examples/basic_example.rst: -------------------------------------------------------------------------------- 1 | Basic usage example 2 | =============================================================================== 3 | 4 | .. literalinclude:: ../../../../../examples/checkpointing/local_ckpt.py 5 | :language: python 6 | :linenos: 7 | -------------------------------------------------------------------------------- /docs/source/checkpointing/local/index.rst: -------------------------------------------------------------------------------- 1 | Local Checkpointing 2 | =================== 3 | 4 | The local checkpointing mechanism is implemented via the Python `LocalCheckpointManager` class, 5 | which operates on a `TensorAwareStateDict` wrapper. 6 | This wrapper encapsulates the operations necessary for efficient replication and data transfers. 7 | 8 | For standard models, 9 | the provided `BasicTensorAwareStateDict` class is typically sufficient for integration. 10 | However, for more advanced use cases, a custom `TensorAwareStateDict` implementation may be required. 11 | 12 | To minimize saving overheads, 13 | integrating the asynchronous version of the `LocalCheckpointManager` method is strongly recommended. 14 | 15 | Features: 16 | 17 | - Local saving: 18 | Each node saves checkpoint parts locally, either on SSDs or RAM disks, as configured by the user. 19 | - Synchronous and asynchronous support: 20 | Save checkpoints either synchronously or asynchronously, based on the application's requirements. 21 | - Automatic cleanup: 22 | Handles the cleanup of broken or outdated checkpoints automatically. 23 | - Optional replication: 24 | The `LocalCheckpointManager.save` method supports an optional replication mechanism 25 | to allow checkpoint recovery in case of node failure after a restart. 26 | - Configurable resiliency: 27 | The replication factor can be adjusted for enhanced resiliency. 28 | - Latest checkpoint detection: 29 | The `find_latest` method in `LocalCheckpointManager` identifies the most recent complete local checkpoint. 30 | - Automated retrieval: 31 | The `LocalCheckpointManager.load` method automatically retrieves valid checkpoint parts that 32 | are unavailable locally. 33 | 34 | For a comprehensive description of this functionality, including detailed 35 | requirements, restrictions, and usage examples, please refer to the :doc:`Usage 36 | Guide ` and :doc:`Examples `. 37 | 38 | 39 | .. toctree:: 40 | :maxdepth: 2 41 | :caption: Contents: 42 | 43 | usage_guide 44 | api 45 | examples 46 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import sys 18 | 19 | sys.path.insert(0, os.path.abspath('../../src')) 20 | 21 | 22 | # -- Project information ----------------------------------------------------- 23 | 24 | project = 'nvidia-resiliency-ext' 25 | copyright = '2024, NVIDIA Corporation' 26 | author = 'NVIDIA Corporation' 27 | 28 | # The full version, including alpha/beta/rc tags 29 | release = '0.1' 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | 'sphinx.ext.autodoc', 39 | 'sphinx.ext.viewcode', 40 | 'sphinx.ext.napoleon', 41 | 'sphinx.ext.intersphinx', 42 | 'sphinx_copybutton', 43 | ] 44 | intersphinx_mapping = { 45 | 'python': ('https://docs.python.org/3', None), 46 | 'torch': ('https://pytorch.org/docs/stable/', None), 47 | } 48 | 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # List of patterns, relative to source directory, that match files and 54 | # directories to ignore when looking for source files. 55 | # This pattern also affects html_static_path and html_extra_path. 56 | exclude_patterns = [] 57 | 58 | autoclass_content = 'both' 59 | autodoc_typehints = 'description' 60 | 61 | # -- Options for HTML output ------------------------------------------------- 62 | 63 | # The theme to use for HTML and HTML Help pages. See the documentation for 64 | # a list of builtin themes. 65 | # 66 | html_theme = 'sphinx_rtd_theme' 67 | 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = ['_static'] 72 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/README-pci-topo-file.md: -------------------------------------------------------------------------------- 1 | # **Providing a PCI Topology File for GPU and NIC Topology Detection** 2 | 3 | ## **Overview** 4 | In certain environments, such as virtual machines (VMs) provided by cloud service providers (CSPs), the PCI device tree may not be fully populated. When this occurs, traversing the system PCI device tree to determine the GPU and NIC topology is not viable. 5 | 6 | To work around this limitation, users can specify a pre-defined PCI topology file using the following option: 7 | 8 | ``` 9 | –ft-pci-topo-file= 10 | ``` 11 | where `` is an XML file describing the PCI topology. 12 | 13 | ## **XML Format Requirements** 14 | The PCI topology file follows a structured XML format with the following key elements: 15 | 16 | 1. **`` Block:** 17 | - Each CPU in the system is represented by a `` block. 18 | 19 | 2. **`` Bridge Block:** 20 | - Within each `` block, there are one or more `` blocks that represent PCI bridges. 21 | - Each PCI bridge has a unique `busid` attribute. 22 | 23 | 3. **GPU and IB PCI Devices:** 24 | - Within each PCI bridge block, the `` elements represent GPU and InfiniBand (IB) devices. 25 | - Each device has its own `busid` and attributes such as `class`, `link_speed`, and `link_width`. 26 | 27 | ### **Example 1: Single PCI Bridge per CPU** 28 | In this example, each CPU has a single PCI bridge, connecting multiple GPUs and IB devices. 29 | 30 | ```xml 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | ``` 48 | 49 | ### **Example 1: Example 2: Multiple PCI Bridges per CPU** 50 | This example shows a topology where each CPU has multiple PCI bridges, with GPUs and IB devices distributed across them. 51 | 52 | ```xml 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | ``` 74 | 75 | ## **Reference Example** 76 | For a detailed working example, refer to the [NDv4 topology file](https://github.com/Azure/azhpc-images/blob/master/topology/ndv4-topo.xml) in the Azure HPC images repository. 77 | 78 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/api.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | :caption: API documentation 7 | 8 | api/config 9 | api/client 10 | api/server 11 | api/callback 12 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/api/callback.rst: -------------------------------------------------------------------------------- 1 | Callback 2 | ======== 3 | 4 | .. automodule:: nvidia_resiliency_ext.ptl_resiliency.fault_tolerance_callback 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/api/client.rst: -------------------------------------------------------------------------------- 1 | Client 2 | ====== 3 | 4 | .. automodule:: nvidia_resiliency_ext.fault_tolerance.rank_monitor_client 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/fault_tolerance/api/config.rst: -------------------------------------------------------------------------------- 1 | Config 2 | ====== 3 | 4 | .. automodule:: nvidia_resiliency_ext.fault_tolerance.config 5 | :members: 6 | :show-inheritance: -------------------------------------------------------------------------------- /docs/source/fault_tolerance/api/server.rst: -------------------------------------------------------------------------------- 1 | Server 2 | ====== 3 | 4 | .. automodule:: nvidia_resiliency_ext.fault_tolerance.rank_monitor_server 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Examples 7 | 8 | examples/basic_example.rst 9 | examples/train_ddp_heartbeats.rst 10 | examples/train_ddp_sections.rst 11 | examples/in_job_and_in_process_example.rst 12 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/examples/basic_example.rst: -------------------------------------------------------------------------------- 1 | Basic usage example 2 | =================== 3 | 4 | .. literalinclude:: ../../../../examples/fault_tolerance/basic_ft_example.py 5 | :language: python 6 | :linenos: -------------------------------------------------------------------------------- /docs/source/fault_tolerance/examples/in_job_and_in_process_example.rst: -------------------------------------------------------------------------------- 1 | FT Launcher & Inprocess integration example 2 | =========================================== 3 | 4 | .. literalinclude:: ../../../../examples/fault_tolerance/in_job_and_in_process_example.py 5 | :language: python 6 | :linenos: 7 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/examples/train_ddp_heartbeats.rst: -------------------------------------------------------------------------------- 1 | Heartbeat API usage example with DDP 2 | ==================================== 3 | 4 | .. literalinclude:: ../../../../examples/fault_tolerance/train_ddp_heartbeats_api.py 5 | :language: python 6 | :linenos: -------------------------------------------------------------------------------- /docs/source/fault_tolerance/examples/train_ddp_sections.rst: -------------------------------------------------------------------------------- 1 | Section API usage example with DDP 2 | ================================== 3 | 4 | .. literalinclude:: ../../../../examples/fault_tolerance/train_ddp_sections_api.py 5 | :language: python 6 | :linenos: -------------------------------------------------------------------------------- /docs/source/fault_tolerance/index.rst: -------------------------------------------------------------------------------- 1 | Fault Tolerance 2 | =============== 3 | 4 | Fault Tolerance is a Python package that features: 5 | * Workload hang detection. 6 | * Automatic calculation of timeouts used for hang detection. 7 | * Detection of rank(s) terminated due to an error. 8 | * Workload respawning in case of a failure. 9 | 10 | Fault Tolerance is included in the ``nvidia_resiliency_ext.fault_tolerance`` package. 11 | 12 | The ``nvidia-resiliency-ext`` package also includes the PTL callback ``FaultToleranceCallback`` that simplifies FT package integration with PyTorch Lightning-based workloads. 13 | ``FaultToleranceCallback`` is included in the ``nvidia_resiliency_ext.ptl_resiliency`` package. 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | :caption: Contents: 18 | 19 | usage_guide 20 | integration 21 | api 22 | examples -------------------------------------------------------------------------------- /docs/source/fault_tolerance/integration.rst: -------------------------------------------------------------------------------- 1 | Integration Guides 2 | ================== 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | :caption: Integration Guides 7 | 8 | integration/heartbeats 9 | integration/sections 10 | integration/ptl 11 | integration/inprocess 12 | 13 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/integration/heartbeats.rst: -------------------------------------------------------------------------------- 1 | Heartbeats API Integration 2 | ************************** 3 | 4 | 1. Prerequisites 5 | ================= 6 | * Run ranks using ``ft_launcher``. The command line is mostly compatible with ``torchrun``. 7 | * Pass the FT config to the ``ft_launcher``. 8 | 9 | .. note:: 10 | Some clusters (e.g., SLURM) use SIGTERM as a default method of requesting a graceful workload shutdown. 11 | It is recommended to implement appropriate signal handling in a fault-tolerant workload. 12 | To avoid deadlocks and other unintended side effects, signal handling should be synchronized across all ranks. 13 | 14 | 2. FT configuration 15 | ==================== 16 | 17 | Timeouts for fault detection need to be adjusted for each workload: 18 | * ``initial_rank_heartbeat_timeout`` should be long enough to allow for workload initialization. 19 | * ``rank_heartbeat_timeout`` should be at least as long as the longest possible interval between steps. 20 | 21 | **Importantly, heartbeats are not sent during checkpoint loading and saving**, so the time for checkpoint-related operations should be taken into account. 22 | 23 | Fixed timeout values can be used throughout the training runs, or timeouts can be calculated based on observed heartbeat intervals. 24 | `null` timeout values are interpreted as infinite timeouts. In such cases, values need to be calculated to make the FT usable. 25 | 26 | .. note:: 27 | When --ft-initial-rank-heartbeat-timeout and --ft-rank-heartbeat-timeout are not 28 | provided in the command-line arguments, the launcher defaults to FT's predefined values. These are 29 | not null/None; currently, the defaults are 60 minutes for --ft-initial-rank-heartbeat-timeout 30 | and 45 minutes for --ft-rank-heartbeat-timeout. 31 | 32 | Configuration file example: 33 | 34 | .. literalinclude:: ../../../../examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml 35 | :language: yaml 36 | :linenos: 37 | 38 | 39 | A summary of all FT configuration items can be found in :class:`nvidia_resiliency_ext.fault_tolerance.config.FaultToleranceConfig` 40 | 41 | 3. Integration with PyTorch workload code 42 | ============================================ 43 | 1. Initialize a ``RankMonitorClient`` instance on each rank with ``RankMonitorClient.init_workload_monitoring()``. 44 | 2. *(Optional)* Restore the state of ``RankMonitorClient`` instances using ``RankMonitorClient.load_state_dict()``. 45 | 3. Periodically send heartbeats from ranks using ``RankMonitorClient.send_heartbeat()``. 46 | 4. *(Optional)* After a sufficient range of heartbeat intervals has been observed, call ``RankMonitorClient.calculate_and_set_hb_timeouts()`` to estimate timeouts. 47 | 5. *(Optional)* Save the ``RankMonitorClient`` instance's ``state_dict()`` to a file so that computed timeouts can be reused in the next run. 48 | 6. Shut down ``RankMonitorClient`` instances using ``RankMonitorClient.shutdown_workload_monitoring()``. 49 | 50 | Please refer to the :doc:`../examples/train_ddp_heartbeats` for an implementation example. 51 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/integration/inprocess.rst: -------------------------------------------------------------------------------- 1 | FT Launcher & Inprocess integration 2 | *********************************** 3 | **FT launcher** integrates with **Inprocess recovery mechanisms**, improving fault tolerance by coordinating injob and inprocess fault recovery. 4 | 5 | 1. Heartbeat Mechanism 6 | ====================== 7 | * The **FT launcher heartbeat** remains active throughout execution to detect and mitigate potential hangs. 8 | * Users must configure timeouts manually, ensuring they exceed **inprocess operational timeouts** to prevent conflicts. 9 | 10 | 2. Worker Monitoring & Restart Policy 11 | ===================================== 12 | A new ``--ft-restart-policy`` argument in ``ft_launcher`` modifies the default worker monitor logic for better compatibility with :doc:`../../inprocess/index`. 13 | 14 | **Policy Options** 15 | 16 | * ``min-healthy``: Restarts workers only when the number of healthy worker groups falls below minimum specified in ``--nnodes``, as set in ``ft_launcher``. 17 | 18 | .. note:: 19 | 20 | For proper behavior, minimum specified in ``--nnodes`` should match the ``inprocess`` restarter setting by either: 21 | 22 | - Ensuring ``inprocess`` operates at the node level like ``injob`` by adding a ``rank_assignment`` filter to the wrapper, or 23 | - Making ``injob`` operate at the rank level like ``inprocess`` by specifying one rank per agent. 24 | 25 | See the `rank assignment guide <../../inprocess/usage_guide.html#rank-assignment>`_ for more details. 26 | 27 | **Example of rank_assignment:** 28 | 29 | .. code-block:: python 30 | 31 | rank_assignment = ( 32 | inprocess.Compose( 33 | inprocess.rank_assignment.ShiftRanks(), 34 | inprocess.rank_assignment.FilterGroupedByKey( 35 | key_or_fn=lambda _, _: socket.gethostname(), 36 | condition=lambda count: count == 8, 37 | ), 38 | ), 39 | ) 40 | 41 | **Behavior in min-healthy mode:** 42 | 43 | * If enough nodes remain healthy, the worker monitor stays inactive while collaborating with :doc:`../../inprocess/index`.. 44 | * If the threshold is breached, ``FT launcher`` takes over and restarts the training process. 45 | 46 | 47 | Supported & Unsupported Configurations 48 | ====================================== 49 | 50 | To ensure correct behavior with inprocess: 51 | 52 | ✅ Supported: 53 | 54 | * ``restart-policy=min-healthy`` **(Required)**: 55 | 56 | * Prevents unintended upscaling. 57 | * Disables any-failed worker monitoring. 58 | 59 | ❌ Unsupported: 60 | 61 | * ``any-failed`` with inprocess **(Not allowed)**: 62 | 63 | * Incompatible with inprocess restarts. 64 | * Causes FT launcher to misinterpret terminated processes as failures, triggering unnecessary restarts. 65 | * Enables upscaling, allowing FT launcher to restart training when a new node becomes available. 66 | * Can lead to undefined behavior when combined with inprocess restarts. 67 | 68 | In short, ``any-failed`` must not be used with inprocess, as it disrupts the intended fault recovery process. 69 | 70 | Please refer to the :doc:`../examples/in_job_and_in_process_example` for an implementation example. 71 | -------------------------------------------------------------------------------- /docs/source/fault_tolerance/integration/ptl.rst: -------------------------------------------------------------------------------- 1 | PyTorch Lightning Integration 2 | ***************************** 3 | 4 | This section describes Fault Tolerance integration with a PTL-based workload (i.e., NeMo) using ``FaultToleranceCallback``. 5 | 6 | 1. Use ``ft_launcher`` to start the workload 7 | ============================================ 8 | 9 | Fault tolerance relies on a special launcher (``ft_launcher``), which is a modified ``torchrun``. 10 | If you are using NeMo, the `NeMo-Framework-Launcher `_ can be used to generate SLURM batch scripts with FT support. 11 | 12 | 2. Add the FT callback to the PTL trainer 13 | ========================================== 14 | 15 | Add the FT callback to the PTL callbacks. 16 | 17 | .. code-block:: python 18 | 19 | from nvidia_resiliency_ext.ptl_resiliency import FaultToleranceCallback 20 | 21 | fault_tol_cb = FaultToleranceCallback( 22 | autoresume=True, 23 | calculate_timeouts=True, 24 | logger_name="test_logger", 25 | exp_dir=tmp_path, 26 | ) 27 | 28 | trainer = pl.Trainer( 29 | ... 30 | callbacks=[..., fault_tol_cb], 31 | resume_from_checkpoint=True, 32 | ) 33 | 34 | 35 | Core FT callback functionality includes: 36 | * Establishing a connection with a rank monitor. 37 | * Sending heartbeats during training and evaluation steps. 38 | * Disconnecting from a rank monitor. 39 | 40 | Optionally, it can also: 41 | * Compute timeouts that will be used instead of timeouts defined in the FT config. 42 | * Create a flag file when the training is completed. 43 | 44 | FT callback initialization parameters are described in the ``FaultToleranceCallback`` constructor docstring: 45 | :class:`nvidia_resiliency_ext.ptl_resiliency.fault_tolerance_callback.FaultToleranceCallback` 46 | 47 | 3. Implementing auto-resume 48 | =========================== 49 | 50 | Auto-resume simplifies running training jobs that consist of multiple sequential runs. 51 | 52 | .. note:: 53 | Auto-resume is not part of the FT package. It is entirely implemented in a launcher script and the ``FaultToleranceCallback``. 54 | 55 | ``FaultToleranceCallback`` exposes an "interface" that allows implementing an auto-resume launcher script. Specifically, if ``autoresume=True``, 56 | the FT callback creates a special marker file when training is completed. The marker file location is expected to be set in the ``FAULT_TOL_FINISHED_FLAG_FILE`` environment variable. 57 | 58 | The following steps can be used to implement an auto-resume launcher script: 59 | * The launcher script starts ranks with ``ft_launcher``. 60 | * ``FAULT_TOL_FINISHED_FLAG_FILE`` should be passed to rank processes. 61 | * When ``ft_launcher`` exits, the launcher script checks if the ``FAULT_TOL_FINISHED_FLAG_FILE`` file was created. 62 | 63 | * If ``FAULT_TOL_FINISHED_FLAG_FILE`` exists, the auto-resume loop stops, as training is complete. 64 | * If ``FAULT_TOL_FINISHED_FLAG_FILE`` does not exist, the continuation job can be issued (other conditions can be checked, e.g., if the maximum number of failures is not reached). -------------------------------------------------------------------------------- /docs/source/fault_tolerance/integration/sections.rst: -------------------------------------------------------------------------------- 1 | Sections API Integration 2 | ************************ 3 | 4 | 1. Prerequisites 5 | ================= 6 | * Run ranks using ``ft_launcher``. The command line is mostly compatible with ``torchrun``. 7 | * Pass the FT config to the ``ft_launcher``. 8 | 9 | .. note:: 10 | Some clusters (e.g., SLURM) use SIGTERM as a default method of requesting a graceful workload shutdown. 11 | It is recommended to implement appropriate signal handling in a fault-tolerant workload. 12 | To avoid deadlocks and other unintended side effects, signal handling should be synchronized across all ranks. 13 | 14 | 2. FT configuration 15 | ==================== 16 | 17 | With the section-based API, timeouts must be set for all defined sections, which wrap operations like 18 | training/eval steps, checkpoint saving, and initialization. Additionally, an out-of-section timeout 19 | applies when no section is active. 20 | 21 | .. note:: 22 | Ensure out-of-section timeout is long enough to accommodate restart overhead, as excessively small values can cause imbalance. 23 | If needed, consider merging sections (e.g., moving 'init' into 'out-of-section') to provide more buffer time. 24 | 25 | Relevant FT configuration items are: 26 | * ``rank_section_timeouts`` is a map of a section name to its timeout. 27 | * ``rank_out_of_section_timeout`` is the out-of-section timeout. 28 | 29 | Fixed timeout values can be used throughout the training runs, or timeouts can be calculated based on observed intervals. 30 | `null` timeout values are interpreted as infinite timeouts. In such cases, values need to be calculated to make the FT usable. 31 | 32 | Config file example: 33 | 34 | .. literalinclude:: ../../../../examples/fault_tolerance/fault_tol_cfg_sections.yaml 35 | :language: yaml 36 | :linenos: 37 | 38 | A summary of all FT configuration items can be found in :class:`nvidia_resiliency_ext.fault_tolerance.config.FaultToleranceConfig` 39 | 40 | 3. Integration with PyTorch workload code 41 | ============================================ 42 | 1. Initialize a ``RankMonitorClient`` instance on each rank with ``RankMonitorClient.init_workload_monitoring()``. 43 | 2. *(Optional)* Restore the state of ``RankMonitorClient`` instances using ``RankMonitorClient.load_state_dict()``. 44 | 3. Mark some sections of the code with ``RankMonitorClient.start_section('
')`` and ``RankMonitorClient.end_section('
')``. 45 | 4. *(Optional)* After a sufficient range of section intervals has been observed, call ``RankMonitorClient.calculate_and_set_section_timeouts()`` to estimate timeouts. 46 | 5. *(Optional)* Save the ``RankMonitorClient`` instance's ``state_dict()`` to a file so that computed timeouts can be reused in the next run. 47 | 6. Shut down ``RankMonitorClient`` instances using ``RankMonitorClient.shutdown_workload_monitoring()``. 48 | 49 | Please refer to the :doc:`../examples/train_ddp_sections` for an implementation example. 50 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | nvidia-resiliency-ext v0.4.0 2 | ============================= 3 | 4 | **nvidia-resiliency-ext** is a set of tools developed by NVIDIA to improve large-scale distributed training resiliency. 5 | 6 | .. image:: ./media/nvrx_docs_source.png 7 | :width: 750 8 | :alt: Figure highlighting core NVRx features including automatic restart, hierarchical checkpointing, fault detection and health checks. 9 | 10 | Features 11 | -------- 12 | 13 | * `Hang detection and automatic in-job restarting `_ 14 | * `In-process restarting `_ 15 | * `Async checkpointing `_ 16 | * `Local checkpointing `_ 17 | * `Straggler (slower ranks) detection `_ 18 | 19 | .. toctree:: 20 | :maxdepth: 3 21 | :caption: Documentation contents: 22 | 23 | fault_tolerance/index 24 | inprocess/index 25 | checkpointing/async/index 26 | checkpointing/local/index 27 | straggler_det/index 28 | -------------------------------------------------------------------------------- /docs/source/inprocess/api.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | =============================================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | :caption: API documentation 7 | 8 | api/wrap 9 | api/compose 10 | api/state 11 | api/rank_assignment 12 | api/rank_filter 13 | api/initialize 14 | api/abort 15 | api/finalize 16 | api/health_check 17 | api/exception 18 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/abort.rst: -------------------------------------------------------------------------------- 1 | Abort 2 | =============================================================================== 3 | 4 | .. autoclass:: nvidia_resiliency_ext.inprocess.abort.Abort 5 | :special-members: __call__ 6 | 7 | .. automodule:: nvidia_resiliency_ext.inprocess.abort 8 | :members: 9 | :exclude-members: Abort 10 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/compose.rst: -------------------------------------------------------------------------------- 1 | Compose 2 | =============================================================================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.inprocess 5 | :members: Compose 6 | :ignore-module-all: 7 | :no-index: 8 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/exception.rst: -------------------------------------------------------------------------------- 1 | Exceptions 2 | =============================================================================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.inprocess.exception 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/finalize.rst: -------------------------------------------------------------------------------- 1 | Finalize 2 | =============================================================================== 3 | 4 | .. autoclass:: nvidia_resiliency_ext.inprocess.finalize.Finalize 5 | :special-members: __call__ 6 | 7 | .. automodule:: nvidia_resiliency_ext.inprocess.finalize 8 | :members: 9 | :exclude-members: Finalize 10 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/health_check.rst: -------------------------------------------------------------------------------- 1 | Health Check 2 | =============================================================================== 3 | 4 | .. autoclass:: nvidia_resiliency_ext.inprocess.health_check.HealthCheck 5 | :special-members: __call__ 6 | 7 | .. automodule:: nvidia_resiliency_ext.inprocess.health_check 8 | :members: 9 | :exclude-members: HealthCheck 10 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/initialize.rst: -------------------------------------------------------------------------------- 1 | Initialize 2 | =============================================================================== 3 | 4 | .. autoclass:: nvidia_resiliency_ext.inprocess.initialize.Initialize 5 | :special-members: __call__ 6 | 7 | .. automodule:: nvidia_resiliency_ext.inprocess.initialize 8 | :members: 9 | :exclude-members: Initialize 10 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/rank_assignment.rst: -------------------------------------------------------------------------------- 1 | Rank Assignment 2 | =============================================================================== 3 | 4 | Rank Assignment 5 | --------------- 6 | Base class 7 | ^^^^^^^^^^ 8 | .. autoclass:: nvidia_resiliency_ext.inprocess.rank_assignment.RankAssignment 9 | :special-members: __call__ 10 | 11 | .. autoclass:: nvidia_resiliency_ext.inprocess.rank_assignment.RankAssignmentCtx 12 | 13 | .. autoexception:: nvidia_resiliency_ext.inprocess.rank_assignment.RankDiscarded 14 | 15 | Tree 16 | ^^^^ 17 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_assignment 18 | :members: Layer, LayerFlag, Tree 19 | 20 | Composable Rank Assignments 21 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ 22 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_assignment 23 | :members: FillGaps, ShiftRanks, FilterCountGroupedByKey 24 | :no-index: 25 | 26 | Rank Filtering 27 | -------------- 28 | Base class 29 | ^^^^^^^^^^ 30 | .. autoclass:: nvidia_resiliency_ext.inprocess.rank_assignment.RankFilter 31 | :special-members: __call__ 32 | 33 | Rank Filters 34 | ^^^^^^^^^^^^ 35 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_assignment 36 | :members: ActivateAllRanks, MaxActiveWorldSize, ActiveWorldSizeDivisibleBy 37 | :no-index: 38 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/rank_filter.rst: -------------------------------------------------------------------------------- 1 | Rank Filter 2 | =============================================================================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.inprocess.rank_filter 5 | :members: 6 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/state.rst: -------------------------------------------------------------------------------- 1 | State 2 | =============================================================================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.inprocess 5 | :members: Mode 6 | :no-index: 7 | 8 | .. automodule:: nvidia_resiliency_ext.inprocess 9 | :members: FrozenState, State 10 | :no-index: 11 | -------------------------------------------------------------------------------- /docs/source/inprocess/api/wrap.rst: -------------------------------------------------------------------------------- 1 | Wrapper 2 | =============================================================================== 3 | 4 | .. automodule:: nvidia_resiliency_ext.inprocess 5 | :members: Wrapper, CallWrapper 6 | :ignore-module-all: 7 | -------------------------------------------------------------------------------- /docs/source/inprocess/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | =============================================================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Examples 7 | 8 | examples/basic_example.rst 9 | examples/optimal_example.rst 10 | -------------------------------------------------------------------------------- /docs/source/inprocess/examples/basic_example.rst: -------------------------------------------------------------------------------- 1 | Basic usage example 2 | =============================================================================== 3 | 4 | .. literalinclude:: ../../../../examples/inprocess/basic_example.py 5 | :language: python 6 | :linenos: 7 | -------------------------------------------------------------------------------- /docs/source/inprocess/examples/optimal_example.rst: -------------------------------------------------------------------------------- 1 | Optimal usage example 2 | =============================================================================== 3 | 4 | .. literalinclude:: ../../../../examples/inprocess/optimal_example.py 5 | :language: python 6 | :linenos: 7 | -------------------------------------------------------------------------------- /docs/source/inprocess/index.rst: -------------------------------------------------------------------------------- 1 | Inprocess Restart 2 | ================= 3 | 4 | In-process restart mechanism is implemented via a Python function wrapper that 5 | adds restart capabilities to an existing Python function implementing 6 | distributed PyTorch workload. Upon a fault, the wrapped function is restarted 7 | across all distributed ranks, within the same operating system process. 8 | Invoking restart of the wrapped function excludes distributed ranks that are 9 | terminated, missing, or deemed unhealthy. When a failure occurs on any worker, 10 | the wrapper ensures the function restarts simultaneously on all healthy ranks. 11 | This process continues until all ranks complete execution successfully or a 12 | termination condition is met. 13 | 14 | Compared to a traditional scheduler-level restart, restarting within the same 15 | process removes overheads associated with launching a new scheduler job, 16 | starting a container, initializing a new Python interpreter, loading 17 | dependencies, and creating a new CUDA context. 18 | 19 | Restarting in the same process also enables the reuse of pre-instantiated, 20 | process-group- and rank-independent objects across restart attempts. This reuse 21 | eliminates the overhead of repeated reinitialization and minimizes restart 22 | latency. 23 | 24 | Features: 25 | 26 | - automatic deinitialization of PyTorch distributed process group, and restart 27 | of the wrapped function upon encountering an unhandled Python exception in 28 | any distributed rank 29 | - timeout mechanism to detect and recover from deadlocks or livelocks, and a 30 | guarantee that the job is making meaningful forward progress 31 | - modular and customizable rank reassignment and health check functions 32 | - support for pre-allocated and pre-initialized reserve workers 33 | - gradual engineering ramp up: integration with existing codebase may start 34 | from restarting the entire ``main()`` function, then gradually refactor 35 | process-group-independent initialization into a separate function call in 36 | order to maximally reuse Python objects between restarts and minimize fault 37 | recovery overhead 38 | 39 | For a comprehensive description of this functionality, including detailed 40 | requirements, restrictions, and usage examples, please refer to the :doc:`Usage 41 | Guide ` and :doc:`Examples `. 42 | 43 | 44 | 45 | .. toctree:: 46 | :maxdepth: 2 47 | :caption: Contents: 48 | 49 | usage_guide 50 | api 51 | examples 52 | -------------------------------------------------------------------------------- /docs/source/media/nvrx_core_features.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/docs/source/media/nvrx_core_features.png -------------------------------------------------------------------------------- /docs/source/media/nvrx_docs_source.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/docs/source/media/nvrx_docs_source.png -------------------------------------------------------------------------------- /docs/source/straggler_det/api.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | ================= 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | :caption: API documentation 7 | 8 | api/straggler 9 | api/reporting 10 | api/statistics 11 | api/callback 12 | -------------------------------------------------------------------------------- /docs/source/straggler_det/api/callback.rst: -------------------------------------------------------------------------------- 1 | Callback 2 | ======== 3 | 4 | .. automodule:: nvidia_resiliency_ext.ptl_resiliency.straggler_det_callback 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/source/straggler_det/api/reporting.rst: -------------------------------------------------------------------------------- 1 | Reporting 2 | ========= 3 | 4 | .. automodule:: nvidia_resiliency_ext.straggler.reporting 5 | :members: 6 | :show-inheritance: 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/source/straggler_det/api/statistics.rst: -------------------------------------------------------------------------------- 1 | Statistics 2 | ========== 3 | 4 | .. automodule:: nvidia_resiliency_ext.straggler.statistics 5 | :members: 6 | :show-inheritance: 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/source/straggler_det/api/straggler.rst: -------------------------------------------------------------------------------- 1 | Straggler 2 | ========= 3 | 4 | .. automodule:: nvidia_resiliency_ext.straggler.straggler 5 | :members: 6 | :show-inheritance: 7 | -------------------------------------------------------------------------------- /docs/source/straggler_det/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Examples 7 | 8 | examples/basic_example.rst 9 | -------------------------------------------------------------------------------- /docs/source/straggler_det/examples/basic_example.rst: -------------------------------------------------------------------------------- 1 | Basic usage example 2 | =================== 3 | 4 | .. literalinclude:: ../../../../examples/straggler/example.py 5 | :language: python 6 | :linenos: -------------------------------------------------------------------------------- /docs/source/straggler_det/index.rst: -------------------------------------------------------------------------------- 1 | Straggler Detection 2 | =================== 3 | 4 | The **Straggler Detection** package's purpose is to detect slower ranks participating in a PyTorch distributed workload. 5 | The ``nvidia-resiliency-ext`` package also includes the PTL callback ``StragglerDetectionCallback`` that simplifies integration with PyTorch Lightning-based workloads. 6 | 7 | Straggler Detection is included in the ``nvidia_resiliency_ext.straggler`` package. 8 | ``StragglerDetectionCallback`` is included in the ``nvidia_resiliency_ext.ptl_resiliency`` package. 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: Contents: 13 | 14 | usage_guide 15 | api 16 | examples -------------------------------------------------------------------------------- /examples/checkpointing/async_ckpt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn as nn 8 | 9 | from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint 10 | 11 | # Set up basic logging configuration 12 | logging.basicConfig(level=logging.INFO) 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser( 17 | description='Async Checkpointing Basic Example', 18 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 19 | ) 20 | parser.add_argument( 21 | '--ckpt_dir', 22 | default="/tmp/test_async_ckpt/", 23 | help="Checkpoint directory for async checkpoints", 24 | ) 25 | parser.add_argument( 26 | '--persistent_queue', 27 | action='store_true', 28 | help="Enables a persistent version of AsyncCallsQueue.", 29 | ) 30 | return parser.parse_args() 31 | 32 | 33 | # Define a simple model 34 | class SimpleModel(nn.Module): 35 | def __init__(self): 36 | super(SimpleModel, self).__init__() 37 | self.fc1 = nn.Linear(10, 5) # Linear layer: input size 10, output size 5 38 | self.fc2 = nn.Linear(5, 2) # Linear layer: input size 5, output size 2 39 | self.activation = nn.ReLU() # Activation function: ReLU 40 | 41 | def forward(self, x): 42 | x = self.activation(self.fc1(x)) 43 | x = self.fc2(x) 44 | return x 45 | 46 | 47 | def init_distributed_backend(backend="nccl"): 48 | """ 49 | Initialize the distributed process group for NCCL backend. 50 | Assumes the environment variables (CUDA_VISIBLE_DEVICES, etc.) are already set. 51 | """ 52 | try: 53 | dist.init_process_group( 54 | backend=backend, # Use NCCL backend 55 | init_method="env://", # Use environment variables for initialization 56 | ) 57 | logging.info(f"Rank {dist.get_rank()} initialized with {backend} backend.") 58 | 59 | # Ensure each process uses a different GPU 60 | torch.cuda.set_device(dist.get_rank()) 61 | except Exception as e: 62 | logging.error(f"Error initializing the distributed backend: {e}") 63 | raise 64 | 65 | 66 | def cleanup(ckpt_dir): 67 | if dist.get_rank() == 0: 68 | logging.info(f"Cleaning up checkpoint directory: {ckpt_dir}") 69 | for file_item in os.scandir(ckpt_dir): 70 | if file_item.is_file(): 71 | os.remove(file_item.path) 72 | 73 | 74 | def main(): 75 | args = parse_args() 76 | logging.info(f'{args}') 77 | 78 | # Initialize the distributed backend 79 | init_distributed_backend(backend="nccl") 80 | 81 | # Instantiate the model and move to CUDA 82 | model = SimpleModel().to("cuda") 83 | org_sd = model.state_dict() 84 | # Define checkpoint directory and manager 85 | ckpt_dir = args.ckpt_dir 86 | os.makedirs(ckpt_dir, exist_ok=True) 87 | logging.info(f"Created checkpoint directory: {ckpt_dir}") 88 | ckpt_file_name = os.path.join(ckpt_dir, f"ckpt_rank{torch.distributed.get_rank()}.pt") 89 | 90 | ckpt_impl = TorchAsyncCheckpoint(persistent_queue=args.persistent_queue) 91 | 92 | ckpt_impl.async_save(org_sd, ckpt_file_name) 93 | 94 | ckpt_impl.finalize_async_save(blocking=True, no_dist=True, terminate=True) 95 | 96 | loaded_sd = torch.load(ckpt_file_name, map_location="cuda") 97 | 98 | for k in loaded_sd.keys(): 99 | assert torch.equal(loaded_sd[k], org_sd[k]), f"loaded_sd[{k}] != org_sd[{k}]" 100 | 101 | # Synchronize processes to ensure all have completed the loading 102 | dist.barrier() 103 | 104 | # Clean up checkpoint directory only on rank 0 105 | cleanup(ckpt_dir) 106 | 107 | # Ensure NCCL process group is properly destroyed 108 | if dist.is_initialized(): 109 | dist.destroy_process_group() 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /examples/fault_tolerance/fault_tol_cfg_heartbeats.yaml: -------------------------------------------------------------------------------- 1 | fault_tolerance: 2 | initial_rank_heartbeat_timeout: null 3 | rank_heartbeat_timeout: null 4 | log_level: "DEBUG" 5 | -------------------------------------------------------------------------------- /examples/fault_tolerance/fault_tol_cfg_sections.yaml: -------------------------------------------------------------------------------- 1 | fault_tolerance: 2 | initial_rank_heartbeat_timeout: null 3 | rank_heartbeat_timeout: null 4 | rank_section_timeouts: 5 | init: 20 6 | step: 10 7 | checkpoint: 30 8 | rank_out_of_section_timeout: 30 9 | log_level: "DEBUG" 10 | -------------------------------------------------------------------------------- /examples/fault_tolerance/log_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import os 18 | import sys 19 | 20 | import dist_utils 21 | 22 | 23 | def setup_logging(log_all_ranks=True, filename=os.devnull, filemode='w'): 24 | """ 25 | Configures logging. 26 | By default logs from all workers are printed to the stderr, entries are 27 | prefixed with "N: " where N is the rank of the worker. Logs printed to the stderr don't include timestaps. 28 | Full logs with timestamps are saved to the log_file file. 29 | """ 30 | 31 | class RankFilter(logging.Filter): 32 | def __init__(self, rank, log_all_ranks): 33 | self.rank = rank 34 | self.log_all_ranks = log_all_ranks 35 | 36 | def filter(self, record): 37 | record.rank = self.rank 38 | if self.log_all_ranks: 39 | return True 40 | else: 41 | return self.rank == 0 42 | 43 | rank = dist_utils.get_rank() 44 | rank_filter = RankFilter(rank, log_all_ranks) 45 | 46 | if log_all_ranks: 47 | logging_format = f"%(asctime)s - %(levelname)s - {rank} - %(message)s" 48 | else: 49 | logging_format = "%(asctime)s - %(levelname)s - %(message)s" 50 | if rank != 0: 51 | filename = os.devnull 52 | 53 | for handler in logging.root.handlers[:]: 54 | logging.root.removeHandler(handler) 55 | handler.close() 56 | 57 | logging.basicConfig( 58 | level=logging.DEBUG, 59 | format=logging_format, 60 | datefmt="%Y-%m-%d %H:%M:%S", 61 | filename=filename, 62 | filemode=filemode, 63 | ) 64 | stderr = logging.StreamHandler(sys.stderr) 65 | stderr.setLevel(logging.DEBUG) 66 | if log_all_ranks: 67 | formatter = logging.Formatter(f'{rank}: %(message)s') 68 | else: 69 | formatter = logging.Formatter('%(message)s') 70 | stderr.setFormatter(formatter) 71 | logging.getLogger('').addHandler(stderr) 72 | logging.getLogger('').addFilter(rank_filter) 73 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | [tool.poetry] 5 | name = "nvidia-resiliency-ext" 6 | repository = "https://github.com/NVIDIA/nvidia-resiliency-ext" 7 | version = "0.4.0" 8 | description = "NVIDIA Resiliency Package" 9 | authors = ["NVIDIA Corporation"] 10 | readme = "README.md" 11 | license = "Apache 2.0" 12 | classifiers = [ 13 | "Development Status :: 4 - Beta", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.10", 16 | "Operating System :: OS Independent", 17 | ] 18 | packages = [ 19 | { include = "nvidia_resiliency_ext", from = "src" }, 20 | ] 21 | 22 | exclude = [ 23 | "src/nvidia_resiliency_ext/straggler/cupti_src" 24 | ] 25 | 26 | [tool.poetry.build] 27 | script = "cupti_build.py" 28 | generate-setup-file = true 29 | 30 | [build-system] 31 | requires = ["poetry-core>=1.0.0", "pybind11", "setuptools", "wheel"] 32 | build-backend = "poetry.core.masonry.api" 33 | 34 | [tool.poetry.dependencies] 35 | torch = ">=2.3.0" 36 | packaging = "*" 37 | python = ">=3.10" 38 | psutil = ">=6.0.0" 39 | pyyaml = "*" 40 | pynvml = ">=12.0.0" 41 | nvidia-ml-py = ">=12.570.86" 42 | defusedxml = "*" 43 | 44 | [tool.poetry.scripts] 45 | ft_launcher = "nvidia_resiliency_ext.fault_tolerance.launcher:main" 46 | 47 | 48 | [tool.isort] 49 | profile = "black" # black-compatible 50 | line_length = 100 # should match black parameters 51 | py_version = 310 # python 3.10 as a target version 52 | # filter_files and extend_skip_glob are needed for pre-commit to filter out the files 53 | filter_files = true 54 | extend_skip_glob = [ 55 | "setup.py", 56 | "cupti_build.py", 57 | "src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/*" 58 | ] 59 | 60 | [tool.black] 61 | line_length = 100 62 | skip_string_normalization = true 63 | # major year version is stable, see details in 64 | # https://black.readthedocs.io/en/stable/the_black_code_style/index.html 65 | # `required_version` is necessary for consistency (other `black` versions will fail to reformat files) 66 | required_version = "24" 67 | target-version = ['py310', 'py311', 'py312'] 68 | force-exclude = ''' 69 | # Force exclude, as this config is also used by pre-commit 70 | # https://stackoverflow.com/questions/73247204/black-not-respecting-extend-exclude-in-pyproject-toml 71 | # A regex preceded with ^/ will apply only to files and directories # in the root of the project. 72 | ( 73 | ^\/setup.py\/ 74 | | ^\/build.py\/ 75 | | ^\/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat\/ 76 | ) 77 | ''' 78 | 79 | [tool.ruff] 80 | extend-exclude = [ 81 | "setup.py", 82 | "build.py", 83 | "src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat" 84 | ] 85 | 86 | [tool.ruff.lint] 87 | # F841 Local variable `...` is assigned to but never used 88 | ignore = ["F841"] 89 | 90 | [tool.ruff.lint.per-file-ignores] 91 | # avoid "unused import" warnings for __init__.py files 92 | "__init__.py" = ["F401"] 93 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/checkpointing/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/checkpointing/async_ckpt/cached_metadata_filesystem_reader.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ FS Reader with metadata cached support. """ 17 | 18 | import os 19 | from typing import Union 20 | 21 | from torch.distributed.checkpoint import FileSystemReader, Metadata 22 | 23 | 24 | class CachedMetadataFileSystemReader(FileSystemReader): 25 | """ 26 | Extends FileSystemReader to cache metadata for improved performance. 27 | 28 | Attributes: 29 | _cached_metadata (Metadata or None): Cached metadata from the file system. 30 | """ 31 | 32 | def __init__(self, path: Union[str, os.PathLike]) -> None: 33 | """ 34 | Initialize with file system path. 35 | 36 | Args: 37 | path (Union[str, os.PathLike]): Path to the checkpoint directory or file. 38 | """ 39 | super().__init__(path=path) 40 | self._cached_metadata = None 41 | 42 | def read_metadata(self) -> Metadata: 43 | """ 44 | Read metadata from file system, caching for subsequent calls. 45 | 46 | Returns: 47 | Metadata: Checkpoint metadata. 48 | """ 49 | if self._cached_metadata is None: 50 | self._cached_metadata = super().read_metadata() 51 | return self._cached_metadata 52 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/checkpointing/async_ckpt/torch_ckpt.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ 17 | TorchAsyncCheckpoint defines a wrapper for the async version of `torch.save` with 18 | an additional method to synchronize async saving requests 19 | """ 20 | 21 | 22 | import logging 23 | 24 | import torch 25 | 26 | from ..utils import preload_tensors, wrap_for_async 27 | from .core import AsyncCallsQueue, AsyncRequest 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class TorchAsyncCheckpoint(object): 33 | async_fn = None 34 | 35 | def __init__(self, persistent_queue=False): 36 | self.save = torch.save 37 | self._async_calls_queue = AsyncCallsQueue(persistent=persistent_queue) 38 | # Use direct torch.save for persistent queue, avoid unnecessary wrapping 39 | TorchAsyncCheckpoint.async_fn = ( 40 | torch.save if persistent_queue else wrap_for_async(torch.save) 41 | ) 42 | 43 | def async_save(self, state_dict, *args, **kwargs): 44 | """ 45 | Keeps the original interface of `torch.save` 46 | Schedules a `AsyncReuqest` with preloading tensors to CPU with pinned memcpy 47 | """ 48 | 49 | preloaded_sd = preload_tensors(state_dict) 50 | torch.cuda.synchronize() 51 | async_request = AsyncRequest( 52 | TorchAsyncCheckpoint.async_fn, (preloaded_sd, *args), [], kwargs 53 | ) 54 | self._async_calls_queue.schedule_async_request(async_request) 55 | 56 | def finalize_async_save(self, blocking: bool = False, no_dist=True, terminate=False): 57 | """Finalizes active async save calls. 58 | 59 | Args: 60 | blocking (bool, optional): if True, will wait until all active requests 61 | are done. Otherwise, finalizes only the async request that already 62 | finished. Defaults to False. 63 | no_dist (bool, Optional): if True, training ranks simply check its 64 | asynchronous checkpoint writer without synchronization. 65 | terminate (bool, optional): if True, the asynchronous queue will 66 | be closed as the last action of this function. 67 | """ 68 | if blocking and self._async_calls_queue.get_num_unfinalized_calls() > 0: 69 | if torch.distributed.get_rank() == 0: 70 | logger.info( 71 | 'Unfinalized async checkpoint saves. Finalizing them synchronously now.' 72 | ) 73 | 74 | self._async_calls_queue.maybe_finalize_async_calls(blocking, no_dist=no_dist) 75 | if terminate: 76 | self._async_calls_queue.close() 77 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/checkpointing/local/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/checkpointing/local/replication/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/checkpointing/local/replication/utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | def zip_strict(*args): 18 | """ 19 | Alternative to Python's builtin zip(..., strict=True) (available in 3.10+). 20 | Apart from providing functionality in earlier versions of Python is also more verbose. 21 | (Python's zip does not print lengths, only which iterable has finished earlier) 22 | """ 23 | args = [list(a) for a in args] 24 | lens = [len(a) for a in args] 25 | assert len(set(lens)) <= 1, f"Tried to zip iterables of unequal lengths: {lens}!" 26 | return zip(*args) 27 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .config import FaultToleranceConfig # noqa: F401 17 | from .data import WorkloadAction # noqa: F401 18 | from .data import WorkloadControlRequest # noqa: F401 19 | from .rank_monitor_client import RankMonitorClient # noqa: F401 20 | from .rank_monitor_client import RankMonitorClientError # noqa: F401 21 | from .rank_monitor_server import RankMonitorServer # noqa: F401 22 | from .rank_monitor_state_machine import InvalidStateTransitionException # noqa: F401 23 | from .rank_monitor_state_machine import RankMonitorState # noqa: F401 24 | from .rank_monitor_state_machine import RankMonitorStateMachine # noqa: F401 25 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env/python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # SPDX-License-Identifier: BSD-3-Clause 10 | # Modifications made by NVIDIA 11 | # - This package is a copy of `torch.distributed.elastic` from PyTorch version 2.3 12 | # - All occurences of 'torch.distributed.elastic' were replaced with 'fault_tolerance._torch_elastic_compat' 13 | 14 | """ 15 | 16 | Torchelastic agent and user worker failover contract: 17 | 18 | **TL;DR;**: 19 | 20 | * TE(torchelastic) expects user workers to finish with the 5 minutes drift 21 | * It is better to design DDP app to fail for all workers, rather than a single one. 22 | * TE does not synchronize number of restarts between agents 23 | * TE re-rendezvous does not trigger restart decrease 24 | * When a single agent finishes its job(successfully or not), it will close rendezvous. 25 | If other agents still have workers in progress, they will be terminated. 26 | * Based on above, scale down does not work if at least single agent finishes the job. 27 | * When Scale up is detected by agents, it will not decrease ``max_restarts`` 28 | 29 | 30 | In general TE(torchelastic) can launch arbitrary user code, but there is some 31 | clarifications need to be done around what failover mechanism torchelastic 32 | provides and what failover mechanism it expects from user workers. 33 | 34 | Torchelastic currently supports DDP style applications. That means that 35 | TE expects *ALL* workers finish approximately at the same time. In practice, 36 | it is nearly to impossible to guarantee that all workers in arbitrary 37 | DDP application finish at the time, so TE provides a finalization barrier 38 | that waits for TIMEOUT(5 minutes) for worker finalization. 39 | 40 | **Worker Failure** 41 | 42 | When worker fails, TE will check the number of restarts 43 | available, if there is more than 0 restarts, TE will start a new rendezvous 44 | round and restart the worker process. New rendezvous round will other 45 | TE agents to terminate their workers. 46 | 47 | .. note:: The TE agent does not synchronize restarts between themselves. 48 | When a single agent performs restart, it will trigger a local ``max_restarts`` 49 | decrease, other agent will not decrease their ``max_restarts``. 50 | the user to run the distributed application locally on a dev host. 51 | 52 | A single worker failure can cause the whole cluster to fail: 53 | If a single worker is constantly failing, it will cause the TE agent 54 | ``max_restarts`` to go to zero. This will cause an agent to finish its 55 | work and close rendezvous. If there are any other workers on different 56 | agents, they will be terminated. 57 | 58 | 59 | **Re-Rendezvous** 60 | 61 | Re-rendezvous occurs when TE agents detect a new node 62 | trying to joint a cluster. TE will not decrease ``max_restarts``. TE agents 63 | will terminate its workers and start a new rendezvous round. 64 | 65 | Note about DynamicRendezvous(etcd-v2, c10d-experimental): If the rendezvous 66 | has already max_nodes, the new node won't be added to the wait list right 67 | away since there is no need to tear down a rendezvous that is already fully 68 | utilized. The new node will wait until its timeout (600 secs by default) 69 | and periodically check the number of participants. If the number becomes 70 | less than max_nodes, it will be added to the wait list; otherwise, it will time out after 600 secs. 71 | 72 | *Scale up event*. When scale up event happens, torchelastic rendezvous 73 | will detect that there are new nodes trying to join. Torchelastic agent 74 | will stop all workers and perform re-rendezvous. Note: when scale up event 75 | happens, *``max_restarts``* will *not* decrease. 76 | 77 | *Scale down event*. When scale down event happens, rendezvous will not 78 | notify the torchelastic agent about it. If TE agent launched with ``max_restarts=0`` , 79 | it relies on the underlying scheduler to handle job restart. If the ``max_restarts>0`` , 80 | TE agent will terminate workers and start a new rdzv round, which is a *Scale up event*. 81 | 82 | """ 83 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/agent/__init__.py -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/agent/server/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | """ 10 | The elastic agent is the control plane of torchelastic. 11 | 12 | It is a process that launches and manages underlying worker processes. 13 | The agent is responsible for: 14 | 15 | 1. Working with distributed torch: the workers are started with all the 16 | necessary information to successfully and trivially call 17 | ``torch.distributed.init_process_group()``. 18 | 19 | 2. Fault tolerance: monitors workers and upon detecting worker failures 20 | or unhealthiness, tears down all workers and restarts everyone. 21 | 22 | 3. Elasticity: Reacts to membership changes and restarts workers with the new 23 | members. 24 | 25 | The simplest agents are deployed per node and works with local processes. 26 | A more advanced agent can launch and manage workers remotely. Agents can 27 | be completely decentralized, making decisions based on the workers it manages. 28 | Or can be coordinated, communicating to other agents (that manage workers 29 | in the same job) to make a collective decision. 30 | """ 31 | 32 | from .api import ( # noqa: F401 33 | ElasticAgent, 34 | RunResult, 35 | SimpleElasticAgent, 36 | Worker, 37 | WorkerGroup, 38 | WorkerSpec, 39 | WorkerState, 40 | ) 41 | from .local_elastic_agent import TORCHELASTIC_ENABLE_FILE_TIMER, TORCHELASTIC_TIMER_FILE 42 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/events/api.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import json 10 | from dataclasses import asdict, dataclass, field 11 | from enum import Enum 12 | from typing import Dict, Union, Optional 13 | 14 | __all__ = ['EventSource', 'Event', 'NodeState', 'RdzvEvent'] 15 | 16 | EventMetadataValue = Union[str, int, float, bool, None] 17 | 18 | 19 | class EventSource(str, Enum): 20 | """Known identifiers of the event producers.""" 21 | 22 | AGENT = "AGENT" 23 | WORKER = "WORKER" 24 | 25 | 26 | @dataclass 27 | class Event: 28 | """ 29 | The class represents the generic event that occurs during the torchelastic job execution. 30 | 31 | The event can be any kind of meaningful action. 32 | 33 | Args: 34 | name: event name. 35 | source: the event producer, e.g. agent or worker 36 | timestamp: timestamp in milliseconds when event occurred. 37 | metadata: additional data that is associated with the event. 38 | """ 39 | 40 | name: str 41 | source: EventSource 42 | timestamp: int = 0 43 | metadata: Dict[str, EventMetadataValue] = field(default_factory=dict) 44 | 45 | def __str__(self): 46 | return self.serialize() 47 | 48 | @staticmethod 49 | def deserialize(data: Union[str, "Event"]) -> "Event": 50 | if isinstance(data, Event): 51 | return data 52 | if isinstance(data, str): 53 | data_dict = json.loads(data) 54 | data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] 55 | return Event(**data_dict) 56 | 57 | def serialize(self) -> str: 58 | return json.dumps(asdict(self)) 59 | 60 | 61 | class NodeState(str, Enum): 62 | """The states that a node can be in rendezvous.""" 63 | 64 | INIT = "INIT" 65 | RUNNING = "RUNNING" 66 | SUCCEEDED = "SUCCEEDED" 67 | FAILED = "FAILED" 68 | 69 | 70 | @dataclass 71 | class RdzvEvent: 72 | """ 73 | Dataclass to represent any rendezvous event. 74 | 75 | Args: 76 | name: Event name. (E.g. Current action being performed) 77 | run_id: The run id of the rendezvous 78 | message: The message describing the event 79 | hostname: Hostname of the node 80 | pid: The process id of the node 81 | node_state: The state of the node (INIT, RUNNING, SUCCEEDED, FAILED) 82 | master_endpoint: The master endpoint for the rendezvous store, if known 83 | rank: The rank of the node, if known 84 | local_id: The local_id of the node, if defined in dynamic_rendezvous.py 85 | error_trace: Error stack trace, if this is an error event. 86 | """ 87 | 88 | name: str 89 | run_id: str 90 | message: str 91 | hostname: str 92 | pid: int 93 | node_state: NodeState 94 | master_endpoint: str = "" 95 | rank: Optional[int] = None 96 | local_id: Optional[int] = None 97 | error_trace: str = "" 98 | 99 | def __str__(self): 100 | return self.serialize() 101 | 102 | @staticmethod 103 | def deserialize(data: Union[str, "RdzvEvent"]) -> "RdzvEvent": 104 | if isinstance(data, RdzvEvent): 105 | return data 106 | if isinstance(data, str): 107 | data_dict = json.loads(data) 108 | data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] 109 | return RdzvEvent(**data_dict) 110 | 111 | def serialize(self) -> str: 112 | return json.dumps(asdict(self)) 113 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/events/handlers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import logging 10 | from typing import Dict 11 | 12 | 13 | _log_handlers: Dict[str, logging.Handler] = { 14 | "console": logging.StreamHandler(), 15 | "dynamic_rendezvous": logging.NullHandler(), 16 | "null": logging.NullHandler(), 17 | } 18 | 19 | 20 | def get_logging_handler(destination: str = "null") -> logging.Handler: 21 | global _log_handlers 22 | return _log_handlers[destination] 23 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/errors/handlers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # Multiprocessing error-reporting module 9 | 10 | # SPDX-License-Identifier: BSD-3-Clause 11 | # Modifications made by NVIDIA 12 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat' 13 | 14 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.errors.error_handler import ErrorHandler 15 | 16 | __all__ = ['get_error_handler'] 17 | 18 | def get_error_handler(): 19 | return ErrorHandler() 20 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/redirects.py: -------------------------------------------------------------------------------- 1 | # !/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # Taken and modified from original source: 10 | # https://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ 11 | import ctypes 12 | import logging 13 | import os 14 | import sys 15 | from contextlib import contextmanager 16 | from functools import partial 17 | 18 | IS_WINDOWS = sys.platform == "win32" 19 | IS_MACOS = sys.platform == "darwin" 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def get_libc(): 26 | if IS_WINDOWS or IS_MACOS: 27 | logger.warning( 28 | "NOTE: Redirects are currently not supported in Windows or MacOs." 29 | ) 30 | return None 31 | else: 32 | return ctypes.CDLL("libc.so.6") 33 | 34 | 35 | libc = get_libc() 36 | 37 | 38 | def _c_std(stream: str): 39 | return ctypes.c_void_p.in_dll(libc, stream) 40 | 41 | 42 | def _python_std(stream: str): 43 | return {"stdout": sys.stdout, "stderr": sys.stderr}[stream] 44 | 45 | 46 | _VALID_STD = {"stdout", "stderr"} 47 | 48 | 49 | @contextmanager 50 | def redirect(std: str, to_file: str): 51 | """ 52 | Redirect ``std`` (one of ``"stdout"`` or ``"stderr"``) to a file in the path specified by ``to_file``. 53 | 54 | This method redirects the underlying std file descriptor (not just python's ``sys.stdout|stderr``). 55 | See usage for details. 56 | 57 | Directory of ``dst_filename`` is assumed to exist and the destination file 58 | is overwritten if it already exists. 59 | 60 | .. note:: Due to buffering cross source writes are not guaranteed to 61 | appear in wall-clock order. For instance in the example below 62 | it is possible for the C-outputs to appear before the python 63 | outputs in the log file. 64 | 65 | Usage: 66 | 67 | :: 68 | 69 | # syntactic-sugar for redirect("stdout", "tmp/stdout.log") 70 | with redirect_stdout("/tmp/stdout.log"): 71 | print("python stdouts are redirected") 72 | libc = ctypes.CDLL("libc.so.6") 73 | libc.printf(b"c stdouts are also redirected" 74 | os.system("echo system stdouts are also redirected") 75 | 76 | print("stdout restored") 77 | 78 | """ 79 | if std not in _VALID_STD: 80 | raise ValueError( 81 | f"unknown standard stream <{std}>, must be one of {_VALID_STD}" 82 | ) 83 | 84 | c_std = _c_std(std) 85 | python_std = _python_std(std) 86 | std_fd = python_std.fileno() 87 | 88 | def _redirect(dst): 89 | libc.fflush(c_std) 90 | python_std.flush() 91 | os.dup2(dst.fileno(), std_fd) 92 | 93 | with os.fdopen(os.dup(std_fd)) as orig_std, open(to_file, mode="w+b") as dst: 94 | _redirect(dst) 95 | try: 96 | yield 97 | finally: 98 | _redirect(orig_std) 99 | 100 | 101 | redirect_stdout = partial(redirect, "stdout") 102 | redirect_stderr = partial(redirect, "stderr") 103 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/subprocess_handler/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # SPDX-License-Identifier: BSD-3-Clause 10 | # Modifications made by NVIDIA 11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat' 12 | 13 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.subprocess_handler.handlers import ( 14 | get_subprocess_handler, 15 | ) 16 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.subprocess_handler.subprocess_handler import ( 17 | SubprocessHandler, 18 | ) 19 | 20 | __all__ = ["SubprocessHandler", "get_subprocess_handler"] 21 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/subprocess_handler/handlers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # SPDX-License-Identifier: BSD-3-Clause 10 | # Modifications made by NVIDIA 11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat' 12 | 13 | from typing import Dict, Tuple 14 | 15 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.multiprocessing.subprocess_handler.subprocess_handler import ( 16 | SubprocessHandler, 17 | ) 18 | 19 | __all__ = ["get_subprocess_handler"] 20 | 21 | 22 | def get_subprocess_handler( 23 | entrypoint: str, 24 | args: Tuple, 25 | env: Dict[str, str], 26 | stdout: str, 27 | stderr: str, 28 | local_rank_id: int, 29 | ): 30 | return SubprocessHandler( 31 | entrypoint=entrypoint, 32 | args=args, 33 | env=env, 34 | stdout=stdout, 35 | stderr=stderr, 36 | local_rank_id=local_rank_id, 37 | ) 38 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/multiprocessing/subprocess_handler/subprocess_handler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # SPDX-License-Identifier: BSD-3-Clause 10 | # Modifications made by NVIDIA 11 | # Added shell=False to Popen to mitigate security thread 12 | # Added suppression for subprocess low serverity issue 13 | 14 | import os 15 | import signal 16 | # Issue: [B404:blacklist] Consider possible security implications associated with the subprocess module. 17 | # Severity: Low Confidence: High 18 | # CWE: CWE-78 (https://cwe.mitre.org/data/definitions/78.html) 19 | # More Info: https://bandit.readthedocs.io/en/1.7.9/blacklists/blacklist_imports.html#b404-import-subprocess 20 | import subprocess # nosec 21 | import sys 22 | 23 | from typing import Any, Dict, Optional, Tuple 24 | 25 | __all__ = ["SubprocessHandler"] 26 | 27 | IS_WINDOWS = sys.platform == "win32" 28 | 29 | 30 | def _get_default_signal() -> signal.Signals: 31 | """Get the default termination signal. SIGTERM for unix, CTRL_C_EVENT for windows.""" 32 | if IS_WINDOWS: 33 | return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821 34 | else: 35 | return signal.SIGTERM 36 | 37 | 38 | class SubprocessHandler: 39 | """ 40 | Convenience wrapper around python's ``subprocess.Popen``. Keeps track of 41 | meta-objects associated to the process (e.g. stdout and stderr redirect fds). 42 | """ 43 | 44 | def __init__( 45 | self, 46 | entrypoint: str, 47 | args: Tuple, 48 | env: Dict[str, str], 49 | stdout: str, 50 | stderr: str, 51 | local_rank_id: int, 52 | ): 53 | self._stdout = open(stdout, "w") if stdout else None 54 | self._stderr = open(stderr, "w") if stderr else None 55 | # inherit parent environment vars 56 | env_vars = os.environ.copy() 57 | env_vars.update(env) 58 | 59 | args_str = (entrypoint, *[str(e) for e in args]) 60 | self.local_rank_id = local_rank_id 61 | self.proc: subprocess.Popen = self._popen(args_str, env_vars) 62 | 63 | def _popen(self, args: Tuple, env: Dict[str, str]) -> subprocess.Popen: 64 | kwargs: Dict[str, Any] = {} 65 | if not IS_WINDOWS: 66 | kwargs["start_new_session"] = True 67 | # Issue: [B603:subprocess_without_shell_equals_true] subprocess call - check for execution of untrusted input. 68 | # Severity: Low Confidence: High 69 | # CWE: CWE-78 (https://cwe.mitre.org/data/definitions/78.html) 70 | # More Info: https://bandit.readthedocs.io/en/1.7.9/plugins/b603_subprocess_without_shell_equals_true.html 71 | return subprocess.Popen( 72 | # pyre-fixme[6]: Expected `Union[typing.Sequence[Union[_PathLike[bytes], 73 | # _PathLike[str], bytes, str]], bytes, str]` for 1st param but got 74 | # `Tuple[str, *Tuple[Any, ...]]`. 75 | args=args, 76 | env=env, 77 | stdout=self._stdout, 78 | stderr=self._stderr, 79 | **kwargs, 80 | shell=False, 81 | ) # nosec 82 | 83 | def close(self, death_sig: Optional[signal.Signals] = None) -> None: 84 | if not death_sig: 85 | death_sig = _get_default_signal() 86 | if IS_WINDOWS: 87 | self.proc.send_signal(death_sig) 88 | else: 89 | os.killpg(self.proc.pid, death_sig) 90 | if self._stdout: 91 | self._stdout.close() 92 | if self._stderr: 93 | self._stderr.close() 94 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/rendezvous/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # SPDX-License-Identifier: BSD-3-Clause 8 | # Modifications made by NVIDIA 9 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat' 10 | from .api import RendezvousHandler, RendezvousParameters 11 | from .api import rendezvous_handler_registry as handler_registry 12 | from .dynamic_rendezvous import create_handler 13 | 14 | __all__ = ['get_rendezvous_handler'] 15 | 16 | def _create_static_handler(params: RendezvousParameters) -> RendezvousHandler: 17 | from . import static_tcp_rendezvous 18 | 19 | return static_tcp_rendezvous.create_rdzv_handler(params) 20 | 21 | 22 | def _create_etcd_handler(params: RendezvousParameters) -> RendezvousHandler: 23 | from . import etcd_rendezvous 24 | 25 | return etcd_rendezvous.create_rdzv_handler(params) 26 | 27 | 28 | def _create_etcd_v2_handler(params: RendezvousParameters) -> RendezvousHandler: 29 | from .etcd_rendezvous_backend import create_backend 30 | 31 | backend, store = create_backend(params) 32 | 33 | return create_handler(store, backend, params) 34 | 35 | 36 | def _create_c10d_handler(params: RendezvousParameters) -> RendezvousHandler: 37 | from .c10d_rendezvous_backend import create_backend 38 | 39 | backend, store = create_backend(params) 40 | 41 | return create_handler(store, backend, params) 42 | 43 | 44 | def _register_default_handlers() -> None: 45 | handler_registry.register("etcd", _create_etcd_handler) 46 | handler_registry.register("etcd-v2", _create_etcd_v2_handler) 47 | handler_registry.register("c10d", _create_c10d_handler) 48 | handler_registry.register("static", _create_static_handler) 49 | 50 | 51 | def get_rendezvous_handler(params: RendezvousParameters) -> RendezvousHandler: 52 | """ 53 | Obtain a reference to a :py:class`RendezvousHandler`. 54 | 55 | Custom rendezvous handlers can be registered by 56 | 57 | :: 58 | 59 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous import rendezvous_handler_registry 60 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous.registry import get_rendezvous_handler 61 | 62 | def create_my_rdzv(params: RendezvousParameters): 63 | return MyCustomRdzv(params) 64 | 65 | rendezvous_handler_registry.register("my_rdzv_backend_name", create_my_rdzv) 66 | 67 | my_rdzv_handler = get_rendezvous_handler("my_rdzv_backend_name", RendezvousParameters) 68 | """ 69 | return handler_registry.create_handler(params) 70 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/rendezvous/static_tcp_rendezvous.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # SPDX-License-Identifier: BSD-3-Clause 10 | # Modifications made by NVIDIA 11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat' 12 | import datetime 13 | import logging 14 | from typing import Tuple, cast, Optional 15 | 16 | # pyre-ignore[21]: Could not find name `Store` in `torch.distributed`. 17 | from torch.distributed import Store, TCPStore, PrefixStore 18 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous import RendezvousHandler, RendezvousParameters 19 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.rendezvous.utils import parse_rendezvous_endpoint 20 | 21 | log = logging.getLogger(__name__) 22 | 23 | _default_timeout_seconds = 600 24 | 25 | 26 | class StaticTCPRendezvous(RendezvousHandler): 27 | """ 28 | Static rendezvous that is a wrapper around the TCPStore. 29 | 30 | Creates TCPStore based on the input parameters with the 31 | listener on the agent with group_rank=0 32 | """ 33 | 34 | def __init__( 35 | self, 36 | master_addr: str, 37 | master_port: int, 38 | rank: int, 39 | world_size: int, 40 | run_id: str, 41 | timeout: int, 42 | ): 43 | self.master_addr = master_addr 44 | self.master_port = master_port 45 | self.rank = rank 46 | self.world_size = world_size 47 | self.run_id = run_id 48 | self.timeout = datetime.timedelta(seconds=timeout) 49 | self._store: Optional[Store] = None 50 | 51 | def get_backend(self) -> str: 52 | return "static" 53 | 54 | def next_rendezvous(self) -> Tuple[Store, int, int]: 55 | log.info("Creating TCPStore as the c10d::Store implementation") 56 | if not self._store: 57 | is_master = self.rank == 0 58 | self._store = TCPStore( # type: ignore[call-arg] 59 | self.master_addr, 60 | self.master_port, 61 | self.world_size, 62 | is_master, 63 | self.timeout, 64 | multi_tenant=True, 65 | ) 66 | store = PrefixStore(self.run_id, self._store) 67 | return store, self.rank, self.world_size 68 | 69 | def is_closed(self): 70 | return False 71 | 72 | def set_closed(self): 73 | pass 74 | 75 | def num_nodes_waiting(self): 76 | return 0 77 | 78 | def get_run_id(self) -> str: 79 | return self.run_id 80 | 81 | def shutdown(self) -> bool: 82 | return True 83 | 84 | 85 | def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler: 86 | if "rank" not in params.config: 87 | raise ValueError( 88 | "rank is absent in RendezvousParameters." 89 | "Try add --node-rank to the cmd request" 90 | ) 91 | endpoint = params.endpoint.strip() 92 | if not endpoint: 93 | raise ValueError( 94 | "endpoint is absent in RendezvousParameters" 95 | "Try add --master-port and --master-addr to the cmd request" 96 | ) 97 | master_addr, master_port = parse_rendezvous_endpoint(endpoint, -1) 98 | if master_port == -1: 99 | raise ValueError( 100 | f"Port is absent in endpoint: {endpoint}. Try launching with --master-port" 101 | ) 102 | world_size = params.max_nodes 103 | rank = cast(int, params.config.get("rank")) 104 | run_id = params.run_id 105 | if "timeout" in params.config: 106 | timeout = int(params.config["timeout"]) 107 | else: 108 | timeout = _default_timeout_seconds 109 | return StaticTCPRendezvous( 110 | master_addr, master_port, rank, world_size, run_id, timeout 111 | ) 112 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/timer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Expiration timers are set up on the same process as the agent and 9 | used from your script to deal with stuck workers. When you go into 10 | a code-block that has the potential to get stuck you can acquire 11 | an expiration timer, which instructs the timer server to kill the 12 | process if it does not release the timer by the self-imposed expiration 13 | deadline. 14 | 15 | Usage:: 16 | 17 | import torchelastic.timer as timer 18 | import torchelastic.agent.server as agent 19 | 20 | def main(): 21 | start_method = "spawn" 22 | message_queue = mp.get_context(start_method).Queue() 23 | server = timer.LocalTimerServer(message, max_interval=0.01) 24 | server.start() # non-blocking 25 | 26 | spec = WorkerSpec( 27 | fn=trainer_func, 28 | args=(message_queue,), 29 | ...) 30 | agent = agent.LocalElasticAgent(spec, start_method) 31 | agent.run() 32 | 33 | def trainer_func(message_queue): 34 | timer.configure(timer.LocalTimerClient(message_queue)) 35 | with timer.expires(after=60): # 60 second expiry 36 | # do some work 37 | 38 | In the example above if ``trainer_func`` takes more than 60 seconds to 39 | complete, then the worker process is killed and the agent retries the worker group. 40 | """ 41 | 42 | from .api import TimerClient, TimerRequest, TimerServer, configure, expires # noqa: F401 43 | from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 44 | from .file_based_local_timer import FileTimerClient, FileTimerServer, FileTimerRequest # noqa: F401 45 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401 10 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/api.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import os 10 | import socket 11 | from string import Template 12 | from typing import List, Any 13 | 14 | 15 | def get_env_variable_or_raise(env_name: str) -> str: 16 | r""" 17 | Tries to retrieve environment variable. Raises ``ValueError`` 18 | if no environment variable found. 19 | 20 | Args: 21 | env_name (str): Name of the env variable 22 | """ 23 | value = os.environ.get(env_name, None) 24 | if value is None: 25 | msg = f"Environment variable {env_name} expected, but not set" 26 | raise ValueError(msg) 27 | return value 28 | 29 | 30 | def get_socket_with_port() -> socket.socket: 31 | addrs = socket.getaddrinfo( 32 | host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM 33 | ) 34 | for addr in addrs: 35 | family, type, proto, _, _ = addr 36 | s = socket.socket(family, type, proto) 37 | try: 38 | s.bind(("localhost", 0)) 39 | s.listen(0) 40 | return s 41 | except OSError as e: 42 | s.close() 43 | raise RuntimeError("Failed to create a socket") 44 | 45 | 46 | class macros: 47 | """ 48 | Defines simple macros for caffe2.distributed.launch cmd args substitution 49 | """ 50 | 51 | local_rank = "${local_rank}" 52 | 53 | @staticmethod 54 | def substitute(args: List[Any], local_rank: str) -> List[str]: 55 | args_sub = [] 56 | for arg in args: 57 | if isinstance(arg, str): 58 | sub = Template(arg).safe_substitute(local_rank=local_rank) 59 | args_sub.append(sub) 60 | else: 61 | args_sub.append(arg) 62 | return args_sub 63 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from .cycling_iterator import CyclingIterator # noqa: F401 10 | from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401 11 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/data/cycling_iterator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | 10 | class CyclingIterator: 11 | """ 12 | An iterator decorator that cycles through the 13 | underlying iterator "n" times. Useful to "unroll" 14 | the dataset across multiple training epochs. 15 | 16 | The generator function is called as ``generator_fn(epoch)`` 17 | to obtain the underlying iterator, where ``epoch`` is a 18 | number less than or equal to ``n`` representing the ``k``th cycle 19 | 20 | For example if ``generator_fn`` always returns ``[1,2,3]`` 21 | then ``CyclingIterator(n=2, generator_fn)`` will iterate through 22 | ``[1,2,3,1,2,3]`` 23 | """ 24 | 25 | def __init__(self, n: int, generator_fn, start_epoch=0): 26 | self._n = n 27 | self._epoch = start_epoch 28 | self._generator_fn = generator_fn 29 | self._iter = generator_fn(self._epoch) 30 | 31 | def __iter__(self): 32 | return self 33 | 34 | def __next__(self): 35 | try: 36 | return next(self._iter) 37 | except StopIteration as eod: # eod == end of data 38 | if self._epoch < self._n - 1: 39 | self._epoch += 1 40 | self._iter = self._generator_fn(self._epoch) 41 | return self.__next__() 42 | else: 43 | raise eod 44 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/data/elastic_distributed_sampler.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import math 10 | 11 | import torch 12 | from torch.utils.data.distributed import DistributedSampler 13 | 14 | 15 | class ElasticDistributedSampler(DistributedSampler): 16 | """ 17 | Sampler that restricts data loading to a subset of 18 | the dataset for elastic training. 19 | 20 | It is especially useful in conjunction with 21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 22 | process can pass a DistributedSampler instance as a DataLoader sampler, 23 | and load a subset of the original dataset that is exclusive to it. 24 | 25 | .. note:: 26 | Dataset is assumed to be of constant size. 27 | 28 | Args: 29 | dataset: Dataset used for sampling. 30 | num_replicas (optional): Number of processes participating in 31 | distributed training. 32 | rank (optional): Rank of the current process within num_replicas. 33 | start_index (optional): Which index of the dataset to start sampling from 34 | """ 35 | 36 | def __init__(self, dataset, num_replicas=None, rank=None, start_index=0): 37 | super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) 38 | if start_index >= len(dataset): 39 | raise ValueError( 40 | f"Start index {start_index} should be less than dataset size {len(dataset)}" 41 | ) 42 | 43 | self.start_index = start_index 44 | self.num_samples = int( 45 | math.ceil(float(len(self.dataset) - self.start_index) / self.num_replicas) # type: ignore[arg-type] 46 | ) 47 | self.total_size = self.num_samples * self.num_replicas 48 | 49 | def __iter__(self): 50 | # deterministically shuffle based on epoch 51 | g = torch.Generator() 52 | g.manual_seed(self.epoch) 53 | indices = ( 54 | torch.randperm(len(self.dataset) - self.start_index, generator=g) # type: ignore[arg-type] 55 | .add(self.start_index) 56 | .tolist() 57 | ) 58 | 59 | # add extra samples to make it evenly divisible 60 | indices += indices[: (self.total_size - len(indices))] 61 | assert len(indices) == self.total_size 62 | 63 | # subsample 64 | indices = indices[self.rank : self.total_size : self.num_replicas] 65 | assert len(indices) == self.num_samples 66 | 67 | return iter(indices) 68 | 69 | def __len__(self): 70 | return self.num_samples 71 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/log_level.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | 10 | def get_log_level() -> str: 11 | """ 12 | Return default log level for pytorch. 13 | """ 14 | return "WARNING" 15 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # SPDX-License-Identifier: BSD-3-Clause 10 | # Modifications made by NVIDIA 11 | # All occurences of 'torch.distributed.elastic' were replaced with 'nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat' 12 | 13 | import inspect 14 | import logging 15 | import os 16 | import warnings 17 | from typing import Optional 18 | 19 | from nvidia_resiliency_ext.fault_tolerance._torch_elastic_compat.utils.log_level import get_log_level 20 | 21 | 22 | def get_logger(name: Optional[str] = None): 23 | """ 24 | Util function to set up a simple logger that writes 25 | into stderr. The loglevel is fetched from the LOGLEVEL 26 | env. variable or WARNING as default. The function will use the 27 | module name of the caller if no name is provided. 28 | 29 | Args: 30 | name: Name of the logger. If no name provided, the name will 31 | be derived from the call stack. 32 | """ 33 | 34 | # Derive the name of the caller, if none provided 35 | # Use depth=2 since this function takes up one level in the call stack 36 | return _setup_logger(name or _derive_module_name(depth=2)) 37 | 38 | 39 | def _setup_logger(name: Optional[str] = None): 40 | log = logging.getLogger(name) 41 | log.setLevel(os.environ.get("LOGLEVEL", get_log_level())) 42 | return log 43 | 44 | 45 | def _derive_module_name(depth: int = 1) -> Optional[str]: 46 | """ 47 | Derives the name of the caller module from the stack frames. 48 | 49 | Args: 50 | depth: The position of the frame in the stack. 51 | """ 52 | try: 53 | stack = inspect.stack() 54 | assert depth < len(stack) 55 | # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index) 56 | frame_info = stack[depth] 57 | 58 | module = inspect.getmodule(frame_info[0]) 59 | if module: 60 | module_name = module.__name__ 61 | else: 62 | # inspect.getmodule(frame_info[0]) does NOT work (returns None) in 63 | # binaries built with @mode/opt 64 | # return the filename (minus the .py extension) as modulename 65 | filename = frame_info[1] 66 | module_name = os.path.splitext(os.path.basename(filename))[0] 67 | return module_name 68 | except Exception as e: 69 | warnings.warn( 70 | f"Error deriving logger module name, using . Exception: {e}", 71 | RuntimeWarning, 72 | ) 73 | return None 74 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/fault_tolerance/_torch_elastic_compat/utils/store.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | from datetime import timedelta 10 | from typing import List 11 | 12 | 13 | def get_all(store, rank: int, prefix: str, size: int): 14 | r""" 15 | Given a store and a prefix, the method goes through the array of keys 16 | of the following format: ``{prefix}{idx}``, where idx is in a range 17 | from 0 to size, and tries to retrieve the data. 18 | 19 | The Rank0 process waits at the end to make sure all other processes 20 | finished the procedure before exiting. 21 | 22 | Usage 23 | 24 | :: 25 | 26 | values = get_all(store, 'torchelastic/data', 3) 27 | value1 = values[0] # retrieves the data for key torchelastic/data0 28 | value2 = values[1] # retrieves the data for key torchelastic/data1 29 | value3 = values[2] # retrieves the data for key torchelastic/data2 30 | 31 | """ 32 | data_arr = [] 33 | for idx in range(size): 34 | data = store.get(f"{prefix}{idx}") 35 | data_arr.append(data) 36 | store.set(f"{prefix}{rank}.FIN", b"FIN") 37 | if rank == 0: 38 | # Rank0 runs the TCPStore daemon, as a result it needs to exit last. 39 | # Otherwise, the barrier may timeout if rank0 process finished the work 40 | # before other processes finished `get_all` method 41 | for node_rank in range(size): 42 | store.get(f"{prefix}{node_rank}.FIN") 43 | 44 | return data_arr 45 | 46 | 47 | def synchronize( 48 | store, 49 | data: bytes, 50 | rank: int, 51 | world_size: int, 52 | key_prefix: str, 53 | barrier_timeout: float = 300, 54 | ) -> List[bytes]: 55 | """ 56 | Synchronizes ``world_size`` agents between each other using the underlying c10d store. 57 | The ``data`` will be available on each of the agents. 58 | 59 | Note: The data on the path is not deleted, as a result there can be stale data if 60 | you use the same key_prefix twice. 61 | """ 62 | store.set_timeout(timedelta(seconds=barrier_timeout)) 63 | store.set(f"{key_prefix}{rank}", data) 64 | agent_data = get_all(store, rank, key_prefix, world_size) 65 | return agent_data 66 | 67 | 68 | def barrier( 69 | store, rank: int, world_size: int, key_prefix: str, barrier_timeout: float = 300 70 | ) -> None: 71 | """ 72 | A global lock between agents. 73 | 74 | Note: Since the data is not removed from the store, the barrier can be used 75 | once per unique ``key_prefix``. 76 | """ 77 | data = f"{rank}".encode() 78 | synchronize(store, data, rank, world_size, key_prefix, barrier_timeout) 79 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from . import ( 18 | completion, 19 | exception, 20 | finalize, 21 | health_check, 22 | initialize, 23 | monitor_thread, 24 | nested_restarter, 25 | rank_assignment, 26 | state, 27 | terminate, 28 | ) 29 | from .compose import Compose 30 | from .state import FrozenState, Mode, State 31 | from .wrap import CallWrapper, Wrapper 32 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/attribution.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import dataclasses 18 | import enum 19 | import itertools 20 | import re 21 | 22 | 23 | class Interruption(enum.Enum): 24 | EXCEPTION = enum.auto() 25 | BASE_EXCEPTION = enum.auto() 26 | SOFT_TIMEOUT = enum.auto() 27 | HARD_TIMEOUT = enum.auto() 28 | TERMINATED = enum.auto() 29 | UNRESPONSIVE = enum.auto() 30 | MONITOR_PROCESS_EXCEPTION = enum.auto() 31 | 32 | 33 | @dataclasses.dataclass(frozen=True) 34 | class InterruptionRecord: 35 | rank: int 36 | interruption: Interruption 37 | 38 | @classmethod 39 | def from_str(cls, string: str): 40 | rank_match = re.search(r'rank=(\d+)', string) 41 | interruption_match = re.search(r'Interruption\.(\w+)', string) 42 | 43 | if not rank_match or not interruption_match: 44 | raise ValueError("Invalid State string format") 45 | 46 | rank = int(rank_match.group(1)) 47 | interruption_name = interruption_match.group(1) 48 | interruption = Interruption[interruption_name] 49 | 50 | return cls(rank=rank, interruption=interruption) 51 | 52 | 53 | def format_interruption_records(records): 54 | msg = ', '.join( 55 | ( 56 | f'{interruption} on {ranks=}' 57 | for interruption, group in itertools.groupby(records, key=lambda r: r.interruption) 58 | for ranks in [set([elem.rank for elem in group])] 59 | ) 60 | ) 61 | return msg 62 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/completion.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import abc 18 | 19 | from .state import FrozenState 20 | 21 | 22 | class Completion(abc.ABC): 23 | r''' 24 | Abstract base class for ``global_finalize_success`` argument for 25 | :py:class:`inprocess.Wrapper`. 26 | 27 | :py:class:`Completion` is executed by any unterminated rank when 28 | it has completed the workload wrapped by inprocess. 29 | 30 | Multiple instances of :py:class:`Completion` could be composed with 31 | :py:class:`inprocess.Compose` to achieve the desired behavior. 32 | ''' 33 | 34 | @abc.abstractmethod 35 | def __call__(self, state: FrozenState) -> FrozenState: 36 | r''' 37 | Implementation of a :py:class:`Completion`. 38 | 39 | Args: 40 | state: read-only :py:class:`Wrapper` state 41 | 42 | Returns: 43 | Forwarded read-only input ``state``. 44 | ''' 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/compose.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import inspect 18 | import warnings 19 | from collections.abc import Callable 20 | from typing import Any, TypeVar 21 | 22 | T = TypeVar('T') 23 | 24 | 25 | def find_common_ancestor(*instances): 26 | common_mro = set(type(instances[0]).mro()) 27 | 28 | for instance in instances[1:]: 29 | common_mro &= set(type(instance).mro()) 30 | 31 | if common_mro: 32 | mro_list = type(instances[0]).mro() 33 | common_ancestor = [cls for cls in mro_list if cls in common_mro] 34 | return common_ancestor[0] 35 | else: 36 | return None 37 | 38 | 39 | class Compose: 40 | r''' 41 | Performs functional composition (chaining) of multiple callable class 42 | instances. 43 | 44 | Output of the previous callable is passed as input to the next callable, 45 | and the output of the last callable is returned as the final output of a 46 | :py:class:`Compose` instance. 47 | 48 | Constructed :py:class:`Compose` object is an instance of the lowest common 49 | ancestor in `method resolution order 50 | `_ of 51 | all input callable class instances. 52 | 53 | Example: 54 | 55 | .. code-block:: python 56 | 57 | composed = Compose(a, b, c) 58 | ret = composed(arg) # is equivalent to ret = a(b(c(arg))) 59 | ''' 60 | 61 | def __new__(cls, *instances: Callable[[T], T]): 62 | 63 | common_ancestor = find_common_ancestor(*instances) 64 | DynamicCompose = type( 65 | 'DynamicCompose', 66 | (Compose, common_ancestor), 67 | { 68 | 'instances': instances, 69 | '__new__': object.__new__, 70 | }, 71 | ) 72 | return DynamicCompose() 73 | 74 | def __init__(self, *args, **kwargs): 75 | pass 76 | 77 | def __call__(self, *args: Any): 78 | for instance in reversed(self.instances): 79 | ret = instance(*args or ()) 80 | if ret is None and args and args != (None,): 81 | msg = ( 82 | f'{type(self).__name__} didn\'t chain arguments after ' 83 | f'calling {instance=} with {args=}' 84 | ) 85 | warnings.warn(msg) 86 | if not isinstance(ret, tuple) and len(inspect.signature(instance).parameters) > 0: 87 | args = (ret,) 88 | else: 89 | args = ret 90 | return ret 91 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/exception.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | class RestartError(Exception): 19 | r''' 20 | Base :py:exc:`Exception` for exceptions raised by 21 | :py:class:`inprocess.Wrapper`. 22 | ''' 23 | 24 | pass 25 | 26 | 27 | class RestartAbort(BaseException): 28 | r''' 29 | A terminal Python :py:exc:`BaseException` indicating that the 30 | :py:class:`inprocess.Wrapper` should be aborted immediately, bypassing any 31 | further restart attempts. 32 | ''' 33 | 34 | pass 35 | 36 | 37 | class HealthCheckError(RestartError): 38 | r''' 39 | :py:exc:`RestartError` exception to indicate that 40 | :py:class:`inprocess.health_check.HealthCheck` raised errors, and execution 41 | shouldn't be restarted on this distributed rank. 42 | ''' 43 | 44 | pass 45 | 46 | 47 | class InternalError(RestartError): 48 | r''' 49 | :py:class:`inprocess.Wrapper` internal error. 50 | ''' 51 | 52 | pass 53 | 54 | 55 | class TimeoutError(RestartError): 56 | r''' 57 | :py:class:`inprocess.Wrapper` timeout error. 58 | ''' 59 | 60 | pass 61 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/finalize.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import abc 18 | import datetime 19 | import threading 20 | from typing import Any, Callable, Optional 21 | 22 | from . import exception 23 | from .state import FrozenState 24 | 25 | 26 | class Finalize(abc.ABC): 27 | r''' 28 | Abstract base class for ``finalize`` argument for 29 | :py:class:`inprocess.Wrapper`. 30 | 31 | :py:class:`Finalize` brings the process into a state where a restart of the 32 | wrapped function may be attempted, e.g.: deinitialize any global variables 33 | or synchronize with any asynchronous tasks issued by the wrapped function 34 | that was not already performed by exception handlers in the wrapped 35 | function. 36 | 37 | Any failure during execution of :py:class:`Finalize` should raise an 38 | exception. In this case the health check is skipped, exception is reraised 39 | by the wrapper, and it should cause termination of the main Python 40 | interpreter process. 41 | 42 | :py:class:`Finalize` class is executed after a fault was detected, 43 | distributed group was destroyed, but before the 44 | :py:class:`inprocess.health_check.HealthCheck` is performed. 45 | 46 | Multiple instances of :py:class:`Finalize` could be composed with 47 | :py:class:`inprocess.Compose` to achieve the desired behavior. 48 | ''' 49 | 50 | @abc.abstractmethod 51 | def __call__(self, state: FrozenState) -> FrozenState: 52 | r''' 53 | Implementation of a :py:class:`Finalize`. 54 | 55 | Args: 56 | state: read-only :py:class:`Wrapper` state 57 | 58 | Returns: 59 | Forwarded read-only input ``state``. 60 | ''' 61 | raise NotImplementedError 62 | 63 | 64 | class ThreadedFinalize(Finalize): 65 | r''' 66 | Executes the provided finalize ``fn`` function with specified positional 67 | and keyword arguments in a separate :py:class:`threading.Thread`. 68 | 69 | Raises an exception if execution takes longer than the specified 70 | ``timeout``. 71 | 72 | Args: 73 | timeout: timeout for a thread executing ``fn`` 74 | fn: function to be executed 75 | args: tuple of positional arguments 76 | kwargs: dictionary of keyword arguments 77 | ''' 78 | 79 | def __init__( 80 | self, 81 | timeout: datetime.timedelta, 82 | fn: Callable[..., Any], 83 | args: Optional[tuple[Any, ...]] = (), 84 | kwargs: Optional[dict[str, Any]] = None, 85 | ): 86 | if kwargs is None: 87 | kwargs = {} 88 | 89 | self.timeout = timeout 90 | self.fn = fn 91 | self.args = args 92 | self.kwargs = kwargs 93 | 94 | def __call__(self, state: FrozenState) -> FrozenState: 95 | rank = state.rank 96 | thread = threading.Thread( 97 | target=self.fn, 98 | name=f'{type(self).__name__}-{rank}', 99 | args=self.args, 100 | kwargs=self.kwargs, 101 | daemon=True, 102 | ) 103 | thread.start() 104 | thread.join(self.timeout.total_seconds()) 105 | if thread.is_alive(): 106 | raise exception.TimeoutError 107 | 108 | return state 109 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/initialize.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import abc 18 | from typing import Optional 19 | 20 | from . import exception 21 | from .state import FrozenState 22 | 23 | 24 | class Initialize(abc.ABC): 25 | r''' 26 | Abstract base class for ``initialize`` argument for 27 | :py:class:`inprocess.Wrapper`. 28 | 29 | :py:class:`Initialize` is executed at the start of every restart iteration, 30 | including the first one. :py:class:`Initialize` can raise exceptions (e.g., 31 | if specific preconditions are not met). Raising a standard Python 32 | :py:exc:`Exception` triggers another restart, while raising a 33 | :py:exc:`BaseException` terminates the wrapper. 34 | 35 | Multiple instances of :py:class:`Initialize` could be composed with 36 | :py:class:`inprocess.Compose` to achieve the desired behavior. 37 | ''' 38 | 39 | @abc.abstractmethod 40 | def __call__(self, state: FrozenState) -> FrozenState: 41 | r''' 42 | Implementation of a :py:class:`Initialize`. 43 | 44 | Args: 45 | state: read-only :py:class:`Wrapper` state 46 | 47 | Returns: 48 | Forwarded read-only input ``state``. 49 | ''' 50 | raise NotImplementedError 51 | 52 | 53 | class RetryController(Initialize): 54 | r''' 55 | Controls retry logic for distributed training based on specified iteration 56 | and world size limits. 57 | 58 | This class manages the conditions under which distributed training retries 59 | are allowed, raising a :py:exc:`inprocess.exception.RestartAbort` exception 60 | when the conditions are not met. 61 | 62 | Args: 63 | max_iterations: the maximum number of iterations allowed before 64 | aborting retries. If :py:obj:`None`, there is no iteration limit 65 | min_world_size: The minimum required world size to proceed with 66 | execution 67 | min_active_world_size: The minimum required active world size to 68 | proceed with execution 69 | ''' 70 | 71 | def __init__( 72 | self, 73 | max_iterations: Optional[int] = None, 74 | min_world_size: int = 1, 75 | min_active_world_size: int = 1, 76 | ): 77 | self.max_iterations = max_iterations 78 | self.min_world_size = min_world_size 79 | self.min_active_world_size = min_active_world_size 80 | 81 | def __call__(self, state: FrozenState) -> FrozenState: 82 | if ( 83 | state.world_size < self.min_world_size 84 | or state.active_world_size < self.min_active_world_size 85 | or (self.max_iterations is not None and state.iteration >= self.max_iterations) 86 | ): 87 | msg = ( 88 | f'{state.iteration=} {self.max_iterations=} ' 89 | f'{state.world_size=} {self.min_world_size=} ' 90 | f'{state.active_world_size=} {self.min_active_world_size=} ' 91 | ) 92 | raise exception.RestartAbort(msg) 93 | return state 94 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/nested_restarter.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import dataclasses 18 | from typing import Optional 19 | 20 | from ..fault_tolerance.rank_monitor_server import RankMonitorLogger 21 | from .abort import Abort 22 | from .completion import Completion 23 | from .initialize import Initialize 24 | from .state import FrozenState 25 | from .terminate import Terminate 26 | 27 | 28 | class NestedRestarterLogger(RankMonitorLogger): 29 | """Logger used in the nested restarter process""" 30 | 31 | def __init__(self): 32 | super().__init__(name="InprocessRestarter", is_restarter_logger=True) 33 | 34 | 35 | @dataclasses.dataclass 36 | class NestedRestarterCallback: 37 | r''' 38 | Callback for logging the NVRx nested restarter integration. 39 | ''' 40 | 41 | _shared_logger = NestedRestarterLogger() 42 | 43 | restarter_state: str 44 | restarter_stage: Optional[str] = None 45 | logger: NestedRestarterLogger = dataclasses.field(default=_shared_logger) 46 | special_rank: int = 0 47 | 48 | def __call__(self, state: FrozenState) -> FrozenState: 49 | 50 | if state.initial_rank == self.special_rank: 51 | self.logger.set_connected_rank(state.initial_rank) 52 | msg = f'[NestedRestarter] name=[InProcess] state={self.restarter_state}' 53 | if self.restarter_stage is not None: 54 | msg += f" stage={self.restarter_stage}" 55 | 56 | self.logger.log_for_restarter(msg) 57 | 58 | return state 59 | 60 | 61 | @dataclasses.dataclass 62 | class NestedRestarterHandlingCompleted(Initialize, NestedRestarterCallback): 63 | 64 | restarter_state: str = 'initialize' 65 | restarter_stage: str = None 66 | 67 | def __init__(self, special_rank: int = 0): 68 | self._called_once = False 69 | self.special_rank = special_rank 70 | self.logger = NestedRestarterCallback._shared_logger 71 | 72 | def __call__(self, state: FrozenState) -> FrozenState: 73 | 74 | # Apply the callback functionality 75 | state = NestedRestarterCallback.__call__(self, state) 76 | 77 | if not self._called_once: 78 | self._called_once = True 79 | self.restarter_state = 'handling' 80 | self.restarter_stage = 'completed' 81 | 82 | return state 83 | 84 | 85 | @dataclasses.dataclass 86 | class NestedRestarterHandlingStarting(Abort, NestedRestarterCallback): 87 | restarter_state: str = 'handling' 88 | restarter_stage: str = 'starting' 89 | 90 | def __call__(self, state: FrozenState) -> FrozenState: 91 | return NestedRestarterCallback.__call__(self, state) 92 | 93 | 94 | @dataclasses.dataclass 95 | class NestedRestarterFinalized(Completion, NestedRestarterCallback): 96 | restarter_state: str = 'finalized' 97 | 98 | def __call__(self, state: FrozenState) -> FrozenState: 99 | return NestedRestarterCallback.__call__(self, state) 100 | 101 | 102 | @dataclasses.dataclass 103 | class NestedRestarterAborted(Terminate, NestedRestarterCallback): 104 | restarter_state: str = 'aborted' 105 | 106 | def __call__(self, state: FrozenState) -> FrozenState: 107 | return NestedRestarterCallback.__call__(self, state) 108 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/terminate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import abc 18 | 19 | from .state import FrozenState 20 | 21 | 22 | class Terminate(abc.ABC): 23 | r''' 24 | Abstract base class for ``global_finalize_failure`` argument for 25 | :py:class:`inprocess.Wrapper`. 26 | 27 | :py:class:`Terminate` is executed by any unterminated rank when 28 | that rank terminates. 29 | 30 | Multiple instances of :py:class:`Terminate` could be composed with 31 | :py:class:`inprocess.Compose` to achieve the desired behavior. 32 | ''' 33 | 34 | @abc.abstractmethod 35 | def __call__(self, state: FrozenState) -> FrozenState: 36 | r''' 37 | Implementation of a :py:class:`Terminate`. 38 | 39 | Args: 40 | state: read-only :py:class:`Wrapper` state 41 | 42 | Returns: 43 | Forwarded read-only input ``state``. 44 | ''' 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/inprocess/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from . import inject_fault 18 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/ptl_resiliency/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from ._utils import SimulatedFaultParams # noqa: F401 17 | from .fault_tolerance_callback import FaultToleranceCallback # noqa: F401 18 | from .fault_tolerance_sections_callback import FaultToleranceSectionsCallback # noqa: F401 19 | from .straggler_det_callback import StragglerDetectionCallback # noqa: F401 20 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/shared_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/src/nvidia_resiliency_ext/shared_utils/__init__.py -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .reporting import Report, StragglerId # noqa: F401 17 | from .statistics import Statistic # noqa: F401 18 | from .straggler import CallableId, Detector # noqa: F401 19 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/cupti.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import threading 17 | 18 | 19 | class CuptiManager: 20 | """Provide thread safe access to the CUPTI extension. 21 | 22 | Implements simple usage counter, to track active profiling runs. 23 | """ 24 | 25 | def __init__(self, bufferSize=1_000_000, numBuffers=8, statsMaxLenPerKernel=4096): 26 | """ 27 | Args: 28 | bufferSize (int, optional): CUPTI buffer size. Defaults to 1MB. 29 | numBuffers (int, optional): Num of CUPTI buffers in a pool . Defaults to 8. 30 | statsMaxLenPerKernel (int, optional): Max number of timing entries per kernel. 31 | (when this limit is rached, oldest timing entries are discarded). Defaults to 4096. 32 | """ 33 | 34 | # lazy load the extension module, to avoid circular import 35 | import nvrx_cupti_module as cupti_module # type: ignore 36 | 37 | self.cupti_ext = cupti_module.CuptiProfiler( 38 | bufferSize=bufferSize, 39 | numBuffers=numBuffers, 40 | statsMaxLenPerKernel=statsMaxLenPerKernel, 41 | ) 42 | self.is_initialized = False 43 | self.started_cnt = 0 44 | self.lock = threading.Lock() 45 | 46 | def _ensure_initialized(self): 47 | """Check for CuptiProfiler initialization.""" 48 | if not self.is_initialized: 49 | raise RuntimeError("CuptiManager was not initialized") 50 | 51 | def initialize(self): 52 | """Call CuptiProfiler initialization method, registering CUPTI 53 | callbacks for profiling.""" 54 | with self.lock: 55 | self.cupti_ext.initialize() 56 | self.is_initialized = True 57 | 58 | def shutdown(self): 59 | """Finalize CUPTI.""" 60 | with self.lock: 61 | self.cupti_ext.shutdown() 62 | self.is_initialized = False 63 | self.started_cnt = 0 64 | 65 | def start_profiling(self): 66 | """Enable CUDA kernels activity tracking.""" 67 | with self.lock: 68 | self._ensure_initialized() 69 | if self.started_cnt == 0: 70 | self.cupti_ext.start() 71 | self.started_cnt += 1 72 | 73 | def stop_profiling(self): 74 | """Disable CUDA kernels activity tracking.""" 75 | with self.lock: 76 | self._ensure_initialized() 77 | if self.started_cnt > 0: 78 | self.started_cnt -= 1 79 | if self.started_cnt == 0: 80 | self.cupti_ext.stop() 81 | else: 82 | raise RuntimeError("No active profiling run.") 83 | 84 | def get_results(self): 85 | """Calculate kernel execution timing statistics.""" 86 | with self.lock: 87 | self._ensure_initialized() 88 | stats = self.cupti_ext.get_stats() 89 | return stats.copy() 90 | 91 | def reset_results(self): 92 | """Reset kernel execution timing records.""" 93 | with self.lock: 94 | self._ensure_initialized() 95 | self.cupti_ext.reset() 96 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/cupti_src/BufferPool.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include "BufferPool.h" 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | 25 | 26 | BufferPool::BufferPool(size_t bufferSize, int numBuffers) 27 | : bufferSize(bufferSize), numBuffers(numBuffers) { 28 | for (int i = 0; i < numBuffers; ++i) { 29 | uint8_t* newBuffer = (uint8_t*)malloc(bufferSize); 30 | if (!newBuffer) { 31 | throw std::bad_alloc(); 32 | } 33 | freeBuffers.push_back(newBuffer); 34 | } 35 | } 36 | 37 | BufferPool::~BufferPool() { 38 | while (!freeBuffers.empty()) { 39 | free(freeBuffers.back()); 40 | freeBuffers.pop_back(); 41 | } 42 | } 43 | 44 | uint8_t* BufferPool::getBuffer() { 45 | std::lock_guard lock(mutex); 46 | if (freeBuffers.empty()) { 47 | return nullptr; // prob better to allocate a new buffer 48 | } 49 | uint8_t* buffer = freeBuffers.back(); 50 | freeBuffers.pop_back(); 51 | return buffer; 52 | } 53 | 54 | void BufferPool::releaseBuffer(uint8_t* buffer) { 55 | std::lock_guard lock(mutex); 56 | freeBuffers.push_back(buffer); 57 | } 58 | 59 | size_t BufferPool::getBufferSize() const { 60 | return bufferSize; 61 | } 62 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/cupti_src/BufferPool.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | 24 | class BufferPool { 25 | public: 26 | BufferPool(size_t bufferSize = 1024 * 1024 * 4, int numBuffers = 20); 27 | ~BufferPool(); 28 | 29 | uint8_t* getBuffer(); 30 | void releaseBuffer(uint8_t* buffer); 31 | size_t getBufferSize() const; 32 | 33 | private: 34 | std::vector freeBuffers; 35 | size_t bufferSize; 36 | int numBuffers; 37 | std::mutex mutex; 38 | }; 39 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/cupti_src/CircularBuffer.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | 22 | template 23 | class CircularBuffer { 24 | private: 25 | std::vector _buffer; 26 | size_t _head; 27 | size_t _tail; 28 | size_t _size; 29 | size_t _capacity; 30 | 31 | public: 32 | explicit CircularBuffer(size_t capacity=32) : 33 | _buffer(capacity), _head(0), _tail(0), _size(0), _capacity(capacity) {} 34 | 35 | ~CircularBuffer() = default; 36 | 37 | bool empty() const { 38 | return _size == 0; 39 | } 40 | 41 | bool full() const { 42 | return _size == _capacity; 43 | } 44 | 45 | size_t size() const { 46 | return _size; 47 | } 48 | 49 | size_t capacity() const { 50 | return _capacity; 51 | } 52 | 53 | void push_back(const T& value) { 54 | _buffer[_tail] = value; 55 | _tail = (_tail + 1) % _capacity; 56 | if (full()) { 57 | _head = (_head + 1) % _capacity; 58 | } else { 59 | _size++; 60 | } 61 | } 62 | 63 | std::vector linearize() const { 64 | std::vector res(_size); 65 | for (size_t linearIndex = 0; linearIndex < _size; linearIndex++) { 66 | res[linearIndex] = _buffer[(_head + linearIndex) % _capacity]; 67 | } 68 | return res; 69 | } 70 | }; 71 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/cupti_src/CuptiProfiler.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #pragma once 19 | 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | 33 | 34 | #include "BufferPool.h" 35 | #include "CircularBuffer.h" 36 | 37 | namespace py = pybind11; 38 | 39 | struct KernelStats { 40 | KernelStats() : num_calls(0), min(NAN), max(NAN), median(NAN), avg(NAN), stddev(NAN) { 41 | } 42 | int num_calls; 43 | float min, max, median, avg, stddev; 44 | std::string toString() const; 45 | }; 46 | 47 | class CuptiProfiler { 48 | public: 49 | CuptiProfiler(size_t cuptiBufferSize = 1024 * 1024 * 8, 50 | size_t cuptiBuffersNum = 8, 51 | size_t statsMaxLenPerKernel = 1024); 52 | ~CuptiProfiler(); 53 | 54 | void initializeProfiling(); 55 | void shutdownProfiling(); 56 | 57 | void startProfiling(); 58 | void stopProfiling(); 59 | 60 | void reset(); 61 | 62 | std::map getStats(); 63 | 64 | private: 65 | static CuptiProfiler* instance; 66 | BufferPool _bufferPool; 67 | std::unordered_map> _kernelDurations; 68 | std::mutex _kernelDurationsMutex; 69 | size_t _statsMaxLenPerKernel; 70 | bool _isInitialized {false}; 71 | bool _isStarted {false}; 72 | 73 | void CUPTIAPI bufferRequested(uint8_t **buffer, size_t *size, size_t *maxNumRecords); 74 | void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer, size_t size, size_t validSize); 75 | 76 | static void CUPTIAPI bufferRequestedTrampoline(uint8_t **buffer, size_t *size, size_t *maxNumRecords); 77 | static void CUPTIAPI bufferCompletedTrampoline(CUcontext ctx, uint32_t streamId, uint8_t *buffer, size_t size, size_t validSize); 78 | }; 79 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/cupti_src/cupti_module_py.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | #include 19 | #include "CuptiProfiler.h" 20 | 21 | namespace py = pybind11; 22 | 23 | 24 | static py::dict get_stats_py(CuptiProfiler* profiler_inst) { 25 | auto stats = profiler_inst->getStats(); 26 | py::dict dict; 27 | for (auto& [key, value] : stats) { 28 | dict[py::cast(key)] = py::cast(value); 29 | } 30 | return dict; 31 | } 32 | 33 | PYBIND11_MODULE(nvrx_cupti_module, m) { 34 | py::class_(m, "KernelStats") 35 | .def(py::init<>()) 36 | .def_readwrite("min", &KernelStats::min) 37 | .def_readwrite("max", &KernelStats::max) 38 | .def_readwrite("median", &KernelStats::median) 39 | .def_readwrite("avg", &KernelStats::avg) 40 | .def_readwrite("stddev", &KernelStats::stddev) 41 | .def_readwrite("num_calls", &KernelStats::num_calls) 42 | .def("__str__", &KernelStats::toString); 43 | 44 | py::class_(m, "CuptiProfiler") 45 | .def(py::init(), 46 | py::arg("bufferSize") = 1024 * 1024 * 8, 47 | py::arg("numBuffers") = 8, 48 | py::arg("statsMaxLenPerKernel") = 1024) 49 | .def("start", &CuptiProfiler::startProfiling, "Start profiling.") 50 | .def("stop", &CuptiProfiler::stopProfiling, "Stop profiling.") 51 | .def("initialize", &CuptiProfiler::initializeProfiling, "Initialize CUPTI.") 52 | .def("shutdown", &CuptiProfiler::shutdownProfiling, "Shutdown CUPTI.") 53 | .def("get_stats", get_stats_py, "Retrieve kernel execution statistics.") 54 | .def("reset", &CuptiProfiler::reset, "Reset statistics."); 55 | } 56 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/interval_tracker.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import dataclasses 17 | import time 18 | from typing import Optional, Sequence 19 | 20 | import torch 21 | 22 | 23 | @dataclasses.dataclass 24 | class ReportIntervalTracker: 25 | """ 26 | `ReportIntervalTracker` is used to calculate the reporting intervals based on a specified time interval. 27 | 28 | Attributes: 29 | INTERVAL_ESTIMATION_ITERS (int): Number of iterations used for estimating the report interval. 30 | time_interval (float): Target time interval for reporting. 31 | current_iter (int): Counter for the current iteration. 32 | iter_interval (int, optional): Computed iteration interval based on the target time interval. 33 | """ 34 | 35 | INTERVAL_ESTIMATION_ITERS: int = 16 36 | time_interval: float = 60.0 37 | current_iter: int = 0 38 | iter_interval: Optional[int] = None 39 | prev_iter_start_time: Optional[float] = None 40 | step_times: Sequence[float] = dataclasses.field(default_factory=list) 41 | profiling_interval: int = 1 42 | 43 | def _gather_report_interval(self): 44 | """ 45 | Gathers the report interval across all distributed processes and sets the maximum interval. 46 | """ 47 | assert self.iter_interval is None, "Report iteration interval has already been gathered." 48 | 49 | step_times = torch.tensor(self.step_times, dtype=torch.float32) 50 | median_step_time = torch.median(step_times) 51 | 52 | gathered_interval = (self.time_interval / median_step_time).to(torch.cuda.current_device()) 53 | if torch.distributed.is_initialized(): 54 | torch.distributed.all_reduce(gathered_interval, op=torch.distributed.ReduceOp.MAX) 55 | # it makes no sense to report more frequently than the profiling interval 56 | self.iter_interval = int(max(gathered_interval.item(), self.profiling_interval)) 57 | 58 | def iter_increase(self): 59 | """ 60 | Increases the iteration counter and gathers the report interval if the estimation phase is completed. 61 | """ 62 | self.current_iter += 1 63 | 64 | if self.iter_interval is None: 65 | if self.prev_iter_start_time is not None: 66 | step_time = time.monotonic() - self.prev_iter_start_time 67 | self.step_times.append(step_time) 68 | if len(self.step_times) == self.INTERVAL_ESTIMATION_ITERS: 69 | self._gather_report_interval() 70 | self.step_times.clear() 71 | assert self.iter_interval is not None 72 | self.prev_iter_start_time = time.monotonic() 73 | 74 | def is_interval_elapsed(self): 75 | """ 76 | Checks if the current iteration is a reporting interval. 77 | 78 | Returns: 79 | bool: True if the interval has elapsed, False otherwise. 80 | """ 81 | return (self.iter_interval is not None) and (self.current_iter % self.iter_interval == 0) 82 | -------------------------------------------------------------------------------- /src/nvidia_resiliency_ext/straggler/statistics.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | 18 | 19 | class Statistic(enum.Enum): 20 | """Enumeration of constants representing common statistical measures that 21 | are used for performance analysis and reporting.""" 22 | 23 | MIN = enum.auto() 24 | MAX = enum.auto() 25 | MED = enum.auto() 26 | AVG = enum.auto() 27 | STD = enum.auto() 28 | NUM = enum.auto() 29 | 30 | def __str__(self): 31 | return f"{self.name}" 32 | 33 | def __repr__(self): 34 | cls_name = self.__class__.__name__ 35 | return f"{cls_name}.{self.name}" 36 | -------------------------------------------------------------------------------- /tests/checkpointing/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # spdx-filecopyrighttext: copyright (c) 2024 nvidia corporation & affiliates. all rights reserved. 2 | # spdx-license-identifier: apache-2.0 3 | # 4 | # licensed under the apache license, version 2.0 (the "license"); 5 | # you may not use this file except in compliance with the license. 6 | # you may obtain a copy of the license at 7 | # 8 | # http://www.apache.org/licenses/license-2.0 9 | # 10 | # unless required by applicable law or agreed to in writing, software 11 | # distributed under the license is distributed on an "as is" basis, 12 | # without warranties or conditions of any kind, either express or implied. 13 | # see the license for the specific language governing permissions and 14 | # limitations under the license. 15 | 16 | import os 17 | import weakref 18 | from pathlib import Path 19 | from shutil import rmtree 20 | from tempfile import TemporaryDirectory 21 | from typing import Optional, Union 22 | 23 | import torch.distributed as dist 24 | 25 | from .test_utilities import Utils 26 | 27 | rank = int(os.environ['LOCAL_RANK']) 28 | 29 | 30 | def empty_dir(path: Path): 31 | if Utils.rank > 0: 32 | return 33 | for p in path.iterdir(): 34 | if p.is_dir(): 35 | rmtree(p) 36 | else: 37 | p.unlink() 38 | 39 | 40 | class TempNamedDir(TemporaryDirectory): 41 | """TemporaryDirectory with a fully named directory. Empties the dir if not empty.""" 42 | 43 | def __init__(self, name: Union[str, Path], sync=True, ignore_cleanup_errors=False) -> None: 44 | self.name = str(name) 45 | if Utils.rank == 0: 46 | os.makedirs(name, exist_ok=True) 47 | empty_dir(Path(name)) 48 | if sync: 49 | import torch 50 | 51 | torch.distributed.barrier() 52 | else: 53 | os.makedirs(name, exist_ok=True) 54 | 55 | self._ignore_cleanup_errors = ignore_cleanup_errors 56 | self._finalizer = weakref.finalize( 57 | self, self._cleanup, self.name, warn_message="Implicitly cleaning up {!r}".format(self) 58 | ) 59 | self.sync = sync 60 | 61 | def cleanup(self, override_sync: Optional[bool] = None) -> None: 62 | sync = self.sync if override_sync is None else override_sync 63 | if sync: 64 | import torch 65 | 66 | torch.distributed.barrier() 67 | 68 | if Utils.rank == 0: 69 | super().cleanup() 70 | 71 | def __enter__(self): 72 | path = Path(super().__enter__()) 73 | if self.sync: 74 | import torch 75 | 76 | torch.distributed.barrier() 77 | return path 78 | 79 | def __exit__(self, exc_type, exc_val, exc_tb): 80 | raised = exc_type is not None 81 | if not raised: 82 | self.cleanup() 83 | -------------------------------------------------------------------------------- /tests/checkpointing/unit/conftest.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from pathlib import Path 17 | 18 | import pytest 19 | 20 | from . import TempNamedDir 21 | from .test_utilities import Utils 22 | 23 | 24 | @pytest.fixture(scope="session") 25 | def tmp_path_dist_ckpt(tmp_path_factory) -> Path: 26 | """Common directory for saving the checkpoint. 27 | 28 | Can't use pytest `tmp_path_factory` directly because directory must be shared between processes. 29 | """ 30 | 31 | tmp_dir = tmp_path_factory.mktemp('ignored', numbered=False) 32 | tmp_dir = tmp_dir.parent.parent / 'tmp_dist_ckpt' 33 | 34 | if Utils.rank == 0: 35 | with TempNamedDir(tmp_dir, sync=False): 36 | yield tmp_dir 37 | 38 | else: 39 | yield tmp_dir 40 | -------------------------------------------------------------------------------- /tests/checkpointing/unit/test_async_save.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from nvidia_resiliency_ext.checkpointing.async_ckpt.torch_ckpt import TorchAsyncCheckpoint 18 | 19 | from . import TempNamedDir 20 | from .test_utilities import TestModel, Utils 21 | 22 | 23 | class TestAsyncSave: 24 | def setup_method(self, method): 25 | Utils.set_world_size(1) 26 | 27 | def teardown_method(self, method): 28 | Utils.set_world_size() 29 | 30 | def test_async_is_equivalent_to_sync(self, tmp_path_dist_ckpt): 31 | Utils.initialize_distributed() 32 | model = TestModel((1024, 1024), 10) 33 | ckpt_impl = TorchAsyncCheckpoint() 34 | state_dict = model.state_dict() 35 | with ( 36 | TempNamedDir(tmp_path_dist_ckpt / 'test_equivalence_async') as async_ckpt_dir, 37 | TempNamedDir(tmp_path_dist_ckpt / 'test_equivalence_sync') as sync_ckpt_dir, 38 | ): 39 | # async 40 | ckpt_impl.async_save(state_dict, async_ckpt_dir / 'test') 41 | 42 | # sync 43 | ckpt_impl.save(state_dict, sync_ckpt_dir / 'test') 44 | 45 | # finalize async 46 | ckpt_impl.finalize_async_save(blocking=True) 47 | 48 | # load and compare 49 | device = torch.device(f"cuda:{torch.cuda.current_device()}") 50 | loaded_async_state_dict = torch.load(async_ckpt_dir / 'test', map_location=device) 51 | loaded_sync_state_dict = torch.load(sync_ckpt_dir / 'test', map_location=device) 52 | for k in loaded_sync_state_dict.keys(): 53 | assert k in loaded_async_state_dict, f"{k} is not in loaded async state_dict" 54 | assert torch.equal( 55 | loaded_async_state_dict[k], loaded_sync_state_dict[k] 56 | ), f"loaded_async_state_dict[{k}] != loaded_sync_state_dict[{k}]" 57 | assert torch.equal( 58 | loaded_async_state_dict[k], state_dict[k] 59 | ), f"loaded_async_state_dict[{k}] != src_state_dict[{k}]" 60 | -------------------------------------------------------------------------------- /tests/checkpointing/unit/test_cleanup.py: -------------------------------------------------------------------------------- 1 | # spdx-filecopyrighttext: copyright (c) 2024 nvidia corporation & affiliates. all rights reserved. 2 | # spdx-license-identifier: apache-2.0 3 | # 4 | # licensed under the apache license, version 2.0 (the "license"); 5 | # you may not use this file except in compliance with the license. 6 | # you may obtain a copy of the license at 7 | # 8 | # http://www.apache.org/licenses/license-2.0 9 | # 10 | # unless required by applicable law or agreed to in writing, software 11 | # distributed under the license is distributed on an "as is" basis, 12 | # without warranties or conditions of any kind, either express or implied. 13 | # see the license for the specific language governing permissions and 14 | # limitations under the license. 15 | 16 | import logging 17 | import re 18 | import time 19 | from pathlib import Path 20 | 21 | import pytest 22 | 23 | from nvidia_resiliency_ext.checkpointing.local.ckpt_managers.local_manager import ( 24 | LocalCheckpointManager, 25 | ) 26 | 27 | from . import TempNamedDir 28 | from .test_utilities import SimpleTensorAwareStateDict, Utils 29 | 30 | 31 | class TestLocalCheckpointing: 32 | def setup_method(self, method): 33 | Utils.initialize_distributed() 34 | 35 | def teardown_method(self, method): 36 | pass 37 | 38 | def _async_save(self, async_save_request, async_save): 39 | if async_save: 40 | async_save_request.execute_sync() 41 | else: 42 | assert async_save_request is None 43 | 44 | @pytest.mark.parametrize(('use_ramdisk'), [False, True]) 45 | @pytest.mark.parametrize(('async_save'), [True]) 46 | def test_basic_save_load_scenarios(self, tmp_path_dist_ckpt, use_ramdisk, async_save, caplog): 47 | if use_ramdisk: 48 | tmp_path_dist_ckpt = Path("/dev/shm") 49 | with ( 50 | TempNamedDir(tmp_path_dist_ckpt / "test_save_load") as local_ckpt_dir, 51 | caplog.at_level(logging.DEBUG), 52 | ): 53 | local_ckpt_dir = local_ckpt_dir / "subdir" # Test handling of non-existent directories 54 | 55 | # Test performance on SSD only to save compute time. 56 | tensor_num = 10 if use_ramdisk else 16384 57 | 58 | checkpoint_manager = LocalCheckpointManager(local_ckpt_dir) 59 | # "Multiple saves" 60 | intermediete_state_dict = SimpleTensorAwareStateDict(iteration=1, tensor_num=tensor_num) 61 | # SAVE 62 | async_save_request = checkpoint_manager.save(intermediete_state_dict, 1, async_save) 63 | self._async_save(async_save_request, async_save) 64 | ckpt_id = checkpoint_manager._ckpt_id(1) 65 | first_ckpt_path = checkpoint_manager._local_ckpt_path_from_id(ckpt_id) 66 | assert first_ckpt_path.exists() 67 | intermediete_state_dict = SimpleTensorAwareStateDict(iteration=2, tensor_num=tensor_num) 68 | # SAVE 69 | async_save_request = checkpoint_manager.save(intermediete_state_dict, 2, async_save) 70 | self._async_save(async_save_request, async_save) 71 | ckpt_id = checkpoint_manager._ckpt_id(2) 72 | second_ckpt_path = checkpoint_manager._local_ckpt_path_from_id(ckpt_id) 73 | assert second_ckpt_path.exists() 74 | time.sleep(0.8) 75 | assert not first_ckpt_path.exists() 76 | 77 | def extract_finalize_time_from_log(caplog): 78 | pattern = r"finalize_fn took ([\d.]+)s" 79 | matches = re.findall(pattern, caplog.text) 80 | if matches: 81 | return float(matches[-1]) # Return the last match as a float 82 | return None 83 | 84 | time_to_finalize = extract_finalize_time_from_log(caplog) 85 | # Async cleanup based on Processes: ~0.04s, sync: >0.1s 86 | assert time_to_finalize < 0.03 87 | -------------------------------------------------------------------------------- /tests/fault_tolerance/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /tests/fault_tolerance/unit/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def pytest_configure(): 6 | logging.basicConfig( 7 | level=os.getenv('FT_UNIT_TEST_LOGLEVEL', 'DEBUG'), 8 | format="%(asctime)s - %(levelname)s - %(message)s", 9 | datefmt="%Y-%m-%d %H:%M:%S", 10 | ) 11 | -------------------------------------------------------------------------------- /tests/fault_tolerance/unit/test_ipc_connector.py: -------------------------------------------------------------------------------- 1 | import os 2 | from multiprocessing import Pool 3 | 4 | import pytest 5 | 6 | from nvidia_resiliency_ext.fault_tolerance.ipc_connector import IpcConnector 7 | 8 | 9 | def _sender_process(rank): 10 | socket_path = '/tmp/test_ipc_socket' 11 | sender = IpcConnector(socket_path) 12 | sender.send( 13 | ( 14 | rank, 15 | "Test Message 1", 16 | ) 17 | ) 18 | sender.send( 19 | ( 20 | rank, 21 | "Test Message 2", 22 | ) 23 | ) 24 | sender.send( 25 | ( 26 | rank, 27 | "STOP", 28 | ) 29 | ) 30 | 31 | 32 | def test_ipc_connector_send_receive(): 33 | socket_path = '/tmp/test_ipc_socket' 34 | 35 | receiver = IpcConnector(socket_path) 36 | receiver.start_receiving() 37 | 38 | # 2nd start receiving should fail 39 | with pytest.raises(Exception): 40 | receiver.start_receiving() 41 | 42 | # receiver should be empty 43 | assert not receiver.peek_received() 44 | assert not receiver.fetch_received() 45 | 46 | # clear on empty does nothing 47 | receiver.clear() 48 | assert not receiver.peek_received() 49 | assert not receiver.fetch_received() 50 | 51 | # try to send and receive a few times 52 | attempts = 4 53 | num_processes = 4 54 | ranks = range(num_processes) 55 | for _ in range(attempts): 56 | # send messages from sub-processes 57 | with Pool(processes=num_processes) as pool: 58 | _ = pool.map(_sender_process, ranks) 59 | # peek_received should not clear internal message queue 60 | assert len(receiver.peek_received()) == num_processes * 3 61 | assert len(receiver.peek_received()) == num_processes * 3 62 | for t in range(num_processes): 63 | assert (t, "Test Message 1") in receiver.peek_received() 64 | assert (t, "Test Message 2") in receiver.peek_received() 65 | assert (t, "STOP") in receiver.peek_received() 66 | # should be empty again after .clear 67 | receiver.clear() 68 | assert len(receiver.peek_received()) == 0 69 | # send messages from sub-processes again 70 | with Pool(processes=num_processes) as pool: 71 | _ = pool.map(_sender_process, ranks) 72 | # fetch_received clears the internal message queue 73 | received_messages = receiver.fetch_received() 74 | assert len(receiver.fetch_received()) == 0 75 | assert len(received_messages) == num_processes * 3 76 | for t in range(num_processes): 77 | assert (t, "Test Message 1") in received_messages 78 | assert (t, "Test Message 2") in received_messages 79 | assert (t, "STOP") in received_messages 80 | 81 | receiver.stop_receiving() 82 | assert not os.path.exists(socket_path) 83 | receiver.stop_receiving() # no op 84 | -------------------------------------------------------------------------------- /tests/fault_tolerance/unit/test_process_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import multiprocessing as mp 17 | import time 18 | 19 | import pytest 20 | 21 | from nvidia_resiliency_ext.fault_tolerance.utils import ( 22 | is_process_alive, 23 | wait_until_process_terminated, 24 | ) 25 | 26 | 27 | def _sleeping_process(time_to_sleep): 28 | time.sleep(time_to_sleep) 29 | 30 | 31 | def test_is_process_alive(): 32 | proc_obj = mp.Process(target=_sleeping_process, args=(2,)) 33 | proc_obj.start() 34 | assert is_process_alive(proc_obj.pid) 35 | wait_until_process_terminated(proc_obj.pid, timeout=10) 36 | assert not is_process_alive(proc_obj.pid) 37 | 38 | 39 | def test_wait_until_process_terminated(): 40 | proc_obj = mp.Process(target=_sleeping_process, args=(3,)) 41 | proc_obj.start() 42 | with pytest.raises(Exception): 43 | wait_until_process_terminated(proc_obj.pid, timeout=0.1) 44 | -------------------------------------------------------------------------------- /tests/inprocess/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import os 17 | 18 | if 'TORCH_CPP_LOG_LEVEL' not in os.environ: 19 | os.environ['TORCH_CPP_LOG_LEVEL'] = 'error' 20 | 21 | 22 | if 'PYTORCH_NVML_BASED_CUDA_CHECK' not in os.environ: 23 | os.environ['PYTORCH_NVML_BASED_CUDA_CHECK'] = '1' 24 | -------------------------------------------------------------------------------- /tests/inprocess/test_abort.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import datetime 18 | import multiprocessing 19 | import unittest 20 | 21 | import torch 22 | 23 | import nvidia_resiliency_ext.inprocess as inprocess 24 | 25 | from . import common 26 | 27 | 28 | @common.apply_all_tests(common.retry()) 29 | @unittest.skipIf(not torch.cuda.is_available(), 'cuda not available') 30 | class TestAbort(unittest.TestCase): 31 | @staticmethod 32 | def launch(fn, world_size=2, timeout=datetime.timedelta(seconds=10)): 33 | procs = [] 34 | ctx = multiprocessing.get_context('fork') 35 | barrier = ctx.Barrier(world_size) 36 | for rank in range(world_size): 37 | p = ctx.Process(target=fn, args=(rank, world_size, barrier)) 38 | p.start() 39 | procs.append(p) 40 | 41 | for p in procs: 42 | p.join(timeout.total_seconds()) 43 | if p.exitcode != 0: 44 | for p in procs: 45 | p.kill() 46 | exitcodes = [p.exitcode for p in procs] 47 | return exitcodes 48 | 49 | def test_multi_group(self): 50 | 51 | def run(rank, world_size, barrier): 52 | abort = inprocess.abort.AbortTorchDistributed() 53 | 54 | store = torch.distributed.TCPStore( 55 | host_name='localhost', 56 | port=29501, 57 | is_master=(rank == 0), 58 | timeout=datetime.timedelta(seconds=5), 59 | ) 60 | torch.cuda.set_device(rank) 61 | device = torch.device('cuda') 62 | torch.distributed.init_process_group( 63 | backend='nccl', store=store, rank=rank, world_size=world_size 64 | ) 65 | barrier.wait() 66 | size = 128 67 | t1 = torch.ones(size, device=device) 68 | t2 = torch.ones(size, device=device) 69 | default_group = torch.distributed.distributed_c10d._get_default_group() 70 | torch.distributed.all_reduce(t1, group=default_group) 71 | torch.cuda.synchronize() 72 | new_group = torch.distributed.new_group([0]) 73 | if rank == 0: 74 | torch.distributed.all_reduce(t2, group=new_group) 75 | torch.cuda.synchronize() 76 | 77 | for i in range(3): 78 | if rank == 0: 79 | torch.distributed.all_reduce(t2, group=new_group) 80 | torch.distributed.all_reduce(t1, group=default_group) 81 | if i == 1 and rank == 1: 82 | abort(None) 83 | break 84 | if i == 2 and rank == 0: 85 | abort(None) 86 | break 87 | torch.cuda.synchronize() 88 | 89 | exitcodes = self.launch(run) 90 | self.assertTrue(all(ec == 0 for ec in exitcodes), exitcodes) 91 | -------------------------------------------------------------------------------- /tests/inprocess/test_compose.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import unittest 18 | 19 | import nvidia_resiliency_ext.inprocess as inprocess 20 | 21 | 22 | class TestCompose(unittest.TestCase): 23 | def test_empty(self): 24 | counter = 0 25 | 26 | class Fn: 27 | def __call__(self): 28 | nonlocal counter 29 | counter += 1 30 | 31 | composed = inprocess.Compose(Fn(), Fn(), Fn(), Fn()) 32 | composed() 33 | self.assertEqual(counter, 4) 34 | 35 | def test_none(self): 36 | counter = 0 37 | 38 | class Fn: 39 | def __call__(self, x): 40 | nonlocal counter 41 | counter += 1 42 | return x 43 | 44 | composed = inprocess.Compose(Fn(), Fn(), Fn(), Fn()) 45 | ret = composed(None) 46 | self.assertIs(ret, None) 47 | self.assertEqual(counter, 4) 48 | 49 | def test_return(self): 50 | class Fn: 51 | def __call__(self): 52 | return 1 53 | 54 | composed = inprocess.Compose(Fn()) 55 | ret = composed() 56 | self.assertEqual(ret, 1) 57 | 58 | def test_init_arg(self): 59 | class Fn: 60 | def __init__(self, arg): 61 | self.arg = arg 62 | 63 | def __call__(self): 64 | return self.arg 65 | 66 | composed = inprocess.Compose(Fn(1)) 67 | ret = composed() 68 | self.assertEqual(ret, 1) 69 | 70 | def test_no_return_warns(self): 71 | class Fn: 72 | def __call__(self, x): 73 | pass 74 | 75 | composed = inprocess.Compose(Fn(), Fn()) 76 | with self.assertWarns(UserWarning): 77 | ret = composed(1) 78 | self.assertEqual(ret, None) 79 | 80 | def test_propagate(self): 81 | class Fn: 82 | def __call__(self, counter): 83 | return counter + 1 84 | 85 | composed = inprocess.Compose(Fn(), Fn(), Fn(), Fn()) 86 | ret = composed(0) 87 | self.assertEqual(ret, 4) 88 | 89 | def test_tuple(self): 90 | class Fn: 91 | def __call__(self, a, b, c): 92 | return a + 1, b + 1, c + 1 93 | 94 | composed = inprocess.Compose(Fn(), Fn(), Fn(), Fn()) 95 | ret = composed(0, 1, 2) 96 | self.assertEqual(ret, (4, 5, 6)) 97 | 98 | def test_basic_subclass(self): 99 | class Base: 100 | pass 101 | 102 | class Foo(Base): 103 | def __call__(self): 104 | pass 105 | 106 | class Bar(Base): 107 | def __call__(self): 108 | pass 109 | 110 | composed = inprocess.Compose(Foo(), Bar()) 111 | self.assertIsInstance(composed, Base) 112 | self.assertNotIsInstance(composed, Foo) 113 | self.assertNotIsInstance(composed, Bar) 114 | 115 | def test_nested_subclass(self): 116 | class Base: 117 | pass 118 | 119 | class Foo(Base): 120 | def __call__(self): 121 | pass 122 | 123 | class Bar(Foo): 124 | def __call__(self): 125 | pass 126 | 127 | composed = inprocess.Compose(Foo(), Bar()) 128 | self.assertIsInstance(composed, Base) 129 | self.assertIsInstance(composed, Foo) 130 | self.assertNotIsInstance(composed, Bar) 131 | -------------------------------------------------------------------------------- /tests/inprocess/test_health_check.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES 2 | # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import datetime 18 | import multiprocessing 19 | import sys 20 | import threading 21 | import time 22 | import unittest 23 | 24 | import torch 25 | 26 | import nvidia_resiliency_ext.inprocess as inprocess 27 | 28 | from . import common # noqa: F401 29 | 30 | 31 | @unittest.skipIf(not torch.cuda.is_available(), 'cuda not available') 32 | class TestCudaHealthCheck(unittest.TestCase): 33 | @staticmethod 34 | def launch(fn, timeout=datetime.timedelta(seconds=10)): 35 | ctx = multiprocessing.get_context('fork') 36 | proc = ctx.Process(target=fn) 37 | start_time = time.perf_counter() 38 | proc.start() 39 | 40 | proc.join(timeout.total_seconds()) 41 | if proc.exitcode is None: 42 | proc.kill() 43 | proc.join() 44 | stop_time = time.perf_counter() 45 | elapsed = stop_time - start_time 46 | return proc.exitcode, elapsed 47 | 48 | def test_basic(self): 49 | def run(): 50 | check = inprocess.health_check.CudaHealthCheck() 51 | check(None) 52 | torch.cuda.synchronize() 53 | 54 | exitcode, _ = self.launch(run) 55 | self.assertEqual(exitcode, 0) 56 | 57 | def test_timeout(self): 58 | def run(): 59 | torch.ones(1).cuda() 60 | check = inprocess.health_check.CudaHealthCheck(datetime.timedelta(seconds=1)) 61 | torch.cuda._sleep(1 << 40) 62 | try: 63 | check(None) 64 | sys.exit(1) 65 | except inprocess.exception.TimeoutError: 66 | sys.exit(0) 67 | 68 | exitcode, elapsed = self.launch(run) 69 | self.assertEqual(exitcode, 0) 70 | self.assertLess(elapsed, 10) 71 | 72 | @unittest.mock.patch.object(threading, 'excepthook', new=lambda _: None) 73 | def test_raises(self): 74 | def run(): 75 | check = inprocess.health_check.CudaHealthCheck(datetime.timedelta(seconds=5)) 76 | b = torch.ones(1, dtype=torch.int64).cuda() 77 | a = torch.ones(1, dtype=torch.int64).cuda() 78 | a[b] = 0 79 | try: 80 | check(None) 81 | sys.exit(1) 82 | except RuntimeError as ex: 83 | if 'CUDA' in str(ex): 84 | sys.exit(0) 85 | sys.exit(1) 86 | 87 | exitcode, elapsed = self.launch(run) 88 | self.assertEqual(exitcode, 0) 89 | self.assertLess(elapsed, 10) 90 | -------------------------------------------------------------------------------- /tests/ptl_resiliency/func/nemo20/Dockerfile.ft_test: -------------------------------------------------------------------------------- 1 | # Test image used for FT tests with NeMo 2.0 (`ft_test_sim_nodes.sh`) 2 | 3 | ARG BASE_IMG 4 | FROM ${BASE_IMG} 5 | 6 | COPY . /workdir/nvidia_resiliency_ext 7 | RUN pip install /workdir/nvidia_resiliency_ext 8 | 9 | RUN python -c "import nvidia_resiliency_ext.fault_tolerance" 10 | 11 | WORKDIR /workdir/nvidia_resiliency_ext 12 | -------------------------------------------------------------------------------- /tests/ptl_resiliency/func/nemo20/ft_test_asserts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Verify `ft_test_launchers.sh` output from a given stage 4 | # Usage is `./ft_test_asserts.sh ` 5 | 6 | set -x 7 | set -o pipefail 8 | 9 | : "${FT_CONT_OUT_DIR:?Error: FT_CONT_OUT_DIR is not set or empty}" 10 | : "${LOG_FILE:?Error: LOG_FILE is not set or empty}" 11 | 12 | function assert_log_contains { 13 | expected_str="$1" 14 | if ! grep -q "${expected_str}" ${FT_CONT_OUT_DIR}/${LOG_FILE}; then 15 | echo "Expected string not found in logs: ${expected_str}" 16 | exit 1 17 | fi 18 | } 19 | 20 | function assert_not_in_log { 21 | not_expected_str="$1" 22 | if grep -q "${not_expected_str}" ${FT_CONT_OUT_DIR}/${LOG_FILE}; then 23 | echo "Not expected string found in logs: ${not_expected_str}" 24 | exit 1 25 | fi 26 | } 27 | 28 | function assert_checkpoint_saved { 29 | if [ -d "${FT_CONT_OUT_DIR}/default/checkpoints/step*-last" ] ; then 30 | echo "Expected last checkpoint to be saved, but not found in ${FT_CONT_OUT_DIR}/default/checkpoints/" 31 | exit 1 32 | fi 33 | } 34 | 35 | function assert_number_of_runs { 36 | expected_num=$1 37 | actual_num=$(grep -c "All distributed processes registered." ${FT_CONT_OUT_DIR}/${LOG_FILE}) 38 | if [ "$expected_num" -ne "$actual_num" ]; then 39 | echo "Expected runs: ${expected_num}, but got ${actual_num}" 40 | exit 1 41 | fi 42 | } 43 | 44 | function assert_all_launchers_succeeded { 45 | assert_not_in_log "Some rank(s) exited with non-zero exit code" 46 | } 47 | 48 | function assert_launchers_failed { 49 | assert_log_contains "Some rank(s) exited with non-zero exit code" 50 | } 51 | 52 | case "$1" in 53 | 1) 54 | assert_log_contains "Simulating fault" 55 | assert_log_contains "FT timeout elapsed" 56 | assert_checkpoint_saved 57 | assert_launchers_failed 58 | ;; 59 | 2) 60 | assert_log_contains "Time limit reached." 61 | assert_log_contains "Updated FT timeouts." 62 | assert_all_launchers_succeeded 63 | ;; 64 | 3) 65 | assert_number_of_runs 3 66 | assert_log_contains "Simulating fault" 67 | assert_log_contains "FT timeout elapsed" 68 | assert_launchers_failed 69 | ;; 70 | 4) 71 | assert_log_contains "Time limit reached." 72 | assert_log_contains "Updated FT timeouts." 73 | assert_all_launchers_succeeded 74 | ;; 75 | *) 76 | echo "Invalid stage for assertions." 77 | exit 1 78 | ;; 79 | esac 80 | 81 | echo "Assertions for stage $1 passed." 82 | -------------------------------------------------------------------------------- /tests/ptl_resiliency/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/nvidia-resiliency-ext/6ab773c668838ecd530ddaaa13f618ad466b7c61/tests/ptl_resiliency/unit/__init__.py -------------------------------------------------------------------------------- /tests/straggler/README.md: -------------------------------------------------------------------------------- 1 | # Straggler Detection Tests 2 | ## Running unit tests: 3 | ``` 4 | python3 -m pytest -rs -x ./tests/unit/ 5 | ``` 6 | 7 | ### Unit tests coverage: 8 | - `test_cupti_ext.py`: Test API of C++ class `CuptiProfiler`. 9 | 10 | - `test_cupti_manager.py`: Test `CuptiManager`, Python class wrapping `CuptiProfiler`. 11 | 12 | - `test_data_shared.py`: Ensure size of shared data is reduced after kernel names are mapped into IDs. 13 | - `test_det_section_api.py`: Test `Detector.detection_section` context manager behavior. 14 | - `test_individual_gpu_scores.py`: Verify individual GPU scores values, `rank_to_node` report field. 15 | - `test_interval_tracker.py`: Test `ReportIntervalTracker` reporting interval estimation functionality. 16 | - `test_name_mapper.py`: Test `NameMapper` API. 17 | - `test_relative_gpu_scores.py`: Same as `test_individual_gpu_scores.py` for relative GPU scores. 18 | - `test_reporting_elapsed.py`: Test `Detector.generate_report_if_interval_elapsed` and `ReportIntervalTracker` functionality. 19 | - `test_reporting.py`: Test `Detector.generate_report` and `Report.identify_stragglers` functionality. 20 | - `test_sections.py`: Simulate straggler sections and verify correct detection based on both individual and relative scores. 21 | - `test_wrap_callables.py`: Test `Detector.wrap_callables` behavior. 22 | 23 | ### Running multi-GPU tests: 24 | While `test_reporting_elapsed.py` and `test_reporting.py` perform unit tests to check functionalities, the complete testing scenarios are validated when running on multiple GPUs. 25 | ``` 26 | torchrun --nproc-per-node=8 tests/unit/test_reporting.py 27 | torchrun --nproc-per-node=8 tests/unit/test_reporting_elapsed.py 28 | ``` 29 | 30 | 31 | ## Running functional tests 32 | Testing straggler detection with various options combinations with various combinations of arguments in multi GPU setting. Below are examples of how to run tests using `torchrun`. 33 | 34 | ### Example commands: 35 | 36 | 1. Test with `.generate_if_interval_elapsed` and `gather_on_rank0=False`: 37 | ``` 38 | torchrun --nproc-per-node=8 tests/func/ddp_test.py --generate_if_elapsed --no_gather_on_rank0 39 | ``` 40 | 41 | 2. Test with `gather_on_rank0=False` and `scores_to_compute=["relative_perf_scores"]`: 42 | ``` 43 | torchrun --nproc-per-node=8 tests/func/ddp_test.py --no_gather_on_rank0 --no_indiv_scores 44 | ``` 45 | 46 | ### Available arguments 47 | 48 | - `--no_rel_scores`: Do not compute relative performance scores. 49 | - `--no_indiv_scores`: Do not compute individual performance scores. 50 | - `--no_gather_on_rank0`: Set `gather_on_rank0` to `False`. 51 | - `--use_wrap_forward`: Use wrap callables instead of detection section (default: `False`). 52 | - `--report_iter_interval`: Interval for generating report in iterations (default: `1000`). 53 | - `--generate_if_elapsed`: Generate report if interval elapsed (default: `False`). 54 | - `--report_time_interval`: Time interval for generating report if interval elapsed in seconds (default: `1`). 55 | 56 | -------------------------------------------------------------------------------- /tests/straggler/func/check_log.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import re 18 | 19 | 20 | def cmd_num_reports(log_file_lines, args): 21 | report_lines = [ln for ln in log_file_lines if 'STRAGGLER REPORT' in ln] 22 | num_reports = len(report_lines) 23 | if not (args.min <= num_reports <= args.max): 24 | raise ValueError( 25 | f"Invalid number of reports: {num_reports}. Valid range is from {args.min} to {args.max}." 26 | ) 27 | 28 | 29 | def _check_gpu_stragglers(log_file_lines, pattern, expected_stragglers): 30 | found_stragglers = set() 31 | for ln in log_file_lines: 32 | match = pattern.search(ln) 33 | if match: 34 | rank_value = int(match.group(1)) 35 | found_stragglers.add(rank_value) 36 | if found_stragglers != expected_stragglers: 37 | raise ValueError( 38 | f"Invalid relative GPU stragglers. Found: {found_stragglers}. Expected: {expected_stragglers}." 39 | ) 40 | 41 | 42 | def cmd_relative_gpu_stragglers(log_file_lines, args): 43 | pattern = re.compile(r'DETECTED RELATIVE STRAGGLER GPU RANK=(\d+)') 44 | expected_stragglers = set(args.ranks) 45 | _check_gpu_stragglers(log_file_lines, pattern, expected_stragglers) 46 | 47 | 48 | def cmd_individual_gpu_stragglers(log_file_lines, args): 49 | pattern = re.compile(r'DETECTED INDIVIDUAL STRAGGLER GPU RANK=(\d+)') 50 | expected_stragglers = set(args.ranks) 51 | _check_gpu_stragglers(log_file_lines, pattern, expected_stragglers) 52 | 53 | 54 | def read_log_file(log_file): 55 | lines = [] 56 | contains_done_entry = False 57 | with open(log_file, 'r') as f: 58 | for ln in f.readlines(): 59 | lines.append(ln.strip()) 60 | if 'DONE' in ln: 61 | contains_done_entry = True 62 | if not contains_done_entry: 63 | raise ValueError("Log file does not contain a 'DONE' entry.") 64 | return lines 65 | 66 | 67 | def main(): 68 | parser = argparse.ArgumentParser(description="Verify a log file created by ddp_test.py.") 69 | parser.add_argument('--log', required=True, help='Path to the log file') 70 | 71 | subparsers = parser.add_subparsers(dest='command', help='Sub-commands') 72 | 73 | subp1 = subparsers.add_parser('num_reports', help='Number of straggler reports.') 74 | subp1.add_argument('--min', type=int) 75 | subp1.add_argument('--max', type=int) 76 | subp1.set_defaults(func=cmd_num_reports) 77 | 78 | subp2 = subparsers.add_parser('relative_gpu_stragglers', help='Relative GPU stragglers.') 79 | subp2.add_argument('--ranks', type=int, nargs='*', default=set()) 80 | subp2.set_defaults(func=cmd_relative_gpu_stragglers) 81 | 82 | subp3 = subparsers.add_parser('individual_gpu_stragglers', help='Individual GPU stragglers.') 83 | subp3.add_argument('--ranks', type=int, nargs='*', default=set()) 84 | subp3.set_defaults(func=cmd_individual_gpu_stragglers) 85 | 86 | args = parser.parse_args() 87 | 88 | if args.command: 89 | lines = read_log_file(args.log) 90 | args.func(lines, args) 91 | else: 92 | parser.print_help() 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /tests/straggler/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /tests/straggler/unit/test_cupti_manager.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import pytest 17 | import torch 18 | 19 | from nvidia_resiliency_ext.straggler.cupti import CuptiManager 20 | 21 | 22 | def test_cupti_manager_start_stop(): 23 | cupti_mgr = CuptiManager() 24 | a = torch.randn(1000, 1000, device="cuda") 25 | b = torch.randn(1000, 1000, device="cuda") 26 | torch.cuda.synchronize() 27 | cupti_mgr.initialize() 28 | with pytest.raises(Exception): 29 | cupti_mgr.stop_profiling() 30 | # start profiling 2 times 31 | cupti_mgr.start_profiling() 32 | cupti_mgr.start_profiling() 33 | # stop once, should be still profiling 34 | cupti_mgr.stop_profiling() 35 | # do the matmul, that should be captured 36 | torch.matmul(a, b) 37 | torch.cuda.synchronize() 38 | # stop again, profiling should be stopped 39 | cupti_mgr.stop_profiling() 40 | # do the matmul, should not be captured 41 | torch.matmul(a, b) 42 | torch.cuda.synchronize() 43 | # ensure that just one matmul was captured 44 | stats = cupti_mgr.get_results() 45 | cupti_mgr.shutdown() 46 | assert len(stats) == 1 47 | mm_kernel_name, mm_kernel_stats = list(stats.items())[0] 48 | assert mm_kernel_stats.num_calls == 1 49 | 50 | 51 | def test_cupti_manager_captures_all_started_kernels(): 52 | cupti_mgr = CuptiManager() 53 | cupti_mgr.initialize() 54 | a = torch.randn(1000, 1000, device="cuda") 55 | b = torch.randn(1000, 1000, device="cuda") 56 | # do not not capture randn 57 | torch.cuda.synchronize() 58 | # some CUDA activity that should NOT be captured, as profiling is not started 59 | for _ in range(100): 60 | _ = torch.matmul(a, b) 61 | # now start capturing 62 | cupti_mgr.start_profiling() 63 | for _ in range(50): 64 | _ = torch.matmul(a, b) 65 | # another nested start 66 | cupti_mgr.start_profiling() 67 | # stop, but should still be capturing, as there were 2 starts 68 | cupti_mgr.stop_profiling() 69 | for _ in range(50): 70 | _ = torch.matmul(a, b) 71 | # second stop, all capturing should be stopped 72 | cupti_mgr.stop_profiling() 73 | for _ in range(100): 74 | _ = torch.matmul(a, b) 75 | # ensure that just matmul OP was captured 76 | torch.cuda.synchronize() 77 | stats = cupti_mgr.get_results() 78 | cupti_mgr.shutdown() 79 | assert len(stats) == 1 80 | mm_kernel_name, mm_kernel_stats = list(stats.items())[0] 81 | assert mm_kernel_stats.num_calls == 100 82 | -------------------------------------------------------------------------------- /tests/straggler/unit/test_interval_tracker.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import time 17 | 18 | from nvidia_resiliency_ext.straggler import interval_tracker 19 | 20 | 21 | def test_estimate(): 22 | 23 | tracker = interval_tracker.ReportIntervalTracker() 24 | tracker.time_interval = 0.5 25 | 26 | assert tracker.iter_interval is None 27 | 28 | for i in range(120): 29 | tracker.iter_increase() 30 | time.sleep(0.01) 31 | if tracker.current_iter <= tracker.INTERVAL_ESTIMATION_ITERS: 32 | # estimate is available after INTERVAL_ESTIMATION_ITERS iterations 33 | assert tracker.iter_interval is None 34 | else: 35 | assert tracker.iter_interval is not None 36 | assert tracker.is_interval_elapsed() == ( 37 | (tracker.current_iter % tracker.iter_interval) == 0 38 | ) 39 | # a few longer initial steps should not affect the estimate 40 | if i < tracker.INTERVAL_ESTIMATION_ITERS // 2: 41 | time.sleep(0.04) 42 | 43 | # step times re not needed after the estimate is computed 44 | assert not tracker.step_times 45 | # iter time 0.01 and time interval 0.5sec should give estimate of ~50 iterations 46 | assert abs(tracker.iter_interval - 50) < 5 47 | --------------------------------------------------------------------------------