├── .clang-format ├── .clang-tidy ├── .flake8 ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── config.yml │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── _build_conda.yaml │ ├── _build_doc.yaml │ ├── _build_wheel.yaml │ ├── _deploy.yaml │ ├── _lint.yaml │ ├── _test_conda.yaml │ ├── _test_wheel.yaml │ ├── nightly.yaml │ ├── push.yaml │ ├── push_doc.yaml │ └── release.yaml ├── .gitignore ├── .gitmodules ├── .isort.cfg ├── CHANGELOG.md ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── LSan.supp ├── README.md ├── VERSION ├── cmake └── Helpers.cmake ├── docker ├── ci-base │ ├── Dockerfile │ ├── install-common │ ├── install-git │ └── install-python ├── ci-clang │ ├── Dockerfile │ ├── install-clang │ └── install-cmake-ninja ├── ci-conda │ ├── Dockerfile.cpu │ ├── Dockerfile.cu117 │ ├── Dockerfile.cu118 │ ├── install-conda │ ├── install-cuda-11.7 │ └── install-cuda-11.8 └── ci-wheel │ ├── Dockerfile.cpu │ ├── Dockerfile.cu117 │ ├── Dockerfile.cu118 │ ├── install-awscli │ ├── install-cuda-11.7 │ ├── install-cuda-11.8 │ ├── install-cudnn-8.3.2 │ ├── install-devtoolset-10 │ └── install-devtoolset-11 ├── docs ├── Makefile ├── requirements.txt └── src │ ├── _static │ └── img │ │ ├── fake-tensor-dispatch.png │ │ ├── fake-tensor.png │ │ └── variable-hooks.png │ ├── conf.py │ ├── deferred_init.rst │ ├── fake_tensor.rst │ ├── fake_tensor_and_deferred_init.rst │ ├── gossip_grad.rst │ ├── index.rst │ └── slow_momentum_fsdp.rst ├── packaging └── conda │ ├── build.sh │ ├── conda_build_config.yaml │ ├── install-debug.sh │ ├── install-devel.sh │ ├── install-lib.sh │ ├── install-python.sh │ ├── meta.yaml │ └── variants │ ├── cu117.yaml │ └── cu118.yaml ├── requirements-devel.txt ├── requirements.txt ├── scripts ├── set-version └── strip-debug-symbols ├── setup.py ├── src ├── cc │ ├── torchdistx-config.cmake.in │ └── torchdistx │ │ ├── CMakeLists.txt │ │ ├── deferred_init.cc │ │ ├── deferred_init.h │ │ ├── fake.cc │ │ ├── fake.h │ │ ├── macros.h │ │ ├── stack_utils.cc │ │ └── stack_utils.h └── python │ └── torchdistx │ ├── _C.pyi │ ├── _C │ ├── CMakeLists.txt │ ├── deferred_init.cc │ ├── fake.cc │ ├── module.cc │ └── module.h │ ├── __init__.py │ ├── deferred_init.py │ ├── fake.py │ ├── gossip_grad.py │ ├── optimizers │ ├── __init__.py │ └── anyprecision_optimizer.py │ ├── py.typed │ └── slowmo │ ├── __init__.py │ ├── slowmo_comm.py │ └── slowmo_optimizer.py ├── tests ├── cc │ └── .gitkeep └── python │ ├── test_anyprecision_optimizer.py │ ├── test_comm_hooks_fsdp.py │ ├── test_deferred_init.py │ └── test_fake.py ├── use-cpu.txt ├── use-cu117.txt └── use-cu118.txt /.clang-format: -------------------------------------------------------------------------------- 1 | BasedOnStyle: Google 2 | AllowShortFunctionsOnASingleLine: Empty 3 | AllowShortLambdasOnASingleLine: Empty 4 | ColumnLimit: 100 5 | IncludeBlocks: Preserve 6 | WhitespaceSensitiveMacros: [ 7 | TORCH_CHECK, 8 | TORCH_CHECK_VALUE, 9 | TORCH_CHECK_NOT_IMPLEMENTED, 10 | TORCH_INTERNAL_ASSERT, 11 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY, 12 | ] 13 | -------------------------------------------------------------------------------- /.clang-tidy: -------------------------------------------------------------------------------- 1 | Checks: "*,\ 2 | -altera-*,\ 3 | -android-*,\ 4 | -bugprone-easily-swappable-parameters,\ 5 | -clang-analyzer-*,\ 6 | -clang-diagnostic-extra-semi-stmt,\ 7 | -clang-diagnostic-return-std-move-in-c++11,\ 8 | -cppcoreguidelines-avoid-non-const-global-variables,\ 9 | -cppcoreguidelines-non-private-member-variables-in-classes,\ 10 | -cppcoreguidelines-pro-bounds-array-to-pointer-decay,\ 11 | -cppcoreguidelines-pro-bounds-pointer-arithmetic,\ 12 | -cppcoreguidelines-pro-type-vararg,\ 13 | -facebook-*,\ 14 | -fuchsia-*,\ 15 | -google-readability-todo,\ 16 | -hicpp-*,\ 17 | -llvm-*, 18 | -llvmlibc-*, 19 | -misc-const-correctness, 20 | -misc-no-recursion, 21 | -misc-non-private-member-variables-in-classes,\ 22 | -modernize-use-nodiscard, 23 | -modernize-use-trailing-return-type, 24 | -readability-else-after-return, 25 | -readability-function-cognitive-complexity, 26 | -readability-identifier-length, 27 | -readability-named-parameter, 28 | -readability-redundant-access-specifiers" 29 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Line length recommended by black. 3 | max-line-length = 88 4 | extend-ignore = 5 | # See https://github.com/PyCQA/pycodestyle/issues/373 6 | E203, 7 | # See https://github.com/psf/black/issues/40 8 | E302, 9 | per-file-ignores= 10 | # Ignore `imported but unused`. 11 | __init__.py: F401 12 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @cbalioglu @rohan-varma @H-Huang 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Create a report to help us improve 4 | labels: bug 5 | --- 6 | 7 | **Describe the bug:** 8 | A clear and concise description of what the bug is. 9 | 10 | **Describe how to reproduce:** 11 | Steps to reproduce the behavior. 12 | 13 | **Describe the expected behavior:** 14 | A clear and concise description of what you expected to happen. 15 | 16 | **Environment:** 17 | - OS: [e.g. Ubuntu 18.04] 18 | - Version [e.g. 0.1.0] 19 | 20 | **Additional context:** 21 | Add any other context about the problem here. 22 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | contact_links: 3 | - name: Ask a Question 4 | url: https://discuss.pytorch.org 5 | about: Ask PyTorch Distributed related questions 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature Request 3 | about: Suggest an idea 4 | labels: enhancement 5 | --- 6 | 7 | **Is your feature request related to a problem? Please describe:** 8 | A clear and concise description of what the problem is. 9 | 10 | **Describe the solution you would like:** 11 | A clear and concise description of what you want to happen. 12 | 13 | **Describe the alternatives you have considered:** 14 | A clear and concise description of any alternative solutions or features you have considered. 15 | 16 | **Additional context:** 17 | Add any other context about the feature request here. 18 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | **What does this PR do? Please describe:** 2 | A summary of the change or the issue that is fixed. 3 | 4 | Fixes #{issue number} 5 | 6 | **Does your PR introduce any breaking changes? If yes, please list them:** 7 | List of all backwards-incompatible API changes. 8 | 9 | **Check list:** 10 | - [ ] Was this **discussed and approved** via a GitHub issue? (not for typos or docs) 11 | - [ ] Did you read the [contributor guideline](https://github.com/pytorch/torchdistx/blob/main/CONTRIBUTING.md)? 12 | - [ ] Did you make sure that your **PR does only one thing** instead of bundling different changes together? 13 | - [ ] Did you make sure to **update the documentation** with your changes? (if necessary) 14 | - [ ] Did you write any **new necessary tests**? 15 | - [ ] Did you verify new and **existing tests pass** locally with your changes? 16 | - [ ] Did you **update the [CHANGELOG](https://github.com/pytorch/torchdistx/blob/main/CHANGELOG.md)**? (not for typos, docs, or minor internal changes) 17 | -------------------------------------------------------------------------------- /.github/workflows/_build_conda.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Reusable - Build the Conda packages 8 | 9 | on: 10 | workflow_call: 11 | inputs: 12 | matrix: 13 | type: string 14 | required: true 15 | dev_stamp: 16 | type: boolean 17 | default: false 18 | 19 | defaults: 20 | run: 21 | shell: bash 22 | 23 | jobs: 24 | build_conda: 25 | name: Build the Conda packages 26 | runs-on: ubuntu-18.04 27 | container: 28 | image: ghcr.io/pytorch/torchdistx-ci-conda:2-${{ matrix.build_variant }} 29 | strategy: 30 | matrix: ${{ fromJSON(inputs.matrix) }} 31 | steps: 32 | - name: Check-out the repository 33 | uses: actions/checkout@v3 34 | with: 35 | submodules: recursive 36 | - name: Stamp the package version with the current date 37 | if: inputs.dev_stamp 38 | run: | 39 | version=$(cat VERSION) 40 | 41 | scripts/set-version ${version/-*} dev $(date +%Y%m%d) 42 | - name: Run Conda Build 43 | working-directory: packaging/conda 44 | env: 45 | BUILD_VARIANT: ${{ matrix.build_variant }} 46 | SANITIZER: ${{ matrix.sanitizer }} 47 | run: | 48 | mkdir ~/conda-build 49 | 50 | variants="--python ${{ matrix.py }}" 51 | 52 | if [[ $BUILD_VARIANT != "cpu" ]]; then 53 | variants+=" --variant-config-files variants/$BUILD_VARIANT.yaml" 54 | fi 55 | 56 | if [[ $SANITIZER != "nosan" ]]; then 57 | variants+=" --variants {sanitizers:[\"${SANITIZER/_/;}\"]} --no-test" 58 | fi 59 | 60 | conda build $variants\ 61 | --channel pytorch-nightly\ 62 | --channel conda-forge\ 63 | --output-folder ~/conda-build\ 64 | --no-include-recipe\ 65 | . 66 | - name: Upload the Conda build output to staging 67 | uses: actions/upload-artifact@v3 68 | with: 69 | name: conda-build-py${{ matrix.py }}-${{ matrix.build_variant }}-${{ matrix.sanitizer }} 70 | path: ~/conda-build 71 | retention-days: 1 72 | -------------------------------------------------------------------------------- /.github/workflows/_build_doc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Reusable - Build the documentation 8 | 9 | on: 10 | workflow_call: 11 | 12 | defaults: 13 | run: 14 | shell: bash 15 | 16 | jobs: 17 | build_doc: 18 | name: Build the documentation 19 | runs-on: ubuntu-18.04 20 | container: 21 | image: ghcr.io/pytorch/torchdistx-ci-clang:13 22 | steps: 23 | - name: Check-out the repository 24 | uses: actions/checkout@v3 25 | with: 26 | submodules: recursive 27 | - name: Set up the Python virtual environment 28 | run: | 29 | python3.8 -m venv ~/venvs/docs 30 | 31 | source ~/venvs/docs/bin/activate 32 | 33 | pip install --requirement use-cpu.txt\ 34 | --requirement requirements.txt\ 35 | --requirement docs/requirements.txt\ 36 | --no-cache-dir 37 | - name: Build the library 38 | run: | 39 | source ~/venvs/docs/bin/activate 40 | 41 | cmake -GNinja\ 42 | -DCMAKE_BUILD_TYPE=Release\ 43 | -DTORCHDIST_TREAT_WARNINGS_AS_ERRORS=ON\ 44 | -B build 45 | 46 | cmake --build build 47 | - name: Install the Wheel package locally 48 | run: | 49 | source ~/venvs/docs/bin/activate 50 | 51 | pip install --editable . 52 | - name: Build the documentation 53 | working-directory: docs 54 | run: | 55 | source ~/venvs/docs/bin/activate 56 | 57 | make html 58 | - name: Copy the version file into the documentation 59 | run: | 60 | cp VERSION docs/build/html 61 | - name: Upload the documentation to staging 62 | uses: actions/upload-artifact@v3 63 | with: 64 | name: docs 65 | path: docs/build/html 66 | retention-days: 1 67 | -------------------------------------------------------------------------------- /.github/workflows/_build_wheel.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Reusable - Build the Wheel package 8 | 9 | on: 10 | workflow_call: 11 | inputs: 12 | matrix: 13 | type: string 14 | required: true 15 | dev_stamp: 16 | type: boolean 17 | default: false 18 | 19 | defaults: 20 | run: 21 | shell: bash 22 | 23 | jobs: 24 | build_wheel: 25 | name: Build the Wheel package 26 | runs-on: ubuntu-18.04 27 | container: 28 | image: ghcr.io/pytorch/torchdistx-ci-wheel:2-${{ matrix.build_variant }} 29 | strategy: 30 | matrix: ${{ fromJSON(inputs.matrix) }} 31 | steps: 32 | - name: Check-out the repository 33 | uses: actions/checkout@v3 34 | with: 35 | submodules: recursive 36 | - name: Stamp the package version with the current date 37 | if: inputs.dev_stamp 38 | run: | 39 | version=$(cat VERSION) 40 | 41 | scripts/set-version ${version/-*} dev $(date +%Y%m%d) 42 | - name: Set up the Python virtual environment 43 | run: | 44 | python${{ matrix.py }} -m venv ~/venvs/build 45 | 46 | source ~/venvs/build/bin/activate 47 | 48 | pip install --requirement use-${{ matrix.build_variant }}.txt\ 49 | --requirement requirements.txt\ 50 | --no-cache-dir 51 | - name: Build the library 52 | env: 53 | SANITIZER: ${{ matrix.sanitizer }} 54 | run: | 55 | source ~/venvs/build/bin/activate 56 | 57 | if [[ $SANITIZER == "nosan" ]]; then 58 | unset SANITIZER 59 | fi 60 | 61 | cmake -DCMAKE_BUILD_TYPE=Release\ 62 | -DTORCHDIST_TREAT_WARNINGS_AS_ERRORS=ON\ 63 | -DTORCHDIST_DEVELOP_PYTHON=OFF\ 64 | -DTORCHDIST_INSTALL_STANDALONE=ON\ 65 | -DTORCHDIST_SANITIZERS="${SANITIZER/_/;}"\ 66 | -B build 67 | 68 | cmake --build build -j $(nproc) 69 | - name: Create the Wheel package 70 | run: | 71 | source ~/venvs/build/bin/activate 72 | 73 | pip wheel .\ 74 | --build-option --plat-name\ 75 | --build-option manylinux_2_17_x86_64\ 76 | --no-deps\ 77 | --wheel-dir ~/wheelhouse 78 | - name: Upload the Wheel package to staging 79 | uses: actions/upload-artifact@v3 80 | with: 81 | name: wheel-py${{ matrix.py }}-${{ matrix.build_variant }}-${{ matrix.sanitizer }} 82 | path: ~/wheelhouse 83 | retention-days: 1 84 | -------------------------------------------------------------------------------- /.github/workflows/_deploy.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Reusable - Deploy 8 | 9 | on: 10 | workflow_call: 11 | inputs: 12 | matrix: 13 | type: string 14 | required: true 15 | s3_wheel_path: 16 | type: string 17 | required: true 18 | doc_folder_override: 19 | type: string 20 | secrets: 21 | anaconda_token: 22 | required: true 23 | aws_key_id: 24 | required: true 25 | aws_access_key: 26 | required: true 27 | 28 | defaults: 29 | run: 30 | shell: bash 31 | 32 | jobs: 33 | deploy_doc: 34 | name: Deploy the documentation 35 | runs-on: ubuntu-18.04 36 | steps: 37 | - name: Download the documentation from staging 38 | uses: actions/download-artifact@v3 39 | with: 40 | name: docs 41 | path: ~/docs 42 | - name: Check-out the gh-pages branch of the repository 43 | uses: actions/checkout@v3 44 | with: 45 | ref: gh-pages 46 | - name: Set up Git 47 | run: | 48 | # See https://github.com/actions/checkout/issues/766. 49 | git config --global --add safe.directory "$GITHUB_WORKSPACE" 50 | 51 | git config user.name "github-actions" 52 | git config user.email "github-actions@github.com" 53 | - name: Commit and push the documentation 54 | env: 55 | DOC_FOLDER_OVERRIDE: ${{ inputs.doc_folder_override }} 56 | run: | 57 | rsync --recursive --delete-after ~/docs/ ${DOC_FOLDER_OVERRIDE:-$(cat ~/docs/VERSION)} 58 | 59 | git add --all 60 | 61 | if ! git diff --staged --quiet; then 62 | git commit --message "Documentation generated from $(git rev-parse --short "$GITHUB_SHA")" 63 | git push 64 | fi 65 | 66 | deploy_conda: 67 | name: Deploy the Conda packages 68 | needs: deploy_doc 69 | runs-on: ubuntu-18.04 70 | container: 71 | image: ghcr.io/pytorch/torchdistx-ci-conda:2-cpu 72 | strategy: 73 | matrix: ${{ fromJSON(inputs.matrix) }} 74 | max-parallel: 1 75 | steps: 76 | - name: Download the Conda build output from staging 77 | uses: actions/download-artifact@v3 78 | with: 79 | name: conda-build-py${{ matrix.py }}-${{ matrix.build_variant }}-nosan 80 | path: ~/conda-build 81 | - name: Upload the Conda packages to Anaconda 82 | run: | 83 | find ~/conda-build -name '*.tar.bz2' -type f\ 84 | -exec anaconda --token ${{ secrets.anaconda_token }} upload --force '{}' \+ 85 | 86 | deploy_wheel: 87 | name: Deploy the Wheel package 88 | needs: deploy_conda 89 | runs-on: ubuntu-18.04 90 | container: 91 | image: ghcr.io/pytorch/torchdistx-ci-wheel:2-cpu 92 | strategy: 93 | matrix: ${{ fromJSON(inputs.matrix) }} 94 | max-parallel: 1 95 | steps: 96 | - name: Download the Wheel package from staging 97 | uses: actions/download-artifact@v3 98 | with: 99 | name: wheel-py${{ matrix.py }}-${{ matrix.build_variant }}-nosan 100 | path: ~/wheelhouse 101 | - name: Upload the Wheel package to S3 102 | env: 103 | AWS_ACCESS_KEY_ID: ${{ secrets.aws_key_id }} 104 | AWS_SECRET_ACCESS_KEY: ${{ secrets.aws_access_key }} 105 | AWS_DEFAULT_REGION: us-east-1 106 | run: | 107 | for pkg in ~/wheelhouse/*.whl; do 108 | aws s3 cp "$pkg" "s3://${{ inputs.s3_wheel_path }}/${{ matrix.build_variant }}/" --acl public-read 109 | done 110 | -------------------------------------------------------------------------------- /.github/workflows/_lint.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Reusable - Lint check 8 | 9 | on: 10 | workflow_call: 11 | 12 | defaults: 13 | run: 14 | shell: bash 15 | 16 | jobs: 17 | lint_check: 18 | name: Lint check 19 | runs-on: ubuntu-18.04 20 | container: 21 | image: ghcr.io/pytorch/torchdistx-ci-clang:13 22 | steps: 23 | - name: Check-out the repository 24 | uses: actions/checkout@v3 25 | with: 26 | submodules: recursive 27 | - name: Set up the Python virtual environment 28 | run: | 29 | python3.8 -m venv ~/venvs/lint 30 | 31 | source ~/venvs/lint/bin/activate 32 | 33 | pip install --requirement use-cpu.txt\ 34 | --requirement requirements-devel.txt\ 35 | --no-cache-dir 36 | - id: run_cmake_config 37 | name: Configure the CMake project 38 | run: | 39 | source ~/venvs/lint/bin/activate 40 | 41 | cmake -DCMAKE_CXX_COMPILER=clang++-13 -B build 42 | - name: Run clang-format 43 | if: always() && steps.run_cmake_config.outcome == 'success' 44 | run: | 45 | source ~/venvs/lint/bin/activate 46 | 47 | find src tests -name '*.cc' -type f\ 48 | -exec clang-format-13 --Werror --dry-run '{}' \+ 49 | - name: Run clang-tidy 50 | if: always() && steps.run_cmake_config.outcome == 'success' 51 | run: | 52 | source ~/venvs/lint/bin/activate 53 | 54 | find src tests -name '*.cc' -type f\ 55 | -exec clang-tidy-13 --warnings-as-errors='*' -p=build '{}' \+ 56 | - name: Run flake8 57 | if: always() && steps.run_cmake_config.outcome == 'success' 58 | run: | 59 | source ~/venvs/lint/bin/activate 60 | 61 | flake8 setup.py src tests 62 | - name: Run black 63 | if: always() && steps.run_cmake_config.outcome == 'success' 64 | run: | 65 | source ~/venvs/lint/bin/activate 66 | 67 | black --check setup.py src tests 68 | - name: Run isort 69 | if: always() && steps.run_cmake_config.outcome == 'success' 70 | run: | 71 | source ~/venvs/lint/bin/activate 72 | 73 | isort --check-only setup.py src tests 74 | - name: Run mypy 75 | if: always() && steps.run_cmake_config.outcome == 'success' 76 | run: | 77 | source ~/venvs/lint/bin/activate 78 | 79 | mypy --pretty --show-error-codes setup.py src tests 80 | - name: Run shellcheck 81 | if: always() && steps.run_cmake_config.outcome == 'success' 82 | run: | 83 | source ~/venvs/lint/bin/activate 84 | 85 | shellcheck --severity=warning scripts/*\ 86 | docker/ci-*/install-*\ 87 | packaging/conda/*.sh 88 | -------------------------------------------------------------------------------- /.github/workflows/_test_conda.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Reusable - Test the Conda packages 8 | 9 | on: 10 | workflow_call: 11 | inputs: 12 | matrix: 13 | type: string 14 | required: true 15 | 16 | defaults: 17 | run: 18 | shell: bash 19 | 20 | jobs: 21 | test_conda: 22 | name: Test the Conda packages 23 | runs-on: ubuntu-18.04 24 | container: 25 | image: ghcr.io/pytorch/torchdistx-ci-conda:2-cpu 26 | strategy: 27 | matrix: ${{ fromJSON(inputs.matrix) }} 28 | steps: 29 | - name: Check-out the repository 30 | uses: actions/checkout@v3 31 | with: 32 | submodules: recursive 33 | - name: Download the Conda build output from staging 34 | uses: actions/download-artifact@v3 35 | with: 36 | name: conda-build-py${{ matrix.py }}-${{ matrix.build_variant }}-${{ matrix.sanitizer }} 37 | path: ~/conda-build 38 | - name: Set up the Conda environment 39 | run: | 40 | conda create --yes\ 41 | --name test\ 42 | --channel ~/conda-build\ 43 | --channel pytorch-nightly\ 44 | --channel conda-forge\ 45 | numpy\ 46 | expecttest==0.1.3\ 47 | pytest==7.0.1\ 48 | python==${{ matrix.py }}\ 49 | torchdistx 50 | - name: Set the sanitizer variables 51 | if: matrix.sanitizer != 'nosan' 52 | env: 53 | BUILD_VARIANT: ${{ matrix.build_variant }} 54 | SANITIZER: ${{ matrix.sanitizer }} 55 | run: | 56 | { 57 | conda_prefix=/root/miniconda3/envs/test 58 | 59 | if [[ $SANITIZER == "asan_ubsan" ]]; then 60 | if [[ $BUILD_VARIANT == "cu102" || $BUILD_VARIANT == "cu113" ]]; then 61 | asan_ver=5 62 | else 63 | asan_ver=6 64 | fi 65 | 66 | echo "SANITIZER_LIBRARY=$conda_prefix/lib/libasan.so.$asan_ver" 67 | elif [[ $SANITIZER == "tsan" ]]; then 68 | echo "SANITIZER_LIBRARY=$conda_prefix/lib/libtsan.so.0" 69 | fi 70 | 71 | # Sanitizer Options 72 | if [[ $SANITIZER == "asan_ubsan" ]]; then 73 | echo "LSAN_OPTIONS=suppressions=LSan.supp,exitcode=0,log_path=$HOME/asan.out" 74 | fi 75 | } >> $GITHUB_ENV 76 | - name: Run the Python tests 77 | env: 78 | SANITIZER: ${{ matrix.sanitizer }} 79 | run: | 80 | conda run --name test env LD_PRELOAD=$SANITIZER_LIBRARY pytest tests 81 | 82 | # Unfortunately Python leaks quite a bit of memory, so we cannot rely 83 | # on the output of LSan. Instead we use a rudimentary way to find out 84 | # whether we have any leakage caused by our tests. We simply check if 85 | # any stack frame has a symbol containing the word 'torchdistx'. 86 | if [[ $SANITIZER == "asan_ubsan" ]]; then 87 | if find ~ -maxdepth 1 -name 'asan.out.*' -exec cat '{}' \+ | tee /dev/stderr | grep --quiet 'torchdistx'; then 88 | exit 1 89 | fi 90 | fi 91 | -------------------------------------------------------------------------------- /.github/workflows/_test_wheel.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Reusable - Test the Wheel package 8 | 9 | on: 10 | workflow_call: 11 | inputs: 12 | matrix: 13 | type: string 14 | required: true 15 | 16 | defaults: 17 | run: 18 | shell: bash 19 | 20 | jobs: 21 | test_conda: 22 | name: Test the Wheel package 23 | runs-on: ubuntu-18.04 24 | container: 25 | image: ghcr.io/pytorch/torchdistx-ci-wheel:2-${{ matrix.build_variant }} 26 | strategy: 27 | matrix: ${{ fromJSON(inputs.matrix) }} 28 | steps: 29 | - name: Check-out the repository 30 | uses: actions/checkout@v3 31 | with: 32 | submodules: recursive 33 | - name: Download the Wheel package from staging 34 | uses: actions/download-artifact@v3 35 | with: 36 | name: wheel-py${{ matrix.py }}-${{ matrix.build_variant }}-${{ matrix.sanitizer }} 37 | path: ~/wheelhouse 38 | - name: Set up the Python virtual environment 39 | run: | 40 | python${{ matrix.py }} -m venv ~/venvs/test 41 | 42 | source ~/venvs/test/bin/activate 43 | 44 | pip install ~/wheelhouse/*.whl\ 45 | --requirement use-${{ matrix.build_variant }}.txt\ 46 | --requirement requirements-devel.txt\ 47 | --no-cache-dir 48 | - name: Set the sanitizer variables 49 | if: matrix.sanitizer != 'nosan' 50 | env: 51 | BUILD_VARIANT: ${{ matrix.build_variant }} 52 | SANITIZER: ${{ matrix.sanitizer }} 53 | run: | 54 | { 55 | if [[ $SANITIZER == "asan_ubsan" ]]; then 56 | if [[ $BUILD_VARIANT == "cu102" || $BUILD_VARIANT == "cu113" ]]; then 57 | asan_ver=5 58 | else 59 | asan_ver=6 60 | fi 61 | 62 | echo "SANITIZER_LIBRARY=/usr/lib64/libasan.so.$asan_ver" 63 | elif [[ $SANITIZER == "tsan" ]]; then 64 | echo "SANITIZER_LIBRARY=/usr/lib64/libtsan.so.0" 65 | fi 66 | 67 | # Sanitizer Options 68 | if [[ $SANITIZER == "asan_ubsan" ]]; then 69 | echo "LSAN_OPTIONS=suppressions=LSan.supp,exitcode=0,log_path=$HOME/asan.out" 70 | fi 71 | } >> $GITHUB_ENV 72 | - name: Run the Python tests 73 | env: 74 | SANITIZER: ${{ matrix.sanitizer }} 75 | run: | 76 | source ~/venvs/test/bin/activate 77 | 78 | LD_PRELOAD=$SANITIZER_LIBRARY pytest tests 79 | 80 | # Unfortunately Python leaks quite a bit of memory, so we cannot rely 81 | # on the output of LSan. Instead we use a rudimentary way to find out 82 | # whether we have any leakage caused by our tests. We simply check if 83 | # any stack frame has a symbol containing the word 'torchdistx'. 84 | if [[ $SANITIZER == "asan_ubsan" ]]; then 85 | if find ~ -maxdepth 1 -name 'asan.out.*' -exec cat '{}' \+ | tee /dev/stderr | grep --quiet 'torchdistx'; then 86 | exit 1 87 | fi 88 | fi 89 | -------------------------------------------------------------------------------- /.github/workflows/nightly.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Build and deploy a nightly release 8 | 9 | on: 10 | schedule: 11 | # At 1:15AM UTC on every Monday, Wednesday, and Friday. 12 | - cron: '15 1 * * 1,3,5' 13 | workflow_dispatch: 14 | inputs: 15 | deploy: 16 | type: boolean 17 | default: false 18 | 19 | jobs: 20 | lint_check: 21 | name: Build 22 | uses: ./.github/workflows/_lint.yaml 23 | 24 | build_doc: 25 | name: Build 26 | needs: lint_check 27 | uses: ./.github/workflows/_build_doc.yaml 28 | 29 | build_wheel: 30 | name: Build 31 | needs: lint_check 32 | uses: ./.github/workflows/_build_wheel.yaml 33 | with: 34 | matrix: | 35 | { 36 | py: ['3.8', '3.9', '3.10'], 37 | build_variant: ['cpu', 'cu117', 'cu118'], 38 | sanitizer: ['nosan'], 39 | include: [ 40 | { 41 | py: '3.8', 42 | build_variant: 'cpu', 43 | sanitizer: 'asan_ubsan' 44 | } 45 | ] 46 | } 47 | dev_stamp: true 48 | 49 | build_conda: 50 | name: Build 51 | needs: lint_check 52 | uses: ./.github/workflows/_build_conda.yaml 53 | with: 54 | matrix: | 55 | { 56 | py: ['3.8', '3.9', '3.10'], 57 | build_variant: ['cpu', 'cu117', 'cu118'], 58 | sanitizer: ['nosan'], 59 | include: [ 60 | { 61 | py: '3.8', 62 | build_variant: 'cpu', 63 | sanitizer: 'asan_ubsan' 64 | } 65 | ] 66 | } 67 | dev_stamp: true 68 | 69 | test_wheel_cpu: 70 | name: Test (CPU) 71 | needs: build_wheel 72 | uses: ./.github/workflows/_test_wheel.yaml 73 | with: 74 | matrix: | 75 | { 76 | py: ['3.8', '3.9', '3.10'], 77 | build_variant: ['cpu'], 78 | sanitizer: ['nosan'], 79 | include: [ 80 | { 81 | py: '3.8', 82 | build_variant: 'cpu', 83 | sanitizer: 'asan_ubsan' 84 | } 85 | ] 86 | } 87 | 88 | test_wheel_cu117: 89 | name: Test (CUDA 11.7) 90 | needs: test_wheel_cpu 91 | uses: ./.github/workflows/_test_wheel.yaml 92 | with: 93 | matrix: | 94 | { 95 | py: ['3.8', '3.9', '3.10'], 96 | build_variant: ['cu117'], 97 | sanitizer: ['nosan'] 98 | } 99 | 100 | test_wheel_cu118: 101 | name: Test (CUDA 11.8) 102 | needs: test_wheel_cu117 103 | uses: ./.github/workflows/_test_wheel.yaml 104 | with: 105 | matrix: | 106 | { 107 | py: ['3.8', '3.9', '3.10'], 108 | build_variant: ['cu118'], 109 | sanitizer: ['nosan'] 110 | } 111 | 112 | test_conda_cpu: 113 | name: Test (CPU) 114 | needs: build_conda 115 | uses: ./.github/workflows/_test_conda.yaml 116 | with: 117 | matrix: | 118 | { 119 | py: ['3.8', '3.9', '3.10'], 120 | build_variant: ['cpu'], 121 | sanitizer: ['nosan'], 122 | include: [ 123 | { 124 | py: '3.8', 125 | build_variant: 'cpu', 126 | sanitizer: 'asan_ubsan' 127 | } 128 | ] 129 | } 130 | 131 | test_conda_cu117: 132 | name: Test (CUDA 11.7) 133 | needs: test_conda_cpu 134 | uses: ./.github/workflows/_test_conda.yaml 135 | with: 136 | matrix: | 137 | { 138 | py: ['3.8', '3.9', '3.10'], 139 | build_variant: ['cu117'], 140 | sanitizer: ['nosan'] 141 | } 142 | 143 | test_conda_cu118: 144 | name: Test (CUDA 11.8) 145 | needs: test_conda_cu117 146 | uses: ./.github/workflows/_test_conda.yaml 147 | with: 148 | matrix: | 149 | { 150 | py: ['3.8', '3.9', '3.10'], 151 | build_variant: ['cu118'], 152 | sanitizer: ['nosan'] 153 | } 154 | 155 | deploy: 156 | name: Deploy 157 | if: github.event_name == 'schedule' || github.event.inputs.deploy == 'true' 158 | needs: [build_doc, test_wheel_cu118, test_conda_cu118] 159 | uses: ./.github/workflows/_deploy.yaml 160 | with: 161 | matrix: | 162 | { 163 | py: ['3.8', '3.9', '3.10'], 164 | build_variant: ['cpu', 'cu117', 'cu118'], 165 | } 166 | s3_wheel_path: pytorch/whl/nightly 167 | doc_folder_override: nightly 168 | secrets: 169 | anaconda_token: ${{ secrets.ANACONDA_NIGHTLY_TOKEN }} 170 | aws_key_id: ${{ secrets.AWS_PYTORCH_KEY_ID }} 171 | aws_access_key: ${{ secrets.AWS_PYTORCH_ACCESS_KEY }} 172 | -------------------------------------------------------------------------------- /.github/workflows/push.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Build and test the library 8 | 9 | on: 10 | push: 11 | branches: 12 | - main 13 | paths-ignore: 14 | - 'docker/**' 15 | - 'docs/**' 16 | - '**.md' 17 | pull_request: 18 | paths-ignore: 19 | - 'docker/**' 20 | - 'docs/**' 21 | - '**.md' 22 | workflow_dispatch: 23 | 24 | jobs: 25 | lint_check: 26 | name: Build 27 | uses: ./.github/workflows/_lint.yaml 28 | 29 | build_wheel: 30 | name: Build 31 | needs: lint_check 32 | uses: ./.github/workflows/_build_wheel.yaml 33 | with: 34 | matrix: | 35 | { 36 | py: ['3.8'], 37 | build_variant: ['cpu', 'cu117', 'cu118'], 38 | sanitizer: ['nosan'], 39 | include: [ 40 | { 41 | py: '3.8', 42 | build_variant: 'cpu', 43 | sanitizer: 'asan_ubsan' 44 | } 45 | ] 46 | } 47 | 48 | test_wheel_cpu: 49 | name: Test (CPU) 50 | needs: build_wheel 51 | uses: ./.github/workflows/_test_wheel.yaml 52 | with: 53 | matrix: | 54 | { 55 | py: ['3.8'], 56 | build_variant: ['cpu'], 57 | sanitizer: ['nosan', 'asan_ubsan'] 58 | } 59 | 60 | test_wheel_cu117: 61 | name: Test (CUDA 11.7) 62 | needs: test_wheel_cpu 63 | uses: ./.github/workflows/_test_wheel.yaml 64 | with: 65 | matrix: | 66 | { 67 | include: [ 68 | { 69 | py: '3.8', 70 | build_variant: 'cu117', 71 | sanitizer: 'nosan' 72 | }, 73 | ] 74 | } 75 | 76 | test_wheel_cu118: 77 | name: Test (CUDA 11.8) 78 | needs: test_wheel_cu117 79 | uses: ./.github/workflows/_test_wheel.yaml 80 | with: 81 | matrix: | 82 | { 83 | include: [ 84 | { 85 | py: '3.8', 86 | build_variant: 'cu118', 87 | sanitizer: 'nosan' 88 | }, 89 | ] 90 | } 91 | -------------------------------------------------------------------------------- /.github/workflows/push_doc.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Build the documentation 8 | 9 | on: 10 | push: 11 | branches: 12 | - main 13 | paths: 14 | - 'docs/**' 15 | pull_request_target: 16 | paths: 17 | - 'docs/**' 18 | workflow_dispatch: 19 | 20 | jobs: 21 | build_doc: 22 | name: Build 23 | uses: ./.github/workflows/_build_doc.yaml 24 | -------------------------------------------------------------------------------- /.github/workflows/release.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | name: Build and deploy a release 8 | 9 | on: 10 | workflow_dispatch: 11 | inputs: 12 | deploy: 13 | type: boolean 14 | default: false 15 | 16 | jobs: 17 | lint_check: 18 | name: Build 19 | uses: ./.github/workflows/_lint.yaml 20 | 21 | build_doc: 22 | name: Build 23 | needs: lint_check 24 | uses: ./.github/workflows/_build_doc.yaml 25 | 26 | build_wheel: 27 | name: Build 28 | needs: lint_check 29 | uses: ./.github/workflows/_build_wheel.yaml 30 | with: 31 | matrix: | 32 | { 33 | py: ['3.8', '3.9', '3.10'], 34 | build_variant: ['cpu', 'cu117', 'cu118'], 35 | sanitizer: ['nosan'], 36 | include: [ 37 | { 38 | py: '3.8', 39 | build_variant: 'cpu', 40 | sanitizer: 'asan_ubsan' 41 | } 42 | ] 43 | } 44 | 45 | build_conda: 46 | name: Build 47 | needs: lint_check 48 | uses: ./.github/workflows/_build_conda.yaml 49 | with: 50 | matrix: | 51 | { 52 | py: ['3.8', '3.9', '3.10'], 53 | build_variant: ['cpu', 'cu117', 'cu118'], 54 | sanitizer: ['nosan'], 55 | include: [ 56 | { 57 | py: '3.8', 58 | build_variant: 'cpu', 59 | sanitizer: 'asan_ubsan' 60 | } 61 | ] 62 | } 63 | 64 | test_wheel_cpu: 65 | name: Test (CPU) 66 | needs: build_wheel 67 | uses: ./.github/workflows/_test_wheel.yaml 68 | with: 69 | matrix: | 70 | { 71 | py: ['3.8', '3.9', '3.10'], 72 | build_variant: ['cpu'], 73 | sanitizer: ['nosan'], 74 | include: [ 75 | { 76 | py: '3.8', 77 | build_variant: 'cpu', 78 | sanitizer: 'asan_ubsan' 79 | } 80 | ] 81 | } 82 | 83 | test_wheel_cu117: 84 | name: Test (CUDA 11.7) 85 | needs: test_wheel_cpu 86 | uses: ./.github/workflows/_test_wheel.yaml 87 | with: 88 | matrix: | 89 | { 90 | py: ['3.8', '3.9', '3.10'], 91 | build_variant: ['cu117'], 92 | sanitizer: ['nosan'] 93 | } 94 | 95 | test_wheel_cu118: 96 | name: Test (CUDA 11.8) 97 | needs: test_wheel_cu117 98 | uses: ./.github/workflows/_test_wheel.yaml 99 | with: 100 | matrix: | 101 | { 102 | py: ['3.8', '3.9', '3.10'], 103 | build_variant: ['cu118'], 104 | sanitizer: ['nosan'] 105 | } 106 | 107 | test_conda_cpu: 108 | name: Test (CPU) 109 | needs: build_conda 110 | uses: ./.github/workflows/_test_conda.yaml 111 | with: 112 | matrix: | 113 | { 114 | py: ['3.8', '3.9', '3.10'], 115 | build_variant: ['cpu'], 116 | sanitizer: ['nosan'], 117 | include: [ 118 | { 119 | py: '3.8', 120 | build_variant: 'cpu', 121 | sanitizer: 'asan_ubsan' 122 | } 123 | ] 124 | } 125 | 126 | test_conda_cu117: 127 | name: Test (CUDA 11.7) 128 | needs: test_conda_cpu 129 | uses: ./.github/workflows/_test_conda.yaml 130 | with: 131 | matrix: | 132 | { 133 | py: ['3.8', '3.9', '3.10'], 134 | build_variant: ['cu117'], 135 | sanitizer: ['nosan'] 136 | } 137 | 138 | test_conda_cu118: 139 | name: Test (CUDA 11.8) 140 | needs: test_conda_cu117 141 | uses: ./.github/workflows/_test_conda.yaml 142 | with: 143 | matrix: | 144 | { 145 | py: ['3.8', '3.9', '3.10'], 146 | build_variant: ['cu118'], 147 | sanitizer: ['nosan'] 148 | } 149 | 150 | deploy: 151 | name: Deploy 152 | if: github.event.inputs.deploy == 'true' 153 | needs: [build_doc, test_wheel_cu118, test_conda_cu118] 154 | uses: ./.github/workflows/_deploy.yaml 155 | with: 156 | matrix: | 157 | { 158 | py: ['3.8', '3.9', '3.10'], 159 | build_variant: ['cpu', 'cu117', 'cu118'], 160 | } 161 | s3_wheel_path: pytorch/whl 162 | secrets: 163 | anaconda_token: ${{ secrets.ANACONDA_TOKEN }} 164 | aws_key_id: ${{ secrets.AWS_PYTORCH_KEY_ID }} 165 | aws_access_key: ${{ secrets.AWS_PYTORCH_ACCESS_KEY }} 166 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.so 3 | *.whl 4 | __pycache__ 5 | build-*/ 6 | build/ 7 | dist/ 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third-party/pybind11"] 2 | path = third-party/pybind11 3 | url = https://github.com/pybind/pybind11.git 4 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [0.3.0] - 2022-mm-dd 8 | ### Added 9 | - Adds a `fake_cuda` parameter to `fake_mode()` that allows constructing fake 10 | CUDA tensors even if CUDA is not available. 11 | 12 | ## [0.2.0] - 2022-06-23 13 | ### Added 14 | - Moves to PyTorch v1.12 15 | - Adds support for Python 3.10 16 | 17 | ### Fixed 18 | - Addresses a minor bug in Fake tensor caused by the API changes in PyTorch 19 | 20 | ## [0.1.0] - 2022-04-14 21 | ### Added 22 | - Initial release with Fake Tensor and Deferred Module Initialization 23 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | cmake_minimum_required(VERSION 3.21.0) 8 | 9 | project(torchdistx VERSION 0.3.0 LANGUAGES CXX) 10 | 11 | if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES) 12 | set_property(CACHE CMAKE_BUILD_TYPE PROPERTY VALUE RelWithDebInfo) 13 | endif() 14 | 15 | include(cmake/Helpers.cmake) 16 | 17 | # ------------------------------------------------------------ 18 | # Options 19 | # ------------------------------------------------------------ 20 | 21 | include(CMakeDependentOption) 22 | 23 | if(PROJECT_IS_TOP_LEVEL) 24 | include(CTest) 25 | endif() 26 | 27 | option(TORCHDIST_BUILD_PYTHON "Generates build target for the Python C extension." ON) 28 | option(TORCHDIST_BUILD_FOR_NATIVE "Builds for the processor type of the compiling machine." OFF) 29 | option(TORCHDIST_TREAT_WARNINGS_AS_ERRORS "Treats compilation warnings as errors." OFF) 30 | option(TORCHDIST_PERFORM_LTO "Performs link-time optimization." OFF) 31 | option(TORCHDIST_INSTALL_STANDALONE "Installs with rpath." OFF) 32 | option(TORCHDIST_RUN_CLANG_TIDY "Runs clang-tidy as static analyzer." OFF) 33 | 34 | cmake_dependent_option(TORCHDIST_DEVELOP_PYTHON 35 | #DESCRIPTION 36 | "Copies the Python C extension to the source tree for `setup.py develop` mode." 37 | #VALUE 38 | ON 39 | #DEPENDS_ON 40 | TORCHDIST_BUILD_PYTHON 41 | #FORCE 42 | OFF 43 | ) 44 | 45 | set(TORCHDIST_SANITIZERS 46 | #DEFAULT 47 | "" 48 | #TYPE 49 | CACHE STRING 50 | #DESCRIPTION 51 | "The list of sanitizers with which to build. The supported values are 'asan', 'ubsan', and 'tsan'." 52 | ) 53 | 54 | set_property(CACHE TORCHDIST_SANITIZERS PROPERTY 55 | STRINGS 56 | "" "asan" "ubsan" "tsan" 57 | ) 58 | 59 | # ------------------------------------------------------------ 60 | # Dependencies 61 | # ------------------------------------------------------------ 62 | 63 | if(TORCHDIST_BUILD_PYTHON) 64 | find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module) 65 | 66 | if(Python3_VERSION VERSION_LESS 3.7) 67 | message(FATAL_ERROR "Only CPython 3.7 and later versions are supported.") 68 | endif() 69 | 70 | # We have to ensure that we use a version of pybind11 that is ABI compatible 71 | # with PyTorch's version. 72 | torchdist_add_third_party(pybind11) 73 | else() 74 | find_package(Python3 OPTIONAL COMPONENTS Interpreter) 75 | endif() 76 | 77 | if(Python3_EXECUTABLE) 78 | execute_process( 79 | COMMAND 80 | ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" 81 | OUTPUT_VARIABLE 82 | torch_cmake_prefix_path 83 | OUTPUT_STRIP_TRAILING_WHITESPACE 84 | ) 85 | endif() 86 | 87 | find_package(Torch 1.13 REQUIRED PATHS ${torch_cmake_prefix_path}) 88 | 89 | # ------------------------------------------------------------ 90 | # Targets 91 | # ------------------------------------------------------------ 92 | 93 | add_subdirectory(src/cc/torchdistx) 94 | 95 | if(TORCHDIST_BUILD_PYTHON) 96 | add_subdirectory(src/python/torchdistx/_C) 97 | endif() 98 | 99 | torchdist_install_package(torchdistx 100 | #CONFIG_FILE 101 | src/cc/torchdistx-config.cmake.in 102 | ) 103 | 104 | if(PROJECT_IS_TOP_LEVEL AND BUILD_TESTING) 105 | #TODO: Add catch2 tests. 106 | endif() 107 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to torchdistX 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to text, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /LSan.supp: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | leak:libtorch_python 8 | leak:numpy 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torchdistX - Torch Distributed Experimental 2 | 3 | [**Installation**](#installation) | [**Getting Started**](#getting-started) | [**Documentation**](#documentation) 4 | 5 | Torch Distributed Experimental, or in short torchdistX, contains a collection of 6 | experimental features for which our team wants to gather feedback from our users 7 | before introducing them in the core PyTorch Distributed package. In a sense 8 | features included in torchdistX can be considered in an incubation period. 9 | 10 | Please be advised though that all features in torchdistX are subject to change 11 | and, although our team will make its best effort, we do not guarantee any API 12 | or ABI compatibility between releases. This means you should exercise caution if 13 | you plan to use torchdistX in production. 14 | 15 | As of today the following features are available in torchdistX: 16 | 17 | - [Fake Tensor](https://pytorch.org/torchdistx/latest/fake_tensor.html) 18 | - [Deferred Module Initialization](https://pytorch.org/torchdistx/latest/deferred_init.html) 19 | 20 | ## Dependencies 21 | torchdistX versions corresponding to each PyTorch release: 22 | 23 | | `torch` | `torchdistx` | `python` | 24 | | ------------ | ------------ | ----------------- | 25 | | `main` | `main` | `>=3.8`, `<=3.10` | 26 | | `1.12.0` | `0.2.0` | `>=3.7`, `<=3.10` | 27 | | `1.11.0` | `0.1.0` | `>=3.7`, `<=3.9` | 28 | 29 | ## Installation 30 | As of today only Linux and macOS operating systems are supported. Please note 31 | that pre-built Conda and PyPI packages are *only* available for Linux though. 32 | For installation on macOS you can follow the instructions in the [From Source](#from-source) 33 | section. At this time there are no plans to introduce Windows support. 34 | 35 | ### Conda 36 | Conda is the recommended way to install torchdistX. Running the following 37 | command in a Conda environment will install torchdistX and all its dependencies. 38 | 39 | **Stable** 40 | 41 | For PyTorch CPU: 42 | ``` 43 | conda install -c pytorch -c conda-forge torchdistx cpuonly 44 | ``` 45 | 46 | For PyTorch with CUDA 10.2: 47 | ``` 48 | conda install -c pytorch -c conda-forge torchdistx cudatoolkit=10.2 49 | ``` 50 | 51 | For PyTorch with CUDA 11.3: 52 | ``` 53 | conda install -c pytorch -c conda-forge torchdistx cudatoolkit=11.3 54 | ``` 55 | 56 | For PyTorch with CUDA 11.6: 57 | ``` 58 | conda install -c pytorch -c conda-forge torchdistx cudatoolkit=11.6 59 | ``` 60 | 61 | **Nightly** 62 | 63 | For PyTorch CPU 64 | ``` 65 | conda install -c pytorch-nightly -c conda-forge torchdistx cpuonly 66 | ``` 67 | 68 | For PyTorch with CUDA 10.2 69 | ``` 70 | conda install -c pytorch-nightly -c conda-forge torchdistx cudatoolkit=10.2 71 | ``` 72 | 73 | For PyTorch with CUDA 11.3 74 | ``` 75 | conda install -c pytorch-nightly -c conda-forge torchdistx cudatoolkit=11.3 76 | ``` 77 | 78 | For PyTorch with CUDA 11.6 79 | ``` 80 | conda install -c pytorch-nightly -c conda-forge torchdistx cudatoolkit=11.6 81 | ``` 82 | 83 | In fact torchdistX offers several Conda packages that you can install 84 | independently based on your needs: 85 | 86 | | Package | Description | 87 | |-------------------------------------------------------------------------|--------------------------------------------------| 88 | | [torchdistx](https://anaconda.org/pytorch/torchdistx) | torchdistX Python Library | 89 | | [torchdistx-cc](https://anaconda.org/pytorch/torchdistx-cc) | torchdistX C++ Runtime Library | 90 | | [torchdistx-cc-devel](https://anaconda.org/pytorch/torchdistx-cc-devel) | torchdistX C++ Runtime Library Development Files | 91 | | [torchdistx-cc-debug](https://anaconda.org/pytorch/torchdistx-cc-debug) | torchdistX C++ Runtime Library Debug Symbols | 92 | 93 | ### PyPI 94 | 95 | **Stable** 96 | 97 | For PyTorch CPU: 98 | ``` 99 | pip install torchdistx --extra-index-url https://download.pytorch.org/whl/cpu 100 | ``` 101 | 102 | For PyTorch with CUDA 10.2: 103 | ``` 104 | pip install torchdistx --extra-index-url https://download.pytorch.org/whl/cu102 105 | ``` 106 | 107 | For PyTorch with CUDA 11.3: 108 | ``` 109 | pip install torchdistx --extra-index-url https://download.pytorch.org/whl/cu113 110 | ``` 111 | 112 | For PyTorch with CUDA 11.6: 113 | ``` 114 | pip install torchdistx --extra-index-url https://download.pytorch.org/whl/cu116 115 | ``` 116 | 117 | **Nightly** 118 | 119 | For PyTorch CPU: 120 | ``` 121 | pip install torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cpu 122 | ``` 123 | 124 | For PyTorch with CUDA 10.2: 125 | ``` 126 | pip install torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu102 127 | ``` 128 | 129 | For PyTorch with CUDA 11.3: 130 | ``` 131 | pip install torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu113 132 | ``` 133 | 134 | For PyTorch with CUDA 11.6: 135 | ``` 136 | pip install torchdistx --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu116 137 | ``` 138 | 139 | ### From Source 140 | 141 | #### Prerequisites 142 | - After cloning the repository make sure to initialize all submodules by 143 | executing `git submodule update --init --recursive`. 144 | - Create a Python virtual environment and install the build dependencies: 145 | ``` 146 | # Build against PyTorch CPU 147 | pip install --upgrade -r requirements.txt -r use-cpu.txt 148 | 149 | # Build against PyTorch with CUDA 10.2 150 | pip install --upgrade -r requirements.txt -r use-cu102.txt 151 | 152 | # Build against PyTorch with CUDA 11.3 153 | pip install --upgrade -r requirements.txt -r use-cu113.txt 154 | 155 | # Build against PyTorch with CUDA 11.6 156 | pip install --upgrade -r requirements.txt -r use-cu116.txt 157 | ``` 158 | - The build process requires CMake 3.21 or later. You can install an up-to-date 159 | version by executing `pip install cmake`. For other environments please refer 160 | to your package manager or [cmake.org](https://cmake.org/download/). 161 | 162 | Once you have all prerequisites run the following commands to install the 163 | torchdistX Python package: 164 | 165 | ``` 166 | cmake -DTORCHDIST_INSTALL_STANDALONE=ON -B build 167 | cmake --build build 168 | pip install . 169 | ``` 170 | 171 | For advanced build options you can check out [CMakeLists.txt](./CMakeLists.txt). 172 | 173 | #### Development 174 | In case you would like to contribute to the project you can slightly modify the 175 | commands listed above: 176 | 177 | ``` 178 | cmake -B build 179 | cmake --build build 180 | pip install -e . 181 | ``` 182 | 183 | With `pip install -e .` you enable the edit mode (a.k.a. develop mode) that 184 | allows you to modify the Python files in-place without requiring to repeatedly 185 | install the package. If you are working in C++, whenever you modify a header or 186 | implementation file, executing `cmake --build build` alone is sufficient. You do 187 | not have to call `pip install` again. 188 | 189 | The project also comes with a [requirements-devel.txt](./requirements-devel.txt) 190 | to set up a Python virtual environment for development. 191 | 192 | ``` 193 | # Build against PyTorch CPU 194 | pip install --upgrade -r requirements-devel.txt -r use-cpu.txt 195 | 196 | # Build against PyTorch with CUDA 10.2 197 | pip install --upgrade -r requirements-devel.txt -r use-cu102.txt 198 | 199 | # Build against PyTorch with CUDA 11.3 200 | pip install --upgrade -r requirements-devel.txt -r use-cu113.txt 201 | 202 | # Build against PyTorch with CUDA 11.6 203 | pip install --upgrade -r requirements-devel.txt -r use-cu116.txt 204 | ``` 205 | 206 | #### Tip 207 | Note that using the Ninja build system and the ccache tool can significatly 208 | speed up your build times. To use them you can replace the initial CMake command 209 | listed above with the following version: 210 | 211 | ``` 212 | cmake -GNinja -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -B build 213 | ``` 214 | 215 | ## Getting Started 216 | 217 | ### Fake Tensor 218 | Fake tensors, similar to meta tensors, carry no data; however, unlike meta 219 | tensors which report `meta` as their device, fake tensors act as if they were 220 | allocated on a real device. In the example below we construct two fake tensors 221 | with the `fake_mode` context manager. 222 | 223 | ```python 224 | >>> import torch 225 | >>> from torchdistx import fake 226 | >>> 227 | >>> with fake.fake_mode(): 228 | ... a = torch.ones([10]) 229 | ... b = torch.ones([20], device="cuda") 230 | ... 231 | >>> a 232 | tensor(..., size=(10,), fake=True) 233 | >>> b 234 | tensor(..., size=(20,), device=cuda, fake=True) 235 | ``` 236 | 237 | ### Deferred Module Initialization 238 | This feature forces all tensors of a module to be constructed as fake while also 239 | recording all operations performed on them. The module, its submodules, and its 240 | tensors can later be materialized by calling the `materialize_module()` and 241 | `materialize_tensor()` functions. 242 | 243 | ```python 244 | >>> import torch 245 | >>> from torchdistx import deferred_init 246 | >>> 247 | >>> m = deferred_init.deferred_init(torch.nn.Linear, 10, 20) 248 | >>> m.weight 249 | Parameter containing: 250 | tensor(..., size=(20, 10), requires_grad=True, fake=True) 251 | >>> 252 | >>> deferred_init.materialize_module(m) 253 | >>> m.weight 254 | Parameter containing: 255 | tensor([[-0.1838, -0.0080, 0.0747, -0.1663, -0.0936, 0.0587, 0.1988, -0.0977, 256 | -0.1433, 0.2620], 257 | ..., requires_grad=True) 258 | ``` 259 | 260 | ## Documentation 261 | For more documentation, see [our docs website](https://pytorch.org/torchdistx/latest). 262 | 263 | ## Contributing 264 | Please refer to [CONTRIBUTING.md](./CONTRIBUTING.md). 265 | 266 | ## License 267 | This project is BSD licensed, as found in the [LICENSE](LICENSE) file. 268 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.3.0-dev 2 | -------------------------------------------------------------------------------- /cmake/Helpers.cmake: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | include_guard(GLOBAL) 8 | 9 | include(CMakePackageConfigHelpers) 10 | include(GNUInstallDirs) 11 | 12 | function(torchdist_add_target target) 13 | cmake_parse_arguments(arg 14 | #OPTIONS 15 | "EXECUTABLE;LIBRARY;SHARED_LIBRARY;STATIC_LIBRARY;PYTHON_MODULE" 16 | #KEYWORDS 17 | "OUTPUT_NAME" 18 | #MULTI_VALUE_KEYWORDS 19 | "" 20 | #ARGUMENTS 21 | ${ARGN} 22 | ) 23 | 24 | if(arg_EXECUTABLE) 25 | add_executable(${target}) 26 | elseif(arg_PYTHON_MODULE) 27 | if(NOT COMMAND Python3_add_library) 28 | message(FATAL_ERROR "Python3 must be loaded before calling torchdist_add_target()!") 29 | endif() 30 | 31 | Python3_add_library(${target} WITH_SOABI) 32 | else() 33 | if(arg_LIBRARY) 34 | set(lib_type) 35 | elseif(arg_SHARED_LIBRARY) 36 | set(lib_type SHARED) 37 | elseif(arg_STATIC_LIBRARY) 38 | set(lib_type STATIC) 39 | else() 40 | message(FATAL_ERROR "torchdist_add_target() has an invalid target type!") 41 | endif() 42 | 43 | add_library(${target} ${lib_type}) 44 | endif() 45 | 46 | cmake_path(GET CMAKE_CURRENT_SOURCE_DIR 47 | PARENT_PATH 48 | source_parent_dir 49 | ) 50 | 51 | if(arg_LIBRARY OR arg_SHARED_LIBRARY OR arg_STATIC_LIBRARY) 52 | if(PROJECT_IS_TOP_LEVEL) 53 | set(system) 54 | else() 55 | set(system SYSTEM) 56 | endif() 57 | 58 | target_include_directories(${target} ${system} 59 | INTERFACE 60 | $ 61 | ) 62 | endif() 63 | 64 | # ------------------------------------------------------------ 65 | # Properties 66 | # ------------------------------------------------------------ 67 | 68 | set_target_properties(${target} PROPERTIES 69 | C_EXTENSIONS 70 | OFF 71 | C_VISIBILITY_PRESET 72 | hidden 73 | CXX_EXTENSIONS 74 | OFF 75 | CXX_VISIBILITY_PRESET 76 | hidden 77 | CUDA_EXTENSIONS 78 | OFF 79 | CUDA_VISIBILITY_PRESET 80 | hidden 81 | POSITION_INDEPENDENT_CODE 82 | ON 83 | EXPORT_COMPILE_COMMANDS 84 | ON 85 | ) 86 | 87 | if(arg_SHARED_LIBRARY AND NOT TORCHDIST_INSTALL_STANDALONE) 88 | set_target_properties(${target} PROPERTIES 89 | VERSION 90 | ${PROJECT_VERSION} 91 | SOVERSION 92 | ${PROJECT_VERSION_MAJOR} 93 | ) 94 | endif() 95 | 96 | if(arg_OUTPUT_NAME) 97 | set_target_properties(${target} PROPERTIES 98 | OUTPUT_NAME 99 | ${arg_OUTPUT_NAME} 100 | ) 101 | endif() 102 | 103 | if(TORCHDIST_PERFORM_LTO) 104 | set_target_properties(${target} PROPERTIES 105 | INTERPROCEDURAL_OPTIMIZATION 106 | ON 107 | ) 108 | 109 | if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 110 | torchdist_set_macos_lto_path(${target}) 111 | endif() 112 | endif() 113 | 114 | if(arg_PYTHON_MODULE AND TORCHDIST_DEVELOP_PYTHON) 115 | set_target_properties(${target} PROPERTIES 116 | BUILD_RPATH_USE_ORIGIN 117 | OFF 118 | ) 119 | 120 | add_custom_command( 121 | TARGET 122 | ${target} 123 | POST_BUILD 124 | COMMAND 125 | ${CMAKE_COMMAND} -E copy "$" "${source_parent_dir}" 126 | VERBATIM 127 | ) 128 | endif() 129 | 130 | torchdist_enable_clang_tidy(${target}) 131 | 132 | # ------------------------------------------------------------ 133 | # Compiler Settings 134 | # ------------------------------------------------------------ 135 | 136 | if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") 137 | target_compile_options(${target} 138 | PRIVATE 139 | -fasynchronous-unwind-tables -fstack-protector-strong 140 | ) 141 | 142 | if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 143 | if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7) 144 | message(FATAL_ERROR "Only GCC 7 and later versions are supported!") 145 | endif() 146 | 147 | target_compile_options(${target} 148 | PRIVATE 149 | -Wall 150 | -Wcast-align 151 | -Wconversion 152 | -Wdouble-promotion 153 | -Wextra 154 | -Wfloat-equal 155 | -Wformat=2 156 | -Winit-self 157 | -Wlogical-op 158 | -Wno-unknown-pragmas 159 | -Wpointer-arith 160 | -Wshadow 161 | -Wsign-conversion 162 | -Wswitch-enum 163 | -Wunused 164 | $<$:-Wnon-virtual-dtor> 165 | $<$:-Wold-style-cast> 166 | $<$:-Woverloaded-virtual> 167 | $<$:-Wuseless-cast> 168 | ) 169 | 170 | target_compile_definitions(${target} 171 | PRIVATE 172 | $<$:_GLIBCXX_ASSERTIONS> 173 | ) 174 | else() 175 | if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7) 176 | message(FATAL_ERROR "Only Clang 7 and later versions are supported!") 177 | endif() 178 | 179 | target_compile_options(${target} 180 | PRIVATE 181 | -fsized-deallocation 182 | -Weverything 183 | -Wno-c++98-compat 184 | -Wno-c++98-compat-pedantic 185 | -Wno-exit-time-destructors 186 | -Wno-extra-semi-stmt 187 | -Wno-global-constructors 188 | -Wno-padded 189 | -Wno-return-std-move-in-c++11 190 | -Wno-shadow-uncaptured-local 191 | ) 192 | endif() 193 | 194 | if(TORCHDIST_TREAT_WARNINGS_AS_ERRORS) 195 | target_compile_options(${target} 196 | PRIVATE 197 | -Werror 198 | ) 199 | endif() 200 | 201 | if(TORCHDIST_BUILD_FOR_NATIVE) 202 | target_compile_options(${target} 203 | PRIVATE 204 | -march=native -mtune=native 205 | ) 206 | endif() 207 | 208 | target_compile_definitions(${target} 209 | PRIVATE 210 | $<$>:_FORTIFY_SOURCE=2> 211 | ) 212 | else() 213 | message(FATAL_ERROR "Only GCC and Clang toolchains are supported!") 214 | endif() 215 | 216 | # ------------------------------------------------------------ 217 | # Linker Settings 218 | # ------------------------------------------------------------ 219 | 220 | if(CMAKE_SYSTEM_NAME STREQUAL "Linux") 221 | target_link_options(${target} 222 | PRIVATE 223 | LINKER:--as-needed 224 | LINKER:--build-id=sha1 225 | LINKER:-z,noexecstack 226 | LINKER:-z,now 227 | LINKER:-z,relro 228 | ) 229 | 230 | if(NOT arg_PYTHON_MODULE) 231 | target_link_options(${target} 232 | PRIVATE 233 | LINKER:-z,defs 234 | ) 235 | endif() 236 | 237 | if(TORCHDIST_TREAT_WARNINGS_AS_ERRORS) 238 | target_link_options(${target} 239 | PRIVATE 240 | LINKER:--fatal-warnings 241 | ) 242 | endif() 243 | elseif(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 244 | target_link_options(${target} 245 | PRIVATE 246 | LINKER:-bind_at_load 247 | ) 248 | 249 | if(arg_PYTHON_MODULE) 250 | target_link_options(${target} 251 | PRIVATE 252 | LINKER:-undefined,dynamic_lookup 253 | ) 254 | else() 255 | target_link_options(${target} 256 | PRIVATE 257 | LINKER:-undefined,error 258 | ) 259 | endif() 260 | 261 | # Conda Build sets the `-pie` option in `LDFLAGS` which causes a linker warning for library 262 | # targets. When warnings are treated as errors, this becomes a build failure. 263 | if(NOT arg_EXECUTABLE) 264 | target_link_options(${target} 265 | PRIVATE 266 | LINKER:-no_pie 267 | ) 268 | endif() 269 | 270 | if(TORCHDIST_TREAT_WARNINGS_AS_ERRORS) 271 | target_link_options(${target} 272 | PRIVATE 273 | LINKER:-fatal_warnings 274 | ) 275 | endif() 276 | else() 277 | message(FATAL_ERROR "Only Linux and macOS operating systems are supported!") 278 | endif() 279 | 280 | # ------------------------------------------------------------ 281 | # Sanitizers 282 | # ------------------------------------------------------------ 283 | 284 | if(TORCHDIST_SANITIZERS) 285 | string(TOLOWER "${TORCHDIST_SANITIZERS}" 286 | #OUTPUT 287 | sanitizer_types 288 | ) 289 | 290 | foreach(sanitizer_type IN ITEMS ${sanitizer_types}) 291 | if(sanitizer_type STREQUAL "asan") 292 | if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 293 | target_compile_definitions(${target} 294 | PRIVATE 295 | _GLIBCXX_SANITIZE_VECTOR 296 | ) 297 | endif() 298 | 299 | list(APPEND sanitizers -fsanitize=address) 300 | elseif(sanitizer_type STREQUAL "ubsan") 301 | list(APPEND sanitizers -fsanitize=undefined) 302 | elseif(sanitizer_type STREQUAL "tsan") 303 | list(APPEND sanitizers -fsanitize=thread) 304 | else() 305 | message(FATAL_ERROR "The specified sanitizer type is invalid!") 306 | endif() 307 | endforeach() 308 | 309 | target_compile_options(${target} 310 | PRIVATE 311 | ${sanitizers} -fno-omit-frame-pointer 312 | ) 313 | 314 | target_link_options(${target} 315 | PRIVATE 316 | ${sanitizers} 317 | ) 318 | endif() 319 | endfunction() 320 | 321 | # When performing ThinLTO on macOS, mach-o object files are generated under a 322 | # temporary directory that gets deleted by the linker at the end of the build 323 | # process. Thus tools such as dsymutil cannot access the DWARF info contained 324 | # in those files. To ensure that the object files still exist after the build 325 | # process we have to set the `object_path_lto` linker option. 326 | function(torchdist_set_macos_lto_path target) 327 | get_target_property( 328 | #OUT 329 | target_type 330 | #TARGET 331 | ${target} 332 | #PROPERTY 333 | TYPE 334 | ) 335 | 336 | if(target_type STREQUAL "STATIC_LIBRARY") 337 | return() 338 | endif() 339 | 340 | set(lto_dir ${CMAKE_CURRENT_BINARY_DIR}/lto.d/${target}/${CMAKE_CFG_INTDIR}) 341 | 342 | add_custom_command( 343 | TARGET 344 | ${target} 345 | PRE_BUILD 346 | COMMAND 347 | ${CMAKE_COMMAND} -E make_directory "${lto_dir}" 348 | VERBATIM 349 | ) 350 | 351 | # See man ld(1). 352 | target_link_options(${target} 353 | PRIVATE 354 | LINKER:-object_path_lto "${lto_dir}" 355 | ) 356 | 357 | set_property(DIRECTORY APPEND PROPERTY 358 | ADDITIONAL_MAKE_CLEAN_FILES 359 | ${lto_dir} 360 | ) 361 | endfunction() 362 | 363 | function(torchdist_add_third_party) 364 | foreach(project IN ITEMS ${ARGV}) 365 | add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/third-party/${project} EXCLUDE_FROM_ALL) 366 | endforeach() 367 | endfunction() 368 | 369 | function(torchdist_enable_clang_tidy) 370 | if(NOT TORCHDIST_RUN_CLANG_TIDY) 371 | return() 372 | endif() 373 | 374 | if(NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang") 375 | message(FATAL_ERROR "clang-tidy can only be used with the Clang toolchain!") 376 | endif() 377 | 378 | find_program(TORCHDIST_CLANG_TIDY_PROG NAMES clang-tidy REQUIRED) 379 | 380 | mark_as_advanced(TORCHDIST_CLANG_TIDY_PROG) 381 | 382 | foreach(target IN ITEMS ${ARGV}) 383 | set_target_properties(${target} PROPERTIES 384 | C_CLANG_TIDY 385 | ${TORCHDIST_CLANG_TIDY_PROG} 386 | CXX_CLANG_TIDY 387 | ${TORCHDIST_CLANG_TIDY_PROG} 388 | CUDA_CLANG_TIDY 389 | ${TORCHDIST_CLANG_TIDY_PROG} 390 | ) 391 | endforeach() 392 | endfunction() 393 | 394 | function(torchdist_install target) 395 | cmake_parse_arguments(arg "" "PACKAGE" "HEADERS" ${ARGN}) 396 | 397 | # Set rpath if we are installing in standalone mode. 398 | if(TORCHDIST_INSTALL_STANDALONE) 399 | set(install_bindir bin) 400 | set(install_libdir lib) 401 | 402 | if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 403 | set(rpath_origin @loader_path) 404 | else() 405 | set(rpath_origin \$ORIGIN) 406 | endif() 407 | 408 | get_target_property( 409 | #OUT 410 | target_type 411 | #TARGET 412 | ${target} 413 | #PROPERTY 414 | TYPE 415 | ) 416 | 417 | if(target_type STREQUAL "EXECUTABLE") 418 | set(target_rpath ${rpath_origin}/../lib) 419 | else() 420 | set(target_rpath ${rpath_origin}) 421 | endif() 422 | 423 | set_target_properties(${target} PROPERTIES 424 | INSTALL_RPATH 425 | ${target_rpath} 426 | ) 427 | else() 428 | set(install_bindir ${CMAKE_INSTALL_BINDIR}) 429 | set(install_libdir ${CMAKE_INSTALL_LIBDIR}) 430 | endif() 431 | 432 | install( 433 | TARGETS 434 | ${target} 435 | EXPORT 436 | ${arg_PACKAGE}-targets 437 | RUNTIME 438 | DESTINATION 439 | ${install_bindir} 440 | COMPONENT 441 | runtime 442 | LIBRARY 443 | DESTINATION 444 | ${install_libdir} 445 | COMPONENT 446 | runtime 447 | NAMELINK_COMPONENT 448 | devel 449 | ARCHIVE 450 | DESTINATION 451 | ${install_libdir} 452 | COMPONENT 453 | devel 454 | INCLUDES DESTINATION 455 | ${CMAKE_INSTALL_INCLUDEDIR} 456 | ) 457 | 458 | cmake_path(GET CMAKE_CURRENT_SOURCE_DIR 459 | PARENT_PATH 460 | source_parent_dir 461 | ) 462 | 463 | foreach(header IN ITEMS ${arg_HEADERS}) 464 | cmake_path(REMOVE_FILENAME header 465 | OUTPUT_VARIABLE 466 | relative_header_dir 467 | ) 468 | 469 | set(header_dir ${CMAKE_CURRENT_SOURCE_DIR}/${relative_header_dir}) 470 | 471 | cmake_path(RELATIVE_PATH header_dir 472 | BASE_DIRECTORY 473 | ${source_parent_dir} 474 | ) 475 | 476 | install( 477 | FILES 478 | ${header} 479 | DESTINATION 480 | ${CMAKE_INSTALL_INCLUDEDIR}/${header_dir} 481 | COMPONENT 482 | devel 483 | ) 484 | endforeach() 485 | endfunction() 486 | 487 | function(torchdist_install_python_module target) 488 | # Set rpath if we are installing in standalone mode. 489 | if(TORCHDIST_INSTALL_STANDALONE) 490 | if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") 491 | set(rpath_origin @loader_path) 492 | else() 493 | set(rpath_origin \$ORIGIN) 494 | endif() 495 | 496 | set_target_properties(${target} PROPERTIES 497 | INSTALL_RPATH 498 | ${rpath_origin}/lib 499 | ) 500 | endif() 501 | 502 | install( 503 | TARGETS 504 | ${target} 505 | LIBRARY 506 | DESTINATION 507 | . 508 | COMPONENT 509 | python 510 | EXCLUDE_FROM_ALL 511 | ) 512 | endfunction() 513 | 514 | function(torchdist_install_package package config_file) 515 | if(TORCHDIST_INSTALL_STANDALONE) 516 | set(install_libdir lib) 517 | else() 518 | set(install_libdir ${CMAKE_INSTALL_LIBDIR}) 519 | endif() 520 | 521 | set(package_dir ${install_libdir}/cmake/${package}-${PROJECT_VERSION}) 522 | 523 | configure_package_config_file( 524 | #INPUT 525 | ${config_file} 526 | #OUTPUT 527 | ${CMAKE_CURRENT_BINARY_DIR}/${package}/lib/cmake/${package}/${package}-config.cmake 528 | INSTALL_DESTINATION 529 | ${package_dir} 530 | NO_SET_AND_CHECK_MACRO 531 | ) 532 | 533 | write_basic_package_version_file( 534 | #OUTPUT 535 | ${CMAKE_CURRENT_BINARY_DIR}/${package}/lib/cmake/${package}/${package}-config-version.cmake 536 | VERSION 537 | ${PROJECT_VERSION} 538 | COMPATIBILITY 539 | AnyNewerVersion 540 | ) 541 | 542 | install( 543 | FILES 544 | ${CMAKE_CURRENT_BINARY_DIR}/${package}/lib/cmake/${package}/${package}-config.cmake 545 | ${CMAKE_CURRENT_BINARY_DIR}/${package}/lib/cmake/${package}/${package}-config-version.cmake 546 | DESTINATION 547 | ${package_dir} 548 | COMPONENT 549 | devel 550 | ) 551 | 552 | install( 553 | EXPORT 554 | ${package}-targets 555 | FILE 556 | ${package}-targets.cmake 557 | DESTINATION 558 | ${package_dir} 559 | COMPONENT 560 | devel 561 | NAMESPACE 562 | ${package}:: 563 | ) 564 | 565 | export( 566 | EXPORT 567 | ${package}-targets 568 | FILE 569 | ${CMAKE_CURRENT_BINARY_DIR}/${package}/lib/cmake/${package}/${package}-targets.cmake 570 | NAMESPACE 571 | ${package}:: 572 | ) 573 | endfunction() 574 | -------------------------------------------------------------------------------- /docker/ci-base/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ubuntu:18.04 8 | 9 | COPY install-common install-python install-git /root/ 10 | 11 | RUN /root/install-common 12 | RUN /root/install-python 13 | RUN /root/install-git 14 | -------------------------------------------------------------------------------- /docker/ci-base/install-common: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | apt-get update 12 | 13 | apt-get install --yes curl libxml2 make software-properties-common zip 14 | 15 | rm -rf /var/lib/apt/lists/* 16 | -------------------------------------------------------------------------------- /docker/ci-base/install-git: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | add-apt-repository ppa:git-core/ppa 12 | 13 | apt-get update 14 | 15 | apt-get install --yes git 16 | 17 | rm -rf /var/lib/apt/lists/* 18 | -------------------------------------------------------------------------------- /docker/ci-base/install-python: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | add-apt-repository ppa:deadsnakes/ppa 12 | 13 | apt-get update 14 | 15 | apt-get install --yes\ 16 | python3.8 python3.8-dev python3.8-venv\ 17 | python3.9 python3.9-dev python3.9-venv\ 18 | python3.10 python3.10-dev python3.10-venv 19 | 20 | rm -rf /var/lib/apt/lists/* 21 | -------------------------------------------------------------------------------- /docker/ci-clang/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/pytorch/torchdistx-ci-base:2 8 | 9 | COPY install-clang install-cmake-ninja /root/ 10 | 11 | RUN /root/install-clang 12 | RUN /root/install-cmake-ninja 13 | -------------------------------------------------------------------------------- /docker/ci-clang/install-clang: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output llvm.cert https://apt.llvm.org/llvm-snapshot.gpg.key 12 | 13 | apt-key add llvm.cert 14 | 15 | rm llvm.cert 16 | 17 | add-apt-repository "deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic-13 main" 18 | 19 | apt-get update 20 | 21 | apt-get install --yes clang-13 clang++-13 clang-tidy-13 clang-format-13 22 | 23 | rm -rf /var/lib/apt/lists/* 24 | -------------------------------------------------------------------------------- /docker/ci-clang/install-cmake-ninja: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output cmake.sh\ 12 | https://github.com/Kitware/CMake/releases/download/v3.21.6/cmake-3.21.6-linux-x86_64.sh 13 | 14 | sh cmake.sh --skip-license 15 | 16 | rm cmake.sh 17 | 18 | curl --location --fail --output ninja.zip\ 19 | https://github.com/ninja-build/ninja/releases/download/v1.10.2/ninja-linux.zip 20 | 21 | unzip ninja.zip -d /usr/bin 22 | 23 | rm ninja.zip 24 | -------------------------------------------------------------------------------- /docker/ci-conda/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/pytorch/torchdistx-ci-base:2 8 | 9 | ENV PATH=/root/miniconda3/bin:$PATH 10 | 11 | COPY install-conda /root/ 12 | 13 | RUN /root/install-conda 14 | -------------------------------------------------------------------------------- /docker/ci-conda/Dockerfile.cu117: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/pytorch/torchdistx-ci-conda:2-cpu 8 | 9 | COPY install-cuda-11.7 /root/ 10 | 11 | RUN /root/install-cuda-11.7 12 | -------------------------------------------------------------------------------- /docker/ci-conda/Dockerfile.cu118: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/pytorch/torchdistx-ci-conda:2-cpu 8 | 9 | COPY install-cuda-11.8 /root/ 10 | 11 | RUN /root/install-cuda-11.8 12 | -------------------------------------------------------------------------------- /docker/ci-conda/install-conda: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output miniconda3.sh\ 12 | https://repo.anaconda.com/miniconda/Miniconda3-py39_4.11.0-Linux-x86_64.sh 13 | 14 | sh miniconda3.sh -b 15 | 16 | rm miniconda3.sh 17 | 18 | conda install --yes anaconda-client==1.9.0 conda==4.12.0 conda-build==3.21.8 conda-verify==3.4.2 19 | 20 | conda clean --all 21 | -------------------------------------------------------------------------------- /docker/ci-conda/install-cuda-11.7: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | -------------------------------------------------------------------------------- /docker/ci-conda/install-cuda-11.8: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | -------------------------------------------------------------------------------- /docker/ci-wheel/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM quay.io/pypa/manylinux2014_x86_64 8 | 9 | COPY install-devtoolset-10 install-awscli /root/ 10 | 11 | RUN /root/install-devtoolset-10 12 | RUN /root/install-awscli 13 | -------------------------------------------------------------------------------- /docker/ci-wheel/Dockerfile.cu117: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/pytorch/torchdistx-ci-wheel:2-cpu 8 | 9 | # CUDA 11.7 requires GCC 11.x. 10 | ENV PATH=/usr/local/cuda-11.7/bin:/opt/rh/devtoolset-11/root/usr/bin:$PATH 11 | 12 | ENV LD_LIBRARY_PATH=/usr/local/cuda-11.7/lib64:/opt/rh/devtoolset-11/root/usr/lib64:$LD_LIBRARY_PATH 13 | 14 | COPY install-devtoolset-11 install-cuda-11.7 install-cudnn-8.3.2 /root/ 15 | 16 | RUN /root/install-devtoolset-11 17 | RUN /root/install-cuda-11.7 18 | RUN /root/install-cudnn-8.3.2 19 | -------------------------------------------------------------------------------- /docker/ci-wheel/Dockerfile.cu118: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | FROM ghcr.io/pytorch/torchdistx-ci-wheel:2-cpu 8 | 9 | # CUDA 11.8 requires GCC 11.x. 10 | ENV PATH=/usr/local/cuda-11.8/bin:/opt/rh/devtoolset-11/root/usr/bin:$PATH 11 | 12 | ENV LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:/opt/rh/devtoolset-11/root/usr/lib64:$LD_LIBRARY_PATH 13 | 14 | COPY install-devtoolset-11 install-cuda-11.8 install-cudnn-8.3.2 /root/ 15 | 16 | RUN /root/install-devtoolset-11 17 | RUN /root/install-cuda-11.8 18 | RUN /root/install-cudnn-8.3.2 19 | -------------------------------------------------------------------------------- /docker/ci-wheel/install-awscli: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output awscli.zip\ 12 | https://awscli.amazonaws.com/awscli-exe-linux-x86_64-2.5.4.zip 13 | 14 | unzip awscli.zip 15 | 16 | aws/install --bin-dir /usr/bin 17 | 18 | rm -rf aws awscli.zip 19 | -------------------------------------------------------------------------------- /docker/ci-wheel/install-cuda-11.7: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/11.7.0/local_installers/cuda_11.7.0_515.43.04_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | -------------------------------------------------------------------------------- /docker/ci-wheel/install-cuda-11.8: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | curl --location --fail --output cuda.run\ 12 | https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run 13 | 14 | sh cuda.run --silent --toolkit --override --no-man-page 15 | 16 | rm cuda.run 17 | -------------------------------------------------------------------------------- /docker/ci-wheel/install-cudnn-8.3.2: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement 12 | 13 | mkdir cudnn && cd cudnn 14 | 15 | # Taken from https://github.com/pytorch/builder/blob/main/common/install_cuda.sh. 16 | curl --location --fail --output cudnn.tar.xz\ 17 | https://developer.download.nvidia.com/compute/redist/cudnn/v8.3.2/local_installers/11.5/cudnn-linux-x86_64-8.3.2.44_cuda11.5-archive.tar.xz 18 | 19 | tar xf cudnn.tar.xz 20 | 21 | cp cudnn-linux-x86_64-8.3.2.44_cuda11.5-archive/include/* /usr/local/cuda/include 22 | cp cudnn-linux-x86_64-8.3.2.44_cuda11.5-archive/lib/* /usr/local/cuda/lib64 23 | 24 | cd .. 25 | 26 | rm -rf cudnn 27 | -------------------------------------------------------------------------------- /docker/ci-wheel/install-devtoolset-10: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | # devtoolset-10's gcc and g++ are already installed on manylinux2014. 12 | 13 | yum --assumeyes install\ 14 | devtoolset-10-libasan-devel\ 15 | devtoolset-10-liblsan-devel\ 16 | devtoolset-10-libubsan-devel\ 17 | devtoolset-10-libtsan-devel 18 | 19 | yum clean all 20 | -------------------------------------------------------------------------------- /docker/ci-wheel/install-devtoolset-11: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | yum --assumeyes install\ 12 | devtoolset-11-gcc\ 13 | devtoolset-11-gcc-c++\ 14 | devtoolset-11-libasan-devel\ 15 | devtoolset-11-liblsan-devel\ 16 | devtoolset-11-libubsan-devel\ 17 | devtoolset-11-libtsan-devel 18 | 19 | yum clean all 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = src 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | --editable git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 2 | 3 | sphinx==4.3.0 4 | -------------------------------------------------------------------------------- /docs/src/_static/img/fake-tensor-dispatch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchdistx/9c1b9f5cb2fa36bfb8b70ec07c40ed42a33cc87a/docs/src/_static/img/fake-tensor-dispatch.png -------------------------------------------------------------------------------- /docs/src/_static/img/fake-tensor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchdistx/9c1b9f5cb2fa36bfb8b70ec07c40ed42a33cc87a/docs/src/_static/img/fake-tensor.png -------------------------------------------------------------------------------- /docs/src/_static/img/variable-hooks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchdistx/9c1b9f5cb2fa36bfb8b70ec07c40ed42a33cc87a/docs/src/_static/img/variable-hooks.png -------------------------------------------------------------------------------- /docs/src/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import pytorch_sphinx_theme 8 | import torchdistx 9 | 10 | # -- Project Information ----------------------------------------------------- 11 | 12 | project = "torchdistX" 13 | 14 | copyright = "Meta Platforms, Inc. and affiliates" 15 | 16 | author = "Pytorch Distributed Team" 17 | 18 | version = torchdistx.__version__ 19 | release = torchdistx.__version__ 20 | 21 | # -- General Configuration --------------------------------------------------- 22 | 23 | needs_sphinx = "4.3.0" 24 | 25 | extensions = [ 26 | "sphinx.ext.autodoc", 27 | "sphinx.ext.autosummary", 28 | "sphinx.ext.coverage", 29 | "sphinx.ext.intersphinx", 30 | "sphinx.ext.napoleon", 31 | "sphinx.ext.todo", 32 | "sphinx.ext.viewcode", 33 | ] 34 | 35 | autodoc_typehints = "description" 36 | autodoc_typehints_format = "short" 37 | 38 | todo_include_todos = True 39 | 40 | intersphinx_mapping = { 41 | "torch": ("https://pytorch.org/docs/stable/", None), 42 | } 43 | 44 | # -- Options for HTML Output ------------------------------------------------- 45 | 46 | html_theme = "pytorch_sphinx_theme" 47 | 48 | html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] 49 | 50 | html_theme_options = { 51 | "analytics_id": "UA-117752657-2", 52 | "collapse_navigation": False, 53 | "logo_only": True, 54 | "pytorch_project": "torchdistx", 55 | } 56 | -------------------------------------------------------------------------------- /docs/src/deferred_init.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: torchdistx.deferred_init 2 | 3 | Deferred Module Initialization 4 | ============================== 5 | TL;DR 6 | ------- 7 | Deferred Module Initialization feature consists of a :func:`deferred_init` 8 | function that constructs ``Module`` instances without allocating storage for 9 | their tensors, and the accompanying :func:`materialize_module` and 10 | :func:`materialize_tensor` functions that can fully or partially materialize 11 | modules constructed by :func:`deferred_init`. The feature is meant to be used if 12 | a module is memory-wise too big or computationally too expensive to construct on 13 | a single machine, but needs to be inspected for various reasons before being 14 | initialized. 15 | 16 | Problem 17 | ------- 18 | With ever increasing model sizes, it is becoming increasingly common for models 19 | to exceed the memory or compute capacity of a single machine or accelerator. 20 | This means training such models requires some sharding (a.k.a. partitioning) 21 | strategy to distribute parts of the model onto different computing nodes. 22 | However techniques such as 3D parallelism used to apply these strategies often 23 | need access to the model architecture to decide on the optimal strategy and this 24 | represents a chicken-egg problem. 25 | 26 | Automated parallelism libraries (e.g. FSDP, DeepSpeed) either completely ignore 27 | this problem, meaning they expect the model to fit on a single machine, or they 28 | have some rudimentary workarounds to partially overcome it. For instance they 29 | use a technique that sequentially initializes model parameters while sharding 30 | them on-the-fly based on some predefined memory-size threshold. However the 31 | limitation of such workarounds is that these libraries are not able to see the 32 | whole architecture of the model that would enable them to make smarter sharding 33 | decisions. 34 | 35 | What is Deferred Module Initialization? 36 | --------------------------------------- 37 | Deferred Module Initialization addresses the problem mentioned above by offering 38 | three functions. :func:`deferred_init` is a non-intrusive function that enables 39 | users to defer the initialization of a ``Module`` by skipping storage allocation 40 | for its parameters and buffers while also keeping a record of the operations 41 | performed on them in an in-memory graph. :func:`materialize_module` and 42 | :func:`materialize_tensor` are the accompanying functions that materialize 43 | (i.e. initialize) tensors or modules constructed within a previous 44 | :func:`deferred_init` call by re-playing the operations recorded at that time. 45 | 46 | API 47 | --- 48 | Initialization 49 | ^^^^^^^^^^^^^^ 50 | As mentioned above ``deferred_init()`` is the "entry point" of the API and has 51 | the following signature: 52 | 53 | .. autofunction:: deferred_init 54 | 55 | .. note:: 56 | The graph structure generated by ``deferred_init()`` is fairly simple, albeit 57 | holds information that is specifically meant to materialize in-memory tensors 58 | as if they were initialized without deferral. In that sense its 59 | implementation and its purpose diverges from the much larger and feature rich 60 | solutions such as torch.fx and TorchScript. 61 | 62 | Materialization 63 | ^^^^^^^^^^^^^^^ 64 | Modules, parameters, and buffers constructed within a :func:`deferred_init` call 65 | can later be materialized using the ``materialize_module()`` and 66 | ``materialize_tensor()`` functions. 67 | 68 | .. autofunction:: materialize_module 69 | .. autofunction:: materialize_tensor 70 | 71 | Examples 72 | -------- 73 | The simplest use case is to construct a module using :func:`deferred_init` and 74 | then later materialize it after some form of inspection using 75 | :func:`materialize_module`: 76 | 77 | :: 78 | 79 | >>> import torch 80 | >>> 81 | >>> from torchdistx.deferred_init import deferred_init, materialize_module 82 | >>> 83 | >>> # Notice that `m` does not have any storage even though it appears to be 84 | >>> # be a module allocated on CPU. 85 | >>> m = deferred_init(torch.nn.Linear, 5, 1): 86 | >>> m.weight 87 | Parameter containing: 88 | tensor(..., device='cpu', requires_grad=True, fake=True) 89 | >>> 90 | >>> # Do some form of inspection. 91 | >>> ... 92 | >>> 93 | >>> # At the end materialize the module. 94 | >>> materialize_module(m) 95 | >>> m.weight 96 | Parameter containing: 97 | tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, 98 | -1.4677e+24, 4.5915e-41]], requires_grad=True) 99 | 100 | It is also possible to materialize only a subset of modules, parameters, or 101 | buffers of a large model: 102 | 103 | :: 104 | 105 | >>> import torch 106 | >>> 107 | >>> from torchdistx.deferred_init import ( 108 | ... deferred_init, 109 | ... materialize_module, 110 | ... materialize_tensor, 111 | ... ) 112 | >>> 113 | >>> class MyLargeModel(torch.nn.Module): 114 | ... ... 115 | >>> 116 | >>> m = deferred_init(MyLargeModel): 117 | >>> 118 | >>> # Do some form of inspection (e.g. determine sharding strategy). 119 | >>> ... 120 | >>> 121 | >>> # Only materialize `sublayer1` and `sublayer2`. 122 | >>> materialize_module(m.sublayer1) 123 | >>> materialize_module(m.sublayer2) 124 | >>> 125 | >>> # Or materialize an individual parameter or buffer. 126 | >>> materialized_param = materialize_tensor(m.sublayer1.param1) 127 | 128 | :func:`deferred_init` skips storage allocation even for explicitly passed device 129 | arguments: 130 | 131 | :: 132 | 133 | >>> import torch 134 | >>> 135 | >>> from torchdistx.deferred_init import deferred_init, materialize_module 136 | >>> 137 | >>> class MyModule(torch.nn.Module): 138 | ... def __init__(self): 139 | ... super().__init__() 140 | ... self.param = torch.nn.Parameter(torch.ones([3], device="cpu")) 141 | ... 142 | >>> m = deferred_init(MyModule): 143 | >>> m.param 144 | Parameter containing: 145 | tensor(..., device='cpu', size=(10, 10), requires_grad=True, fake=True) 146 | >>> 147 | >>> materialize_module(m) 148 | >>> m.param 149 | Parameter containing: 150 | tensor([1., 1., 1.], requires_grad=True) 151 | 152 | Lazy modules can be used along with :func:`deferred_init()` by wrapping the 153 | module construction and the dry-run call in a single function as demonstrated 154 | below: 155 | 156 | :: 157 | 158 | >>> import torch 159 | >>> 160 | >>> from torchdistx.deferred_init import deferred_init 161 | >>> 162 | >>> def MyLazyModule(out_features: int): 163 | ... lazy_m = torch.nn.LazyLinear(out_features) 164 | ... 165 | ... # Dry-run the module to infer the parameter and buffer shapes. 166 | ... lazy_m(torch.ones([10, 10])) 167 | ... 168 | ... return lazy_m 169 | >>> 170 | >>> m = deferred_init(MyLazyModule, 10) 171 | 172 | However note that :func:`deferred_init` and materialize functions use a "best 173 | effort" approach and are not guaranteed to always succeed. See the 174 | `Common Failure Patterns`_ section below to learn more. 175 | 176 | Common Failure Patterns 177 | ----------------------- 178 | **A module using an operator that is not supported by the meta backend:** 179 | Internally :func:`deferred_init` relies on the meta backend. If the module to be 180 | constructed by :func:`deferred_init` uses an operator that is not yet supported 181 | by the meta backend, the operator call will fail. Fortunately such failures are 182 | easy to spot since the returned error message will clearly indicate which 183 | operator was the culprit. The solution in such case is to introduce meta backend 184 | support for the failed operation. 185 | 186 | **Mutable operator arguments:** Although almost all PyTorch operators use either 187 | primitives (e.g. integers, floating-point numbers) or tensors as parameter 188 | types, if an operator accepts a mutable argument (e.g. a storage, blob, future) 189 | with ``Tensor`` being an exception, :func:`deferred_init` will deliberately fail 190 | the operation since we cannot guarantee that the argument will have the same 191 | state during materialization. 192 | 193 | **In-place updated external tensors and inference tensors:** As a follow-up of 194 | mutable arguments, if a tensor constructed from external data (e.g. via 195 | ``torch.load()``, ``torch.from_numpy()``) is used as an argument to a meta 196 | operation within :func:`deferred_init`, its version counter will be tracked 197 | similar to Autograd. A change to the version counter, which practically means 198 | an in-place update to the tensor, will be checked during materialization and, if 199 | detected, an error will be raised since that would prevent the correct 200 | materialization. The rules are stricter for inference tensors; since in-place 201 | updates cannot be tracked for them any materialization call using an inference 202 | tensor as an argument will raise an error. 203 | 204 | **A module using tolist() or numpy() functions in its constructor:** Currently 205 | Deferred Module Initialization does not support tracing calls to ``tolist()`` 206 | and ``numpy()`` functions. We consider this a temporary limitation and will work 207 | with the PyTorch core team to mitigate it in future releases. 208 | -------------------------------------------------------------------------------- /docs/src/fake_tensor.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: torchdistx.fake 2 | 3 | Fake Tensor 4 | =========== 5 | Fake tensors, similar to meta tensors, carry no data; however, unlike meta 6 | tensors which report ``meta`` as their device, fake tensors act as if they were 7 | allocated on a real device. The following example shows how the two tensors 8 | types differ: 9 | 10 | :: 11 | 12 | >>> import torch 13 | >>> 14 | >>> from torchdistx.fake import fake_mode 15 | >>> 16 | >>> # Meta tensors are always "allocated" on the `meta` device. 17 | >>> a = torch.ones([10], device="meta") 18 | >>> a 19 | tensor(..., device='meta', size(10,)) 20 | >>> a.device 21 | device(type='meta') 22 | >>> 23 | >>> # Fake tensors are always "allocated" on the specified device. 24 | >>> with fake_mode(): 25 | ... b = torch.ones([10]) 26 | ... 27 | >>> b 28 | tensor(..., size(10,), fake=True) 29 | >>> b.device 30 | device(type='cpu') 31 | 32 | Fake tensors, like meta tensors, rely on the meta backend for their operation. 33 | In that sense meta tensors and fake tensors can be considered close cousins. 34 | Fake tensors are just an alternative interface to the meta backend and have 35 | mostly the same tradeoffs as meta tensors. 36 | 37 | API 38 | --- 39 | The API consists mainly of the ``fake_mode()`` function that acts as a Python 40 | context manager. Any tensor constructed within its scope will be forced to be 41 | fake. 42 | 43 | .. autofunction:: fake_mode 44 | 45 | There are also two convenience functions offered as part of the API: 46 | 47 | .. autofunction:: is_fake 48 | .. autofunction:: meta_like 49 | 50 | Use Cases 51 | --------- 52 | Fake tensors were originally meant as a building block for :doc:`deferred_init`. 53 | However they are not necessarily bound to that use case and can also be used for 54 | other purposes. For instance they serve as a surprisingly good learning tool for 55 | inspecting large model architectures that cannot fit on a consumer-grade PC: 56 | 57 | :: 58 | 59 | >>> import torch 60 | >>> 61 | >>> from transformers import BlenderbotModel, BlenderbotConfig 62 | >>> 63 | >>> from torchdistx.fake import fake_mode 64 | >>> 65 | >>> # Instantiate Blenderbot on a personal laptop with 8GB RAM. 66 | >>> with fake_mode(): 67 | ... m = BlenderbotModel(BlenderbotConfig()) 68 | ... 69 | >>> # Check out the model layers and their parameters. 70 | >>> m 71 | BlenderbotModel(...) 72 | -------------------------------------------------------------------------------- /docs/src/fake_tensor_and_deferred_init.rst: -------------------------------------------------------------------------------- 1 | .. currentmodule:: torchdistx.deferred_init 2 | 3 | Fake Tensors & Deferred Module Initialization 4 | ============================================= 5 | This design note assumes that you have already read the documentation of 6 | :doc:`deferred_init` and :doc:`fake_tensor`. In addition you are expected to be 7 | familiar with the c10 and ATen libraries of PyTorch. 8 | 9 | Introduction 10 | ------------ 11 | Deferred Module Initialization essentially relies on two new dispatch keys: 12 | ``Fake`` and ``DeferredInit``. 13 | 14 | ``Fake``, which will be described in detail below, is a post-autograd dispatch 15 | key and introduces the concept of a fake tensor. Although implemented as part of 16 | this work, it is not necessarily bound to Deferred Module Initialization and can 17 | be used independently. On the other hand ``DeferredInit``, a pre-autograd 18 | dispatch key, is specifically implemented for Deferred Module Initialization. It 19 | leverages the fake tensors to skip memory allocations and at the same time 20 | records the operations performed on those tensors in an in-memory graph. In a 21 | sense it is a lightweight symbolic tracer built on top of fake tensors. 22 | 23 | Fake Tensors 24 | ------------ 25 | Before diving into the technical details of the ``Fake`` dispatch key and the 26 | fake tensors, first the motivation of why the are needed. 27 | 28 | Problem with Meta Tensors 29 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 30 | A naive implementation of ``deferred_init()`` could intercept the tensor factory 31 | operations and replace all ``device`` arguments with the meta device to force 32 | tensors to be allocated on the meta backend. Although this approach would work 33 | fairly well if our goal was to solely skip initialization instead of deferring 34 | it, there is one major problem with it once materialization comes into play. 35 | See the following simple code snippet: 36 | 37 | :: 38 | 39 | >>> class MyModule(Module): 40 | ... def __init__(self): 41 | ... super().__init__() 42 | ... self.buf1 = torch.ones([3], device="cpu") 43 | ... self.buf2 = torch.zeros_like(self.buf1) 44 | 45 | Assuming we construct ``MyModule`` inside the scope of a ``deferred_init()`` 46 | call with the aforementioned naive approach, both ``buf1`` and ``buf2`` will be 47 | successfully allocated on the meta device as expected. However when we attempt 48 | to materialize them, we will hit the problem: 49 | 50 | :: 51 | 52 | >>> materialize_tensor(my_module.buf1) 53 | tensor([1., 1., 1.]) 54 | >>> materialize_tensor(my_module.buf2) 55 | tensor(..., device='meta') 56 | 57 | ``buf1`` will be successfully materialized on CPU, however ``buf2`` will remain 58 | on the meta device. The problem is that the implementation of 59 | ``torch.zero_like()`` looks effectively like this: 60 | 61 | :: 62 | 63 | def zeros_like(src: Tensor): 64 | return torch.zeros(src.shape, dtype=src.dtype, device=src.device, ...) 65 | 66 | This means when we record the operation in our internal graph the ``device`` 67 | argument that we capture for ``buf2`` will be ``Meta``, not ``CPU``. 68 | 69 | Another similar problem happens if the module initialization has some 70 | device-specific logic: 71 | 72 | :: 73 | 74 | def foo(self, device: Device) -> Tensor: 75 | a = torch.ones([1], device=device) 76 | 77 | return a if a.is_cuda else a + 1 78 | 79 | With the naive approach the materialized version of ``a`` will always contain 80 | ``[2., 2., 2.]`` even if the specified real ``device`` was ``CUDA``. This is 81 | because ``a`` will always be allocated on the meta device and ``is_cuda`` will 82 | never return ``True``. 83 | 84 | In summary in order for materialization to work properly we need a more 85 | sophisticated approach and this is where the ``Fake`` dispatch key and the fake 86 | tensor (i.e. ``FakeTensorImpl``) come into play. 87 | 88 | Solution 89 | ^^^^^^^^ 90 | ``FakeTensorImpl`` is a subclass of ``TensorImpl`` and behaves very similar to 91 | ``OpaqueTensorImpl`` meaning, although it is associated with a real device, it 92 | has no storage allocated to it. However unlike ``OpaqueTensorImpl`` it also 93 | holds an internal ``TensorImpl`` that is allocated on the meta backend that acts 94 | as a "shadow" of the actual tensor. 95 | 96 | .. image:: _static/img/fake-tensor.png 97 | :alt: FakeTensorImpl 98 | :scale: 50% 99 | :align: center 100 | 101 | The ``Fake`` dispatch key sits in-between Autograd and backend keys where its 102 | fallback (i.e. catch-all) handler replaces any fake tensor that is passed as an 103 | argument with its shadow meta tensor and forwards the operation to the meta 104 | backend. Once the meta backend call returns, it performs the reverse and 105 | replaces any shadow meta tensor with its fake tensor. Effectively dispatch keys 106 | above ``Fake`` such as Autograd see fake tensor arguments as regular real 107 | tensors while dispatch keys below it see them as meta tensors. 108 | 109 | .. image:: _static/img/fake-tensor-dispatch.png 110 | :alt: Fake Tensor Dispatch 111 | :scale: 50% 112 | :align: center 113 | 114 | Shortcomings 115 | ^^^^^^^^^^^^ 116 | Since internally fake tensors use the meta backend, they have the same 117 | shortcoming as regular meta tensors. If an operator has no support for the meta 118 | backend, it will fail in a similar way for a fake tensor as well. 119 | 120 | Another shortcoming that is unique to fake tensors is the support for 121 | mixed-device operators. Since the ``Fake`` handler never dispatches to the 122 | actual backend, we determine the output tensor(s) of an operator using the 123 | following logic: 124 | 125 | 1. If the operator has a ``BackendSelect`` kernel and a ``device`` argument, we 126 | consider the ``device`` argument the device of the output tensor(s). 127 | 2. Otherwise; if a ``TensorOptions`` can be extracted from the arguments of the 128 | operator, its ``device`` is considered the output of the tensor(s). 129 | 3. Otherwise; we consider the device of the first tensor in the arguments (or 130 | the first element if the argument is a tensor list) as the output of the 131 | tensor(s). 132 | 4. If none of the above is available, we default to CPU. 133 | 134 | Although we are not aware of any native PyTorch operator that contradicts with 135 | this logic, it is still a heuristic and can pick the wrong device for an 136 | unconventional operator. In the future we consider improving this implementation 137 | by leveraging some form of tagging mechanism. 138 | 139 | Deferred Module Initialization 140 | ------------------------------ 141 | The second dispatch key, ``DeferredInit``, is where the core logic of Deferred 142 | Module Initialization lies. The operations performed on tensors are recorded 143 | to a lightweight in-memory graph inside the fallback (i.e. catch-all) handler of 144 | ``DeferredInit``. In addition to recording operations, the handler also ensures 145 | that tensor factory operations are diverted to the ``Fake`` handler by 146 | modifying the ``DispatchKeySet`` of the call. This way all tensors constructed 147 | within a ``deferred_init()`` call are forced to be fake. 148 | 149 | Although this simplified description gives the main intuition behind the 150 | ``DeferredInit`` handler, there are two topics worth mentioning since they 151 | introduce some complexity to the overall implementation. 152 | 153 | Variable Methods 154 | ^^^^^^^^^^^^^^^^ 155 | There are three main category of functions that construct and modify tensors in 156 | PyTorch: (1) conventional operators based on the dispatcher mechanism, (2) a 157 | small set of regular functions such as ``torch.Tensor()``, 158 | ``torch.from_numpy()``, or ``torch.Tensor.numpy()`` that are part of the Python 159 | API, but that don't facilitate the dispatch mechanism, (3) and lastly 160 | ``Variable`` methods such as ``torch.Tensor.set_data()`` that, mostly due to 161 | historical reasons, leverage an alternative hook mechanism to separate Autograd 162 | implementation from the ATen library. 163 | 164 | With ``DeferredInit`` we are able to trace conventional operators as described 165 | above. The non-tracebility of regular functions is a pending (low-priority) 166 | work item that we plan to address in the future. The remaining category of 167 | ``Variable`` methods poses a problem though since there is no straightforward 168 | way to trace them, but they are essential for the materialization of tensors. In 169 | particular any read or write access to the ``torch.Tensor.data`` property in the 170 | Python API, which happens quite frequently with the use of 171 | ``torch.nn.Parameter``, requires tracing of the ``variable_data()`` and 172 | ``set_data()`` functions of the ``Variable`` interface. 173 | 174 | In order to be able to trace calls to the ``Variable`` interface, Deferred 175 | Module Initialization uses an additional mechanism beyond just having a 176 | dispatcher handler. As part of its prologue the ``deferred_init()`` call 177 | "hijacks" the global ``VariableHooksInterface`` instance that is exposed by 178 | Autograd. It wraps the instance with a proxy implementation of the interface 179 | that records the operations and then forwards them to the original instance. 180 | Technically this action is completely transparent to both Autograd and ATen. As 181 | part of its epilogue ``deferred_init()`` disposes its proxy and sets back the 182 | original instance as the global singleton. 183 | 184 | .. image:: _static/img/variable-hooks.png 185 | :alt: Variable Hooks 186 | :scale: 50% 187 | :align: center 188 | 189 | Mutable Tensors 190 | ^^^^^^^^^^^^^^^ 191 | Another complexity is introduced by the mutable nature of PyTorch tensors. This 192 | means our materialization logic cannot simply follow a chronological path 193 | through a unidirectional operation graph since operations performed later in 194 | time can still affect the output of earlier operations. Here a very simple 195 | example: 196 | 197 | :: 198 | 199 | >>> a = torch.ones([2, 2]) 200 | >>> b = a.view(-1) 201 | >>> a.add_(2) 202 | >>> b 203 | tensor([3., 3., 3., 3.]) 204 | 205 | Although ``a.add_()`` happens later in time than ``a.view()`` the output of 206 | ``b`` is still affected by the in-place operation. In order to correctly handle 207 | this and many similar cases caused by the mutability of PyTorch tensors, we use 208 | a bidirectional graph that still offers a topological order. 209 | -------------------------------------------------------------------------------- /docs/src/gossip_grad.rst: -------------------------------------------------------------------------------- 1 | GossipGraD communication strategy for ``FullyShardedDataParallel`` training with ``NO_SHARD`` strategy 2 | ======================================================================================================= 3 | `GossipGraD `_ is a gossip communication protocol 4 | for a large-scale training, which can provide communication efficiency over global `all_reduce` 5 | strategy. 6 | 7 | API 8 | --- 9 | 10 | .. autoclass:: torchdistx.gossip_grad.Topology 11 | 12 | .. autofunction:: torchdistx.gossip_grad.GossipGraDState 13 | 14 | .. autoclass:: torchdistx.gossip_grad.gossip_grad_hook -------------------------------------------------------------------------------- /docs/src/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/pytorch/torchdistx 2 | 3 | Torch Distributed Experimental 4 | ============================== 5 | Torch Distributed Experimental, or in short torchdistX, contains a collection of 6 | experimental features for which our team wants to gather feedback from our users 7 | before introducing them in the core PyTorch Distributed package. In a sense 8 | features included in torchdistX can be considered in an incubation period. 9 | 10 | .. note:: 11 | Please be advised that all features in torchdistX are subject to change and, 12 | although our team will make its best effort, we do not guarantee any API or 13 | ABI compatibility between releases. This means you should exercise caution if 14 | you plan to use torchdistX in production. 15 | 16 | Installation 17 | ------------ 18 | Check out `this section in our README `_ 19 | for installation instructions. 20 | 21 | Documentation 22 | ------------- 23 | .. toctree:: 24 | :maxdepth: 2 25 | :hidden: 26 | :caption: Torch Distributed Experimental 27 | 28 | Index 29 | 30 | .. toctree:: 31 | :maxdepth: 2 32 | :caption: Features 33 | 34 | fake_tensor 35 | deferred_init 36 | slow_momentum_fsdp 37 | gossip_grad 38 | 39 | .. toctree:: 40 | :maxdepth: 1 41 | :caption: Design Notes 42 | 43 | fake_tensor_and_deferred_init 44 | -------------------------------------------------------------------------------- /docs/src/slow_momentum_fsdp.rst: -------------------------------------------------------------------------------- 1 | Slow Momentum for ``FullyShardedDataParallel`` training with ``NO_SHARD`` strategy 2 | =================================================================================== 3 | Slow Momentum is a general framework to improve the accuracy of 4 | communication-efficient distributed training methods. The Slow Momentum algorithm 5 | requires exact-averaging of parameters before a momentum update, which is not feasible 6 | with sharded model parameters. As a result, the current implementation is 7 | available only for the FSDP ``NO_SHARD`` strategy. 8 | 9 | API 10 | --- 11 | 12 | The API consists of ``SlowMoState``, ``slowmo_hook``, and ``SlowMomentumOptimizer``. 13 | 14 | .. autoclass:: torchdistx.slowmo.slowmo_comm.SlowMoState 15 | 16 | .. autofunction:: torchdistx.slowmo.slowmo_comm.slowmo_hook 17 | 18 | .. autoclass:: torchdistx.slowmo.slowmo_optimizer.SlowMomentumOptimizer 19 | -------------------------------------------------------------------------------- /packaging/conda/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | # We perform LTO only if no sanitizer is enabled since they do not play well 12 | # together. 13 | if [[ -z "$TORCHDIST_SANITIZERS" ]]; then 14 | perform_lto=ON 15 | else 16 | perform_lto=OFF 17 | fi 18 | 19 | cmake -GNinja\ 20 | -DCMAKE_BUILD_TYPE=RelWithDebInfo\ 21 | -DCMAKE_INSTALL_PREFIX="$PREFIX"\ 22 | -DCMAKE_INSTALL_LIBDIR=lib\ 23 | -DCMAKE_FIND_FRAMEWORK=NEVER\ 24 | -DTORCHDIST_TREAT_WARNINGS_AS_ERRORS=ON\ 25 | -DTORCHDIST_PERFORM_LTO=$perform_lto\ 26 | -DTORCHDIST_DEVELOP_PYTHON=OFF\ 27 | -DTORCHDIST_SANITIZERS="$TORCHDIST_SANITIZERS"\ 28 | -S "$SRC_DIR"\ 29 | -B "$SRC_DIR/build" 30 | 31 | cmake --build "$SRC_DIR/build" 32 | 33 | # Extract the debug symbols; they will be part of the debug package. 34 | find "$SRC_DIR/build" -type f -name "libtorchdistx*"\ 35 | -exec "$SRC_DIR/scripts/strip-debug-symbols" --extract "{}" ";" 36 | -------------------------------------------------------------------------------- /packaging/conda/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | cmake: 8 | - 3.21.0 9 | cuda: 10 | - None 11 | cuda_home: 12 | - None 13 | cudnn: 14 | - None 15 | cxx_compiler_version: 16 | - 11.2.0 # [linux64] 17 | - 9.0 # [osx] 18 | ninja: 19 | - 1.10.2 20 | pip: 21 | - 22.0.3 22 | python: 23 | - 3.7 24 | - 3.8 25 | - 3.9 26 | - 3.10 27 | pytorch: 28 | - 29 | pytorch_variant: 30 | - cpu 31 | sanitizers: 32 | - None 33 | setuptools: 34 | - 60.9.3 35 | wheel: 36 | - 0.37.1 37 | 38 | zip_keys: 39 | - cuda 40 | - cuda_home 41 | - cudnn 42 | - cxx_compiler_version 43 | - pytorch_variant 44 | 45 | MACOSX_DEPLOYMENT_TARGET: # [osx] 46 | - 10.14 # [osx] 47 | -------------------------------------------------------------------------------- /packaging/conda/install-debug.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | if [[ $(uname -s) == Darwin ]]; then 12 | filter="-type d -name *.dSYM" 13 | else 14 | filter="-type f -name *.debug" 15 | fi 16 | 17 | find "$SRC_DIR/build" $filter -exec cp -a "{}" "$PREFIX/lib" ";" 18 | -------------------------------------------------------------------------------- /packaging/conda/install-devel.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | cmake --install "$SRC_DIR/build" --verbose --component devel 10 | -------------------------------------------------------------------------------- /packaging/conda/install-lib.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | cmake --install "$SRC_DIR/build" --verbose --component runtime --strip 10 | -------------------------------------------------------------------------------- /packaging/conda/install-python.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | pip install "$SRC_DIR" --verbose\ 10 | --ignore-installed\ 11 | --no-compile\ 12 | --no-deps\ 13 | --no-cache-dir\ 14 | --no-build-isolation 15 | -------------------------------------------------------------------------------- /packaging/conda/meta.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | {% set version = "0.3.0.dev0" %} 8 | 9 | {% set build_number = 0 %} 10 | 11 | # Set the build string. 12 | {% if cuda != "None" %} 13 | {% set build_str = "py{1}_cu{2}_{0}".format(build_number, python, cuda) %} 14 | {% else %} 15 | {% set build_str = "py{1}_cpu_{0}" .format(build_number, python) %} 16 | {% endif %} 17 | 18 | # Remove the version dots from the build string. 19 | {% set build_str = build_str.replace(".", "") %} 20 | 21 | # Append the sanitizer tag to the build string. 22 | {% if sanitizers != "None" %} 23 | {% set build_str = "{0}_{1}".format(build_str, sanitizers).replace(";", "_") %} 24 | {% endif %} 25 | 26 | package: 27 | name: torchdistx-cc 28 | version: {{ version }} 29 | 30 | source: 31 | path: ../../ 32 | 33 | build: 34 | number: {{ build_number}} 35 | string: {{ build_str }} 36 | skip: True # [not unix] 37 | script_env: 38 | - CUDA_HOME={{ cuda_home }} # [cuda != "None"] 39 | - TORCHDIST_SANITIZERS={{ sanitizers }} # [sanitizers != "None"] 40 | run_exports: 41 | # We do not maintain ABI compatibility between releases. 42 | - {{ pin_subpackage("torchdistx-cc", exact=True) }} 43 | ignore_run_exports: 44 | - cudatoolkit 45 | - cudnn 46 | # The `run_export` section of the `libsanitizer` package does not specify 47 | # a valid version range. We override it down below. 48 | - libsanitizer 49 | # Since we need an exact version of PyTorch we don't have to export its 50 | # mutex to our runtime requirements. 51 | - pytorch-mutex 52 | # libc10 and libtorch do not have their own packages. They are distributed 53 | # with the pytorch package and reside under the `lib` sub-directory of the 54 | # Python library. Therefore they are not discoverable by Conda and have to 55 | # be listed here. 56 | missing_dso_whitelist: 57 | - "*/libc10*" 58 | - "*/libtorch*" 59 | 60 | requirements: 61 | build: 62 | - {{ compiler("cxx") }} 63 | - cmake 64 | - ninja 65 | - nvcc_linux-64 {{ cuda }} # [cuda != "None"] 66 | host: 67 | - cudatoolkit {{ cuda }} # [cuda != "None"] 68 | - cudnn {{ cudnn }} # [cuda != "None"] 69 | - libsanitizer {{ cxx_compiler_version }} # [linux64 and sanitizers != "None"] 70 | - python {{ python }} 71 | - pytorch {{ pytorch }} 72 | - pytorch-mutex 1.0 {{ pytorch_variant }} 73 | run: 74 | # We include ASan, LSan, UBSan, and TSan libraries if necessary. 75 | - {{ pin_compatible("libsanitizer", max_pin="x.x.x") }} # [linux64 and sanitizers != "None"] 76 | # We require the exact same version of PyTorch during runtime since PyTorch 77 | # does not offer ABI compatibility. 78 | - {{ pin_compatible("pytorch", exact=True) }} 79 | 80 | test: 81 | commands: 82 | - test -f "$PREFIX/lib/libtorchdistx.so.0" # [linux] 83 | - test -f "$PREFIX/lib/libtorchdistx.0.dylib" # [osx] 84 | 85 | outputs: 86 | # This package contains the DSO (i.e. libtorchdistx.so). 87 | - name: torchdistx-cc 88 | script: install-lib.sh 89 | 90 | # This package contains the header files, CMake package configuration, and 91 | # soname symbolic link required for development. 92 | - name: torchdistx-cc-devel 93 | script: install-devel.sh 94 | build: 95 | string: {{ build_str }} 96 | run_exports: 97 | - {{ pin_subpackage("torchdistx-cc", exact=True) }} 98 | requirements: 99 | build: 100 | - cmake 101 | run: 102 | - {{ pin_subpackage("torchdistx-cc", exact=True) }} 103 | test: 104 | commands: 105 | - test -f "$PREFIX/lib/libtorchdistx.so" # [linux] 106 | - test -f "$PREFIX/lib/libtorchdistx.dylib" # [osx] 107 | about: 108 | home: https://github.com/pytorch/torchdistx 109 | license: BSD 110 | license_file: LICENSE 111 | summary: torchdistX C++ Runtime Library Development Files 112 | 113 | # This package contains the debug (i.e. DWARF) symbols of the DSO. 114 | - name: torchdistx-cc-debug 115 | script: install-debug.sh 116 | build: 117 | string: {{ build_str }} 118 | run_exports: 119 | - {{ pin_subpackage("torchdistx-cc", exact=True) }} 120 | requirements: 121 | build: 122 | - cmake 123 | run: 124 | - {{ pin_subpackage("torchdistx-cc", exact=True) }} 125 | about: 126 | home: https://github.com/pytorch/torchdistx 127 | license: BSD 128 | license_file: LICENSE 129 | summary: torchdistX C++ Runtime Library Debug Symbols 130 | 131 | # This package contains the Python library. 132 | - name: torchdistx 133 | script: install-python.sh 134 | build: 135 | string: {{ build_str }} 136 | # These environment variables are used by setup.py. 137 | run_exports: 138 | - {{ pin_subpackage("torchdistx", exact=True) }} 139 | # See the torchdistx-cc package above for why we need this list. 140 | missing_dso_whitelist: 141 | - "*/libc10*" 142 | - "*/libtorch*" 143 | requirements: 144 | build: 145 | # We need the compiler here to implicitly export the platform-specific 146 | # C++ standard library to the runtime requirements. This is needed for 147 | # our Python C extension. 148 | - {{ compiler("cxx") }} 149 | - cmake 150 | host: 151 | # We import PyTorch in setup.py to retrieve its version information. 152 | - {{ pin_compatible("pytorch", exact=True) }} 153 | - pip 154 | - python {{ python }} 155 | - setuptools 156 | - wheel 157 | run: 158 | - {{ pin_compatible("pytorch", exact=True) }} 159 | - {{ pin_subpackage("torchdistx-cc", exact=True) }} 160 | test: 161 | imports: 162 | - torchdistx.deferred_init 163 | - torchdistx.fake 164 | about: 165 | home: https://github.com/pytorch/torchdistx 166 | license: BSD 167 | license_file: LICENSE 168 | summary: torchdistX Python Library 169 | 170 | about: 171 | home: https://github.com/pytorch/torchdistx 172 | license: BSD 173 | license_file: LICENSE 174 | summary: torchdistX C++ Runtime Library 175 | 176 | extra: 177 | maintainers: 178 | - PyTorch Distributed Team 179 | -------------------------------------------------------------------------------- /packaging/conda/variants/cu117.yaml: -------------------------------------------------------------------------------- 1 | cuda: 2 | - 11.7 # [linux64] 3 | cuda_home: 4 | - /usr/local/cuda-11.7 # [linux64] 5 | cudnn: 6 | - 8.3.2 # [linux64] 7 | cxx_compiler_version: 8 | - 11.2.0 # [linux64] 9 | pytorch_variant: 10 | - cuda # [linux64] 11 | -------------------------------------------------------------------------------- /packaging/conda/variants/cu118.yaml: -------------------------------------------------------------------------------- 1 | cuda: 2 | - 11.8 # [linux64] 3 | cuda_home: 4 | - /usr/local/cuda-11.8 # [linux64] 5 | cudnn: 6 | - 8.3.2 # [linux64] 7 | cxx_compiler_version: 8 | - 11.2.0 # [linux64] 9 | pytorch_variant: 10 | - cuda # [linux64] 11 | -------------------------------------------------------------------------------- /requirements-devel.txt: -------------------------------------------------------------------------------- 1 | --requirement requirements.txt 2 | 3 | black==22.3.0 4 | expecttest==0.1.3 5 | flake8==4.0.1 6 | isort==5.10.1 7 | mypy==0.931 8 | numpy 9 | pytest==7.0.1 10 | shellcheck-py==0.8.0.4 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pip==22.0.3 2 | setuptools==60.9.3 3 | torch 4 | types-setuptools==57.4.9 5 | wheel==0.37.1 6 | -------------------------------------------------------------------------------- /scripts/set-version: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | function print_usage 12 | { 13 | printf "Usage: %s MAJOR.MINOR.PATCH [PRE_RELEASE [REV]]\n" "$(basename "$0")" 14 | } 15 | 16 | function exit_with_usage 17 | { 18 | print_usage >&2 && exit 0 19 | } 20 | 21 | function exit_with_error 22 | { 23 | print_usage >&2 && exit 1 24 | } 25 | 26 | function build_mmp_version 27 | { 28 | echo "$1" 29 | } 30 | 31 | function build_sem_version 32 | { 33 | echo "$1${2:+-$2${3:+.$3}}" 34 | } 35 | 36 | function build_pep_version 37 | { 38 | local -A pre_map=([alpha]=a [beta]=b [dev]=.dev) 39 | 40 | local pre=${2:+${pre_map[$2]:-$2}} 41 | 42 | echo "$1${pre:+$pre${3:-0}}" 43 | } 44 | 45 | function replace_match 46 | { 47 | sed --in-place --expression "$2" "$1" 48 | } 49 | 50 | function main 51 | { 52 | local src_dir 53 | local mmp_version 54 | local sem_version 55 | local pep_version 56 | 57 | if [[ $# -eq 0 || $# -gt 3 ]]; then 58 | exit_with_error 59 | fi 60 | 61 | if [[ $1 == -h || $1 == --help ]]; then 62 | if [[ $# -eq 1 ]]; then 63 | exit_with_usage 64 | else 65 | exit_with_error 66 | fi 67 | fi 68 | 69 | src_dir=$(cd "$(dirname "$0")" && pwd)/.. 70 | 71 | # Build the major.minor.patch, semantic, and PEP440 version strings. 72 | mmp_version=$(build_mmp_version "$@") 73 | sem_version=$(build_sem_version "$@") 74 | pep_version=$(build_pep_version "$@") 75 | 76 | # Update CMake 77 | replace_match "$src_dir/CMakeLists.txt"\ 78 | "s/VERSION .* LANGUAGES/VERSION $mmp_version LANGUAGES/" 79 | 80 | # Update Python 81 | replace_match "$src_dir/src/python/torchdistx/__init__.py"\ 82 | "s/__version__ = \".*\"/__version__ = \"$pep_version\"/" 83 | 84 | # Update Setuptools 85 | replace_match "$src_dir/setup.py"\ 86 | "s/version = \".*\"/version = \"$pep_version\"/" 87 | 88 | # Update Conda 89 | replace_match "$src_dir/packaging/conda/meta.yaml"\ 90 | "s/version = \".*\"/version = \"$pep_version\"/" 91 | 92 | # Update the VERSION file 93 | echo "$sem_version" > "$src_dir/VERSION" 94 | } 95 | 96 | main "$@" 97 | -------------------------------------------------------------------------------- /scripts/strip-debug-symbols: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -o errexit 10 | 11 | function print_usage 12 | { 13 | printf "Usage: %s [--extract] PATHNAME\n" "$(basename "$0")" 14 | } 15 | 16 | function exit_with_usage 17 | { 18 | print_usage >&1 && exit 0 19 | } 20 | 21 | function exit_with_error 22 | { 23 | print_usage >&2 && exit 1 24 | } 25 | 26 | function main 27 | { 28 | local target 29 | local should_extract 30 | 31 | if [[ $# -eq 0 || $# -gt 2 ]]; then 32 | exit_with_error 33 | fi 34 | 35 | if [[ $# -eq 1 ]]; then 36 | if [[ $1 == -h || $1 == --help ]]; then 37 | exit_with_usage 38 | fi 39 | else 40 | if [[ $1 != --extract ]]; then 41 | exit_with_error 42 | fi 43 | 44 | should_extract=true 45 | 46 | shift 47 | fi 48 | 49 | target=$1 50 | 51 | if [[ $(uname -s) == Darwin ]]; then 52 | if [[ $should_extract == true ]]; then 53 | # Extract the debug symbols. 54 | dsymutil --minimize -o "$target.dSYM" "$target" 55 | fi 56 | 57 | strip -r -x "$target" 58 | else 59 | if [[ $should_extract == true ]]; then 60 | # Extract the debug symbols. 61 | objcopy --only-keep-debug "$target" "$target.debug" 62 | 63 | # Associate the debug file with the DSO. 64 | objcopy --add-gnu-debuglink="$target.debug" "$target" 65 | fi 66 | 67 | objcopy --strip-unneeded "$target" 68 | fi 69 | } 70 | 71 | main "$@" 72 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import warnings 9 | from typing import List 10 | 11 | import torch 12 | from setuptools import Command, find_packages, setup 13 | from setuptools.command.install import install as install_base 14 | from setuptools.dist import Distribution as DistributionBase 15 | from setuptools.errors import FileError # type: ignore[attr-defined] 16 | 17 | package_path = "src/python" 18 | 19 | package_name = "torchdistx" 20 | 21 | 22 | class Distribution(DistributionBase): 23 | # Since we are injecting our Python C extension into the package instead 24 | # of building it we need to mark the package as non-pure. 25 | def has_ext_modules(self) -> bool: 26 | return True 27 | 28 | 29 | class install(install_base): 30 | install_base.sub_commands.append(("install_cmake", lambda self: True)) 31 | 32 | def finalize_options(self) -> None: 33 | install_base.finalize_options(self) 34 | 35 | # Older versions of distutils incorrectly check `ext_modules` to 36 | # determine whether a package is non-pure. We override it here. 37 | if self.distribution.has_ext_modules(): # type: ignore[attr-defined] 38 | self.install_lib = self.install_platlib 39 | 40 | 41 | # We inject our Python C extension and optionally our shared library into the 42 | # package by installing them directly via CMake. 43 | class install_cmake(Command): 44 | description = "install CMake artifacts" 45 | 46 | user_options = [ 47 | ("cmake-build-dir=", "b", "build directory (where to install from)"), 48 | ("install-dir=", "d", "directory to install to"), 49 | ("standalone", "s", "bundle C++ library"), 50 | ("no-standalone", None, "don't bundle C++ library"), 51 | ] 52 | 53 | boolean_options = ["standalone"] 54 | 55 | negative_opt = {"no-standalone": "standalone"} 56 | 57 | def initialize_options(self) -> None: 58 | # This is a required option and specifies the build (a.k.a. binary) 59 | # directory of the CMake project to install. 60 | self.cmake_build_dir = "build" 61 | 62 | # If not specified, the value of this option is copied over from the 63 | # parent `install` command. It specifies the directory into which to 64 | # install the CMake artifacts. 65 | self.install_dir: str = None # type: ignore[assignment] 66 | 67 | # By default we install a non-standalone package containing only the 68 | # Python C extension. For a wheel package this option must be set to 69 | # true to ensure that it also contains the shared library. 70 | self.standalone: bool = None # type: ignore[assignment] 71 | 72 | def finalize_options(self) -> None: 73 | self.ensure_dirname("cmake_build_dir") 74 | 75 | # If not specified, copy the value of `install_dir` from the `install` 76 | # command. 77 | self.set_undefined_options("install", ("install_lib", "install_dir")) 78 | 79 | # If not specified, we infer the value of `standalone` from the CMake 80 | # configuration file. 81 | if self.standalone is None: 82 | self.standalone = self._should_install_standalone() 83 | 84 | def _should_install_standalone(self) -> bool: 85 | try: 86 | f = open(os.path.join(self.cmake_build_dir, "CMakeCache.txt")) 87 | except FileNotFoundError: 88 | raise FileError("CMakeCache.txt not found. Run CMake first.") 89 | 90 | # Parse the value of the `TORCHDIST_INSTALL_STANDALONE` option from the 91 | # CMake configuration file. 92 | with f: 93 | for line in f: 94 | if line.startswith("TORCHDIST_INSTALL_STANDALONE"): 95 | _, value = line.strip().split("=", 1) 96 | 97 | return value.upper() in ["1", "ON", "TRUE", "YES", "Y"] 98 | 99 | return False 100 | 101 | def run(self) -> None: 102 | # If the user has requested a standalone package, install the shared 103 | # library and other related artifacts into the package. 104 | if self.standalone: 105 | self._cmake_install() 106 | 107 | # Install the Python C extension. 108 | self._cmake_install(component="python") 109 | 110 | def _cmake_install(self, component: str = None) -> None: 111 | prefix_dir = os.path.join(self.install_dir, package_name) 112 | 113 | cmd = ["cmake", "--install", self.cmake_build_dir, "--prefix", prefix_dir] 114 | 115 | if self.verbose: # type: ignore[attr-defined] 116 | cmd += ["--verbose"] 117 | 118 | if component: 119 | cmd += ["--component", component] 120 | 121 | # Ensure that we remove debug symbols from all DSOs. 122 | cmd += ["--strip"] 123 | 124 | # Run `cmake --install` in a subprocess. 125 | self.spawn(cmd) 126 | 127 | def get_inputs(self) -> List[str]: 128 | # We don't take any input files from other commands. 129 | return [] 130 | 131 | def get_outputs(self) -> List[str]: 132 | # Since we don't have an easy way to infer the list of files installed 133 | # by CMake we don't support the `record` option. 134 | warnings.warn("`install_cmake` does not support recording output files.") 135 | 136 | return [] 137 | 138 | 139 | def get_version() -> str: 140 | version = "0.3.0.dev0" 141 | 142 | if torch.version.cuda is None: 143 | return f"{version}+cpu" 144 | else: 145 | return f"{version}+cu{torch.version.cuda.replace('.', '')}" 146 | 147 | 148 | def read_long_description() -> str: 149 | with open("README.md") as f: 150 | return f.read() 151 | 152 | 153 | def main() -> None: 154 | setup( 155 | distclass=Distribution, 156 | cmdclass={ 157 | "install": install, # type: ignore[dict-item] 158 | "install_cmake": install_cmake, 159 | }, 160 | name="torchdistx", 161 | version=get_version(), 162 | description="A collection of experimental features for PyTorch Distributed", 163 | long_description=read_long_description(), 164 | long_description_content_type="text/markdown", 165 | author="PyTorch Distributed Team", 166 | url="https://github.com/pytorch/torchdistx", 167 | license="BSD", 168 | keywords=["pytorch", "machine learning"], 169 | packages=find_packages(where=package_path), 170 | package_dir={"": package_path}, 171 | package_data={"": ["py.typed", "*.pyi"]}, 172 | python_requires=">=3.7", 173 | zip_safe=False, 174 | # Since PyTorch does not offer ABI compatibility we have to make sure 175 | # that we use the same version that was used at build time. 176 | install_requires=[f"torch=={torch.__version__}"], 177 | classifiers=[ 178 | "Development Status :: 3 - Alpha", 179 | "Intended Audience :: Developers", 180 | "Intended Audience :: Science/Research", 181 | "License :: OSI Approved :: BSD License", 182 | "Programming Language :: Python :: 3", 183 | "Programming Language :: Python :: 3.7", 184 | "Programming Language :: Python :: 3.8", 185 | "Programming Language :: Python :: 3.9", 186 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 187 | ], 188 | ) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /src/cc/torchdistx-config.cmake.in: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | @PACKAGE_INIT@ 8 | 9 | include(CMakeFindDependencyMacro) 10 | 11 | find_dependency(Torch @Torch_VERSION@) 12 | 13 | include(${CMAKE_CURRENT_LIST_DIR}/torchdistx-targets.cmake) 14 | 15 | check_required_components(torchdistx) 16 | -------------------------------------------------------------------------------- /src/cc/torchdistx/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | torchdist_add_target(torchdistx SHARED_LIBRARY) 8 | 9 | add_library(torchdistx::torchdistx ALIAS torchdistx) 10 | 11 | target_sources(torchdistx 12 | PRIVATE 13 | deferred_init.cc 14 | fake.cc 15 | stack_utils.cc 16 | ) 17 | 18 | target_compile_features(torchdistx 19 | PUBLIC 20 | cxx_std_17 21 | ) 22 | 23 | target_link_libraries(torchdistx 24 | PUBLIC 25 | torch 26 | ) 27 | 28 | torchdist_install(torchdistx 29 | PACKAGE 30 | torchdistx 31 | HEADERS 32 | deferred_init.h 33 | fake.h 34 | macros.h 35 | ) 36 | -------------------------------------------------------------------------------- /src/cc/torchdistx/deferred_init.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include "macros.h" 13 | 14 | namespace at { 15 | 16 | class Tensor; 17 | 18 | } // namespace at 19 | 20 | namespace torchdistx { 21 | 22 | // Forces all newly-constructed tensors on the calling thread to be fake while 23 | // also recording all operations performed on them in memory. Such tensors can 24 | // later be materialized by calling `materializeTensor()`. 25 | TDX_API void enterDeferredInit(); 26 | TDX_API void leaveDeferredInit() noexcept; 27 | 28 | // Indicates whether `tensor` has been constructed in a deferred-init context. 29 | TDX_API bool canMaterialize(const at::Tensor& tensor) noexcept; 30 | 31 | // Materializes `tensor`. 32 | TDX_API at::Tensor materializeTensor(const at::Tensor& tensor); 33 | 34 | // Temporarily disables deferred-init. 35 | class TDX_API NoDeferredInit { 36 | c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::DeferredInit}; 37 | }; 38 | 39 | } // namespace torchdistx 40 | -------------------------------------------------------------------------------- /src/cc/torchdistx/fake.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include "macros.h" 15 | 16 | namespace at { 17 | 18 | class Tensor; 19 | class TensorBase; 20 | 21 | } // namespace at 22 | 23 | namespace torchdistx { 24 | namespace detail { 25 | 26 | class FakeTensorImpl; 27 | 28 | } // namespace detail 29 | 30 | // Forces all newly-constructed tensors on the calling thread to be fake. 31 | // 32 | // When `fake_cuda` is set to true, allows constructing fake CUDA tensors even 33 | // if CUDA is not available. 34 | TDX_API void enterFakeMode(bool fake_cuda = false); 35 | 36 | // Leaves the fake mode in the calling thread. 37 | TDX_API void leaveFakeMode() noexcept; 38 | 39 | // Indicates whether the calling thread is in fake mode. 40 | TDX_API bool isFakeModeActive() noexcept; 41 | 42 | // Indicates whether `tensor` is fake. 43 | TDX_API bool isFake(const at::TensorBase& tensor) noexcept; 44 | 45 | // Provides access to the properties of a fake tensor. 46 | class TDX_API FakeTensor { 47 | public: 48 | explicit FakeTensor(const at::TensorBase& tensor, bool unsafe = false); 49 | 50 | public: 51 | // Returns a meta tensor with the same properties. 52 | at::Tensor toMeta() const; 53 | 54 | void setData(at::DispatchKey key, std::shared_ptr data); 55 | 56 | bool hasData(at::DispatchKey key) const noexcept; 57 | 58 | std::shared_ptr getData(at::DispatchKey key) const; 59 | 60 | template 61 | inline auto getData(at::DispatchKey key) const { 62 | return std::static_pointer_cast(getData(key)); 63 | } 64 | 65 | void* unsafeGetData(at::DispatchKey key) const; 66 | 67 | template 68 | inline auto unsafeGetData(at::DispatchKey key) const { 69 | return static_cast(unsafeGetData(key)); 70 | } 71 | 72 | public: 73 | const at::Storage& meta_storage() const noexcept; 74 | 75 | private: 76 | detail::FakeTensorImpl* impl_; 77 | }; 78 | 79 | // Treats `tensor` as fake. 80 | TDX_API FakeTensor asFake(const at::TensorBase& tensor); 81 | 82 | // Treats `tensor` as fake without performing any type checks. 83 | TDX_API FakeTensor unsafeAsFake(const at::TensorBase& tensor) noexcept; 84 | 85 | } // namespace torchdistx 86 | -------------------------------------------------------------------------------- /src/cc/torchdistx/macros.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #define TDX_API __attribute__((visibility("default"))) 10 | -------------------------------------------------------------------------------- /src/cc/torchdistx/stack_utils.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "stack_utils.h" 8 | 9 | #include 10 | #include 11 | 12 | namespace torchdistx { 13 | 14 | using at::irange; 15 | using at::IValue; 16 | 17 | using torch::jit::Stack; 18 | 19 | } // namespace torchdistx 20 | 21 | namespace torchdistx::detail { 22 | 23 | void processTensors(const Stack& s, std::size_t n, const TensorProcessor& processor) { 24 | for (auto i : irange(n)) { 25 | const IValue& value = torch::jit::peek(s, i, n); 26 | if (value.isTensor()) { 27 | if (processor(value.toTensor())) { 28 | return; 29 | } 30 | } else if (value.isList()) { 31 | for (const IValue& elem : value.toListRef()) { 32 | if (elem.isTensor()) { 33 | if (processor(elem.toTensor())) { 34 | return; 35 | } 36 | } 37 | } 38 | } 39 | } 40 | } 41 | 42 | void convertTensors(Stack& s, std::size_t n, const TensorConverter& converter) { 43 | for (auto i : irange(n)) { 44 | IValue& value = torch::jit::peek(s, i, n); 45 | if (value.isTensor()) { 46 | converter(value.toTensor()); 47 | } else if (value.isList()) { 48 | for (const IValue& elem : value.toListRef()) { 49 | if (elem.isTensor()) { 50 | // Although technically not mandatory, `ArrayRef` only allows const 51 | // access to the underlying elements. 52 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) 53 | converter(const_cast(elem).toTensor()); 54 | } 55 | } 56 | } 57 | } 58 | } 59 | 60 | } // namespace torchdistx::detail 61 | -------------------------------------------------------------------------------- /src/cc/torchdistx/stack_utils.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | #include 13 | 14 | namespace at { 15 | 16 | class Tensor; 17 | 18 | } // namespace at 19 | 20 | namespace torchdistx::detail { 21 | 22 | using TensorProcessor = std::function; 23 | 24 | // Calls `processor` for all tensors in the last `n` entries of `s`. 25 | void processTensors(const torch::jit::Stack& s, std::size_t n, const TensorProcessor& processor); 26 | 27 | using TensorConverter = std::function; 28 | 29 | // Calls `converter` for all tensors in the last `n` entries of `s`. 30 | void convertTensors(torch::jit::Stack& s, std::size_t n, const TensorConverter& converter); 31 | 32 | } // namespace torchdistx::detail 33 | -------------------------------------------------------------------------------- /src/python/torchdistx/_C.pyi: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | def enter_deferred_init() -> None: ... 10 | def leave_deferred_init() -> None: ... 11 | def enter_fake_mode(fake_mode: bool) -> None: ... 12 | def leave_fake_mode() -> None: ... 13 | def is_fake(tensor: torch.Tensor) -> bool: ... 14 | def can_materialize(tensor: torch.Tensor) -> bool: ... 15 | def materialize_tensor(tensor: torch.Tensor) -> torch.Tensor: ... 16 | def meta_like(fake: torch.Tensor) -> torch.Tensor: ... 17 | -------------------------------------------------------------------------------- /src/python/torchdistx/_C/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | torchdist_add_target(torchdistx-py PYTHON_MODULE 8 | OUTPUT_NAME 9 | _C 10 | ) 11 | 12 | target_sources(torchdistx-py 13 | PRIVATE 14 | deferred_init.cc 15 | fake.cc 16 | module.cc 17 | ) 18 | 19 | target_compile_features(torchdistx-py 20 | PRIVATE 21 | cxx_std_17 22 | ) 23 | 24 | cmake_path(GET 25 | #VAR 26 | TORCH_LIBRARY 27 | PARENT_PATH 28 | torch_library_dir 29 | ) 30 | 31 | # libtorch_python is not exported as part of Torch CMake package, so we have to 32 | # manually find it. 33 | find_library(TORCH_PYTHON_LIBRARY 34 | #NAME 35 | torch_python 36 | PATHS 37 | ${torch_library_dir} 38 | ) 39 | 40 | mark_as_advanced(TORCH_PYTHON_LIBRARY) 41 | 42 | target_link_libraries(torchdistx-py 43 | PRIVATE 44 | pybind11::module torch torchdistx ${TORCH_PYTHON_LIBRARY} 45 | ) 46 | 47 | torchdist_install_python_module(torchdistx-py) 48 | -------------------------------------------------------------------------------- /src/python/torchdistx/_C/deferred_init.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "module.h" 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace py = pybind11; 17 | 18 | namespace torchdistx { 19 | 20 | using at::MaybeOwned; 21 | using at::Tensor; 22 | 23 | using c10::impl::PyInterpreterStatus; 24 | 25 | using torch::TypeError; 26 | 27 | } // namespace torchdistx 28 | 29 | namespace torchdistx::python { 30 | namespace { 31 | 32 | // Creates a new Python variable (i.e. tensor) that holds `data`. 33 | py::object makeVariable(PyTypeObject* type, Tensor data) { 34 | PyObject* naked_obj = type->tp_alloc(type, 0); 35 | 36 | TORCH_CHECK(naked_obj != nullptr, 37 | "Failed to construct the `Variable` object."); 38 | 39 | auto obj = py::reinterpret_steal(naked_obj); 40 | 41 | constexpr auto s = PyInterpreterStatus::DEFINITELY_UNINITIALIZED; 42 | 43 | // Associate ATen and Python tensor instances. 44 | data.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(getPyInterpreter(), naked_obj, s); 45 | 46 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) 47 | auto* var = reinterpret_cast(naked_obj); 48 | 49 | // `THPVariable` is a plain C struct, so we need to use placement new to 50 | // construct `cdata`. 51 | new (&var->cdata) MaybeOwned{}; 52 | 53 | var->cdata = MaybeOwned::owned(std::move(data)); 54 | 55 | return obj; 56 | } 57 | 58 | // Materializing a tensor in Python requires an extra step. We need to ensure 59 | // that the materialized tensor has the same Python class (e.g. `Variable` or 60 | // `Parameter`) as the original tensor. 61 | py::object materializeVariable(const py::object& var) { 62 | PyObject* naked_var = var.ptr(); 63 | 64 | if (!THPVariable_Check(naked_var)) { 65 | throw TypeError{"`var` has to be a `Variable`, but got `%s`.", Py_TYPE(naked_var)->tp_name}; 66 | } 67 | 68 | const Tensor& data = THPVariable_Unpack(naked_var); 69 | 70 | auto materialize = [](const Tensor& tensor) { 71 | py::gil_scoped_release guard{}; 72 | 73 | return materializeTensor(tensor); 74 | }; 75 | 76 | Tensor materialized_data = materialize(data); 77 | 78 | // Check if we have really materialized `data`. Materializing a regular tensor 79 | // is a no-op, so we can simply return. 80 | if (materialized_data.is_same(data)) { 81 | return var; 82 | } 83 | 84 | // We might have already materialized `data`. Make sure that we preserve its 85 | // identity on the Python side and avoid creating a new Python tensor. 86 | c10::optional opt_materialized_var = 87 | materialized_data.unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(getPyInterpreter()); 88 | if (opt_materialized_var.has_value()) { 89 | return py::reinterpret_borrow(*opt_materialized_var); 90 | } 91 | 92 | // Otherwise ensure that our materialized tensor has the same Python class as 93 | // the original tensor. 94 | return makeVariable(Py_TYPE(naked_var), std::move(materialized_data)); 95 | } 96 | 97 | } // namespace 98 | 99 | void initDeferredInitFunctions(py::module& m) { 100 | m.def("enter_deferred_init", enterDeferredInit); 101 | m.def("leave_deferred_init", leaveDeferredInit); 102 | 103 | m.def("can_materialize", canMaterialize); 104 | m.def("materialize_tensor", materializeVariable); 105 | } 106 | 107 | } // namespace torchdistx::python 108 | -------------------------------------------------------------------------------- /src/python/torchdistx/_C/fake.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "module.h" 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace torchdistx::python { 16 | namespace { 17 | 18 | void pyEnterFakeMode(bool fake_cuda) { 19 | enterFakeMode(fake_cuda); 20 | 21 | // If CUDA is not available, suppress PyTorch's attempt to initialize its CUDA 22 | // subsystem which would fail and prevent us from instantiating CUDA devices. 23 | if (fake_cuda) { 24 | if (!at::hasCUDA()) { 25 | torch::utils::set_requires_cuda_init(false); 26 | } 27 | } 28 | } 29 | 30 | void pyLeaveFakeMode() { 31 | leaveFakeMode(); 32 | 33 | if (!isFakeModeActive() && !at::hasCUDA()) { 34 | torch::utils::set_requires_cuda_init(true); 35 | } 36 | } 37 | 38 | } // namespace 39 | 40 | void initFakeFunctions(pybind11::module& m) { 41 | m.def("enter_fake_mode", pyEnterFakeMode); 42 | m.def("leave_fake_mode", pyLeaveFakeMode); 43 | 44 | m.def("is_fake", [](const at::Tensor& tensor) { 45 | return isFake(tensor); // cast to `TensorBase`. 46 | }); 47 | 48 | m.def("meta_like", [](const at::Tensor& fake) { 49 | return FakeTensor{fake}.toMeta(); 50 | }); 51 | } 52 | 53 | } // namespace torchdistx::python 54 | -------------------------------------------------------------------------------- /src/python/torchdistx/_C/module.cc: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #include "module.h" 8 | 9 | #include 10 | 11 | #include 12 | 13 | namespace py = pybind11; 14 | 15 | namespace torchdistx::python { 16 | namespace { 17 | 18 | void registerExceptionTranslator() { 19 | // NOLINTNEXTLINE(performance-unnecessary-value-param) 20 | py::register_exception_translator([](std::exception_ptr ex) { 21 | try { 22 | if (ex) { 23 | std::rethrow_exception(ex); 24 | } 25 | } 26 | CATCH_TH_ERRORS() // NOLINT 27 | }); 28 | } 29 | 30 | } // namespace 31 | 32 | // NOLINTNEXTLINE(clang-diagnostic-reserved-identifier) 33 | PYBIND11_MODULE(_C, m) { 34 | registerExceptionTranslator(); 35 | 36 | initDeferredInitFunctions(m); 37 | 38 | initFakeFunctions(m); 39 | } 40 | 41 | } // namespace torchdistx::python 42 | -------------------------------------------------------------------------------- /src/python/torchdistx/_C/module.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | // 4 | // This source code is licensed under the BSD-style license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | #pragma once 8 | 9 | #include 10 | 11 | namespace torchdistx::python { 12 | 13 | void initDeferredInitFunctions(pybind11::module& m); 14 | 15 | void initFakeFunctions(pybind11::module& m); 16 | 17 | } // namespace torchdistx::python 18 | -------------------------------------------------------------------------------- /src/python/torchdistx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | __version__ = "0.3.0.dev0" 8 | -------------------------------------------------------------------------------- /src/python/torchdistx/deferred_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, Dict, Optional, TypeVar, Union 8 | 9 | from torch import Tensor 10 | from torch.nn import Module 11 | 12 | # We import `fake` to monkey-patch `repr()` of `Tensor`. 13 | from . import fake # noqa: F401 14 | from . import _C 15 | 16 | T = TypeVar("T", bound=Module) 17 | 18 | 19 | def deferred_init(module_fn: Callable[..., T], *args, **kwargs) -> T: 20 | """Defers the initialization of a ``Module``. 21 | 22 | This function forces all tensors constructed within ``module_fn`` to be 23 | fake while also recording all operations performed on them. The modules 24 | and tensors returned from ``module_fn`` can later be instantiated using 25 | the :func:`materialize_tensor` and :func:`materialize_module` functions. 26 | 27 | Args: 28 | module_fn: 29 | A callable that takes arbitrary number of arguments and returns a 30 | ``Module`` instance. 31 | args, kwargs: 32 | The positional and keyword arguments to be passed to ``module_fn``. 33 | 34 | .. Warning:: 35 | The operations performed on the parameters and buffers of a module will 36 | only be recorded while inside ``deferred_init()``. Avoid making changes 37 | to a module after its returned from ``deferred_init()``; otherwise it 38 | cannot be correctly materialized. 39 | """ 40 | _C.enter_deferred_init() 41 | try: 42 | return module_fn(*args, **kwargs) 43 | finally: 44 | _C.leave_deferred_init() 45 | 46 | 47 | def is_deferred(obj: Union[Tensor, Module]) -> bool: 48 | """Indicates whether the provided tensor or module has been constructed in 49 | a deferred-init context. 50 | 51 | Args: 52 | obj: 53 | A ``Tensor`` or ``Module`` instance. 54 | """ 55 | if isinstance(obj, Tensor): 56 | return _C.can_materialize(obj) 57 | 58 | if isinstance(obj, Module): 59 | for prm in obj.parameters(): 60 | if _C.can_materialize(prm): 61 | return True 62 | 63 | for buf in obj.buffers(): 64 | if _C.can_materialize(buf): 65 | return True 66 | 67 | return False 68 | 69 | raise ValueError("`obj` must be of type `Tensor` or `Module`.") 70 | 71 | 72 | def materialize_tensor(tensor: Tensor) -> Tensor: 73 | """Materializes ``tensor``. 74 | 75 | Args: 76 | tensor: 77 | The tensor instance to materialize. 78 | 79 | .. Warning:: 80 | Once materialized a fake tensor will hold a reference to its 81 | materialized version. In order to avoid memory leaks make sure to 82 | dispose it when it is no longer required. 83 | """ 84 | return _C.materialize_tensor(tensor) 85 | 86 | 87 | def materialize_module( 88 | module: Module, 89 | buffers_only: bool = False, 90 | check_fn: Optional[Callable[[Module], bool]] = None, 91 | ) -> None: 92 | """Materializes ``module`` and its descendant modules. 93 | 94 | Args: 95 | module: 96 | The module instance to materialize. 97 | buffers_only: 98 | A boolean value indicating whether to materialize the buffer tensors 99 | only. 100 | check_fn: 101 | An optional callable which takes a ``Module`` instance and returns a 102 | boolean value indicating whether to materialize it. 103 | """ 104 | 105 | def materialize_tensors(tensors: Dict[str, Optional[Tensor]]) -> None: 106 | for key, tensor in tensors.items(): 107 | if tensor is None: 108 | continue 109 | 110 | try: 111 | tensors[key] = _C.materialize_tensor(tensor) 112 | except ValueError: 113 | raise ValueError(f"'{key}' has already been materialized.") from None 114 | 115 | # Materialize the child modules recursively. 116 | for m in module.children(): 117 | materialize_module(m, buffers_only, check_fn) 118 | 119 | # Materialize this module, possibly based on a check. 120 | if check_fn is None or check_fn(module): 121 | if not buffers_only: 122 | materialize_tensors(module._parameters) # type: ignore[arg-type] 123 | 124 | materialize_tensors(module._buffers) 125 | -------------------------------------------------------------------------------- /src/python/torchdistx/fake.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from contextlib import contextmanager 8 | from typing import Callable, Generator 9 | 10 | import torch 11 | 12 | from . import _C 13 | 14 | 15 | # Since the `repr()` method of `Tensor` is not extensible we monkey-patch it 16 | # to support fake tensors. 17 | def _patch_tensor_repr() -> Callable[[torch.Tensor], str]: 18 | tensor_repr = torch.Tensor.__repr__ 19 | 20 | def patched_repr(tensor: torch.Tensor) -> str: 21 | if _C.is_fake(tensor): 22 | s = f"tensor(..., size={tuple(tensor.shape)}" 23 | 24 | if tensor.dtype != torch.get_default_dtype(): 25 | s += f", dtype={tensor.dtype}" 26 | 27 | if tensor.device.type != "cpu": 28 | s += f", device={tensor.device}" 29 | 30 | if tensor.requires_grad: 31 | s += ", requires_grad=True" 32 | 33 | return s + ", fake=True)" 34 | else: 35 | return tensor_repr(tensor) 36 | 37 | return patched_repr 38 | 39 | 40 | torch.Tensor.__repr__ = _patch_tensor_repr() # type: ignore[assignment] 41 | 42 | 43 | @contextmanager 44 | def fake_mode(*, fake_cuda: bool = False) -> Generator: 45 | """Instantiates all tensors within its context as fake. 46 | 47 | Args: 48 | fake_cuda: 49 | If ``True``, allows constructing fake CUDA tensors even if CUDA is 50 | not available. Ignored if CUDA is already available. 51 | """ 52 | _C.enter_fake_mode(fake_cuda) 53 | try: 54 | yield 55 | finally: 56 | _C.leave_fake_mode() 57 | 58 | 59 | def is_fake(tensor: torch.Tensor) -> bool: 60 | """Indicates whether ``tensor`` is fake. 61 | 62 | Args: 63 | tensor: 64 | The tensor to check. 65 | """ 66 | return _C.is_fake(tensor) 67 | 68 | 69 | def meta_like(fake: torch.Tensor) -> torch.Tensor: 70 | """Returns a meta tensor with the same properties as ``fake``. 71 | 72 | This function has the same Autograd behavior as ``detach()`` meaning the 73 | returned tensor won't be part of the Autograd graph. 74 | 75 | Args: 76 | fake: 77 | The fake tensor to copy from. 78 | """ 79 | try: 80 | return _C.meta_like(fake) 81 | except ValueError: 82 | raise ValueError("`fake` was expected to be a fake tensor.") 83 | -------------------------------------------------------------------------------- /src/python/torchdistx/gossip_grad.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import random 9 | from enum import Enum, auto 10 | from itertools import cycle 11 | 12 | import torch 13 | import torch.distributed as dist 14 | from torch._C._distributed_c10d import ProcessGroup 15 | from torch.distributed.algorithms._comm_hooks import default 16 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 17 | 18 | # Setting a constant for situations, when communication peer 19 | # is not present in a current environment. This may happen in CUBE topology, 20 | # when a number of nodes is not equal to a power of 2. In this case, both 21 | # send and receive peers are equal to INVALID_PEER and no communication is 22 | # performed. 23 | INVALID_PEER = -1 24 | 25 | 26 | class Topology(Enum): 27 | r""" 28 | Specifies which topology will be used as a base for gradient communication. 29 | For more information, please refer to the original 30 | `paper `_ 31 | 32 | CUBE: 33 | A hypercube topology - a hierarchical virtual organization of compute nodes. 34 | For this topology gossiping is happening with a neighboring vertex. 35 | 36 | >>> *----* 37 | >>> /| /| 38 | >>> *----* | 39 | >>> | * -|-* 40 | >>> |/ |/ 41 | >>> *----* 42 | 43 | DISSEMINATION: 44 | A dissemination topology has similar property 45 | as hypercube virtual topology. 46 | For this topology gossiping is happening with the neighboring node, 47 | then every 2nd node, every 4th, etc. 48 | 49 | >>> . * . 50 | >>> * * 51 | >>> . . 52 | >>> * * 53 | >>> . . 54 | >>> * * 55 | >>> . * . 56 | 57 | .. note:: 58 | Current implementation does not support uneven number of nodes for a CUBE 59 | topology. 60 | 61 | """ 62 | CUBE = auto() 63 | DISSEMINATION = auto() 64 | 65 | 66 | class GossipGraDState(default.DefaultState): 67 | r""" 68 | Stores state needed to perform GossipGraD algorithm within a communication hook. 69 | 70 | .. note:: Note that this hook should be used with the NCCL PG backend and users 71 | must set the current GPU device with `torch.cuda.set_device` prior to 72 | ``GossipGraDState`` initialization, otherwise it will lead to 73 | unexpected hang issues during the gossiping stage. 74 | 75 | Args: 76 | num_modules (int): Number of FSDP modules to identify how many communication 77 | calls will be performed during a backpropagation pass. 78 | topology (Topology): A virtual topology to be used for gradient communication. 79 | (default: DISSEMINATION) 80 | local_process_group (ProcessGroup): Stores local subgroup, 81 | where intra-node communication will happen, 82 | by default a subgroup is initialized to workers, belonging to the same node. 83 | Should be provided together with `num_nodes`. When every local process group 84 | contains only one worker, then this worker is considered to be a separate 85 | node and local ``all_reduce`` and ``broadcast`` are not performed. 86 | (default: None) 87 | num_nodes (int): Number of nodes in a compute environment. 88 | Should be provided together with `local_process_group`. 89 | By default is initialized to the number of generated local subgroups. 90 | (default: None) 91 | master_process_group (ProcessGroup): Stores main workers, 92 | which are involved in inter-node communication. By default, will be 93 | composed from the workers with rank 0 in the local process group. 94 | (default: None) 95 | proc_per_node (int): Number of workers in each node. By default is initialized 96 | to the size of a local subgroup. 97 | (default: None) 98 | random_seed (int): A random seed, so that randomly generated topologies 99 | were the same on every worker. 100 | (default: 2403) 101 | 102 | """ 103 | 104 | def __init__( 105 | self, 106 | num_modules, 107 | topology=None, 108 | local_process_group=None, 109 | num_nodes=None, 110 | master_process_group=None, 111 | proc_per_node=None, 112 | random_seed=2403, 113 | ): 114 | if num_modules is None or num_modules < 1: 115 | raise ValueError("`num_nodes` should bea positive integer.") 116 | self.num_modules = num_modules 117 | self.topology = topology or Topology.DISSEMINATION 118 | if local_process_group is None and num_nodes is None: 119 | self.local_process_group, subgroups = dist.new_subgroups() 120 | self.num_nodes = len(subgroups) 121 | else: 122 | if ( 123 | local_process_group is not None 124 | and num_nodes is None 125 | or local_process_group is None 126 | and num_nodes is not None 127 | ): 128 | raise ValueError( 129 | "`local_process_group` and `num_nodes` should be provided together." 130 | ) 131 | self.local_process_group = local_process_group 132 | if num_nodes < 1: 133 | raise ValueError("`num_nodes` should be equal to 1 or more.") 134 | self.num_nodes = num_nodes 135 | 136 | if self.num_nodes % 2 != 0 and self.topology == Topology.CUBE: 137 | raise ValueError( 138 | "Current implementation doesn't support uneven number" 139 | " of nodes for CUBE topology." 140 | ) 141 | 142 | super().__init__(self.local_process_group) 143 | self.proc_per_node = ( 144 | proc_per_node 145 | if proc_per_node is not None 146 | else self.local_process_group.size() 147 | ) 148 | if self.proc_per_node < 1: 149 | raise ValueError("`proc_per_node` should be equal to 1 or more.") 150 | 151 | self.master_process_group = ( 152 | master_process_group 153 | if master_process_group is not None 154 | else self._create_master_group() 155 | ) 156 | 157 | self.random_seed = random_seed 158 | self.topologies = self._generate_topologies(self.random_seed) 159 | self.cur_topology = next(self.topologies) 160 | 161 | # For `num_nodes` != power of 2 `gossip_period` should still be an int. 162 | # If we only have 1 node, `gossip_period` should be equal to 1. 163 | self.gossip_period = max(1, math.ceil(math.log(self.num_nodes, 2))) 164 | self.iter = 0 165 | 166 | # Get rank for current device 167 | self.rank = dist.get_rank() 168 | 169 | # Master worker for a current local `process_group` 170 | self.master_worker = dist.distributed_c10d._get_global_rank( 171 | self.local_process_group, 0 172 | ) 173 | 174 | def _create_master_group(self): 175 | r""" 176 | Creates master process group, i.e. a group of workers, 177 | which communicate gradients between different nodes. 178 | """ 179 | # Every 0th worker on every node will be assigned to a master group, 180 | # i.e. if number of rocesses per node is 8, master group contains 181 | # 0th, 8th, 16th, 24th, 32nd, ... ranks 182 | ranks = [i * self.proc_per_node for i in range(self.num_nodes)] 183 | return dist.new_group(ranks) 184 | 185 | def _generate_topologies(self, random_seed): 186 | r""" 187 | Creates `num_nodes` random topology shuffles and returns an infinite iterator. 188 | Original topology is of the form: 189 | [0*K, 1*K, ... , N*K], 190 | where N is the number of nodes and K - the number of workers on each node. 191 | For example, with N=4 and K=8, original topology is 192 | [0, 8, 16, 24] 193 | 194 | Workers' rank values are used instead of node values for easier peer assignment 195 | in a collective communication stage. 196 | 197 | Returns: 198 | An infinite iterator over created topologies 199 | """ 200 | random.seed(random_seed) 201 | topologies_set = [] 202 | original_list = [i * self.proc_per_node for i in range(self.num_nodes)] 203 | for _ in range(self.num_nodes): 204 | random.shuffle(original_list) 205 | topologies_set.append(original_list.copy()) 206 | 207 | return cycle(topologies_set) 208 | 209 | 210 | def _get_send_recv_peers(state): 211 | r""" 212 | Computes peers for the collective communication stage. 213 | For a ``CUBE`` topology a node sends grads to and receives from 214 | the same neighboring vertex. A pick for a neighboring vertex 215 | depends on the step number and current virtual topology in use. 216 | 217 | For a ``DISSEMINATION`` topology a node typically sends grads 218 | to and receives from different neighbors, but there may be a step 219 | where send and receive peers are the same node. A pick for send and receive peers 220 | depends on the step number and current virtual topology in use. 221 | 222 | For more information, please refer to the original 223 | `paper `_ 224 | 225 | Args: 226 | state (GossipGradState): State for GossipGraD communication hook. 227 | 228 | Returns: 229 | Peers' global ranks to whom a current node sends gradients 230 | and from whom it is received. 231 | """ 232 | assert state.gossip_period > 0, "`gossip_period` should be greater than 0." 233 | power = (state.iter // state.num_modules) % state.gossip_period 234 | # Our new node_rank is a position of a global rank in 235 | # a virtual topology 236 | node_rank = state.cur_topology.index(state.rank) 237 | 238 | if state.topology == Topology.CUBE: 239 | peer_idx = node_rank ^ 2**power 240 | if peer_idx >= len(state.cur_topology): 241 | return INVALID_PEER, INVALID_PEER 242 | return state.cur_topology[peer_idx], state.cur_topology[peer_idx] 243 | 244 | elif state.topology == Topology.DISSEMINATION: 245 | send_peer_idx = (node_rank + 2**power) % state.num_nodes 246 | recv_peer_idx = (node_rank - 2**power + state.num_nodes) % state.num_nodes 247 | return state.cur_topology[send_peer_idx], state.cur_topology[recv_peer_idx] 248 | 249 | 250 | def _gossip(state, grad, scaling_factor=0.5): 251 | r""" 252 | Gossiping stage. 253 | 254 | At this step, it obtains communication peers, 255 | stacks ``torch.distributed.irecv`` and ``torch.distributed.isend`` operations, 256 | and performs communication with ``torch.distributed.batch_isend_irecv``. 257 | Finally, received and current gradients are added together 258 | and scaled appropriately, i.e. since communication happens 259 | only between 2 peers at a time, summed gradients are divided 260 | by 2 (or multiplied by 0.5) 261 | 262 | For more information, please refer to the original 263 | `paper `_ 264 | 265 | Args: 266 | state (GossipGradState): State for GossipGraD communication hook. 267 | grad (torch.Tensor): A gradient for the local batch 268 | that needs to be communicated across ranks. 269 | scaling_facto (float): Scaling factor to apply after 270 | received and current gradients are combined. 271 | 272 | """ 273 | send_peer, recv_peer = _get_send_recv_peers(state) 274 | 275 | if send_peer == INVALID_PEER or recv_peer == INVALID_PEER: 276 | return 277 | 278 | assert send_peer is not None and recv_peer is not None, ( 279 | "Failed to calculate send and receive peers: " 280 | f"(`send_peer` is {send_peer} and `recv_peer` is {recv_peer})" 281 | ) 282 | # Need to check that send and receive peers are not equal to a current rank 283 | assert send_peer != state.rank and recv_peer != state.rank, ( 284 | "Expected send and receive peers to differ from a current rank: " 285 | f"(current rank is {state.rank}, `send_peer` is {send_peer}\ 286 | and `recv_peer` is {recv_peer})" 287 | ) 288 | assert ( 289 | send_peer != -1 and recv_peer != -1 290 | ), "Communication peers are not present in a current topology" 291 | recv_grad = torch.empty_like(grad) 292 | ops = [] 293 | 294 | # For ranks not in the `master_process_group`, 295 | # `master_process_group` is an `object` instance 296 | assert isinstance( 297 | state.master_process_group, ProcessGroup 298 | ), "`master_process_group` is not an instance of `ProcessGroup`" 299 | 300 | ops.append( 301 | dist.P2POp( 302 | op=dist.isend, tensor=grad, peer=send_peer, group=state.master_process_group 303 | ) 304 | ) 305 | ops.append( 306 | dist.P2POp( 307 | op=dist.irecv, 308 | tensor=recv_grad, 309 | peer=recv_peer, 310 | group=state.master_process_group, 311 | ) 312 | ) 313 | reqs = dist.batch_isend_irecv(ops) 314 | for req in reqs: 315 | req.wait() 316 | grad.add_(recv_grad).mul_(scaling_factor) 317 | 318 | 319 | def get_num_modules(module: torch.nn.Module): 320 | r""" 321 | Returns number of FSDP modules in a provided FSDP instance. 322 | 323 | Args: 324 | module (torch.nn.Module): FSDP instance 325 | 326 | Returns: 327 | int: number of FSDP modules that are nested in the input ``module``, 328 | including self. 329 | 330 | """ 331 | return len(FSDP.fsdp_modules(module)) 332 | 333 | 334 | def gossip_grad_hook(state: GossipGraDState, grad: torch.Tensor): 335 | r""" 336 | Communication hook, that follows 337 | `GossipGraD `_ strategy. 338 | 339 | Every ``state.gossip_period`` step a virtual topology is changed. 340 | Before an inter-node communication happens, gradients are reduced locally, 341 | i.e. in an intra-node fashion. 342 | 343 | Only workers from a master process group are participating in a gossiping stage. 344 | Finally, every main worker broadcasts final gradient to its local subgroup 345 | 346 | Args: 347 | state (GossipGradState): State for GossipGraD communication hook. 348 | grad (torch.Tensor): A gradient for the local batch 349 | that needs to be communicated across ranks. 350 | 351 | Here is an example for how to initialize a default ``GossipGraD state`` 352 | and register an fsdp model with a communication hook. 353 | :: 354 | 355 | >>> import torch 356 | >>> import torch.distributed as dist 357 | >>> from torch.distributed.fsdp import( 358 | >>> FullyShardedDataParallel as FSDP 359 | >>> ) 360 | >>> from torchdistx.gossip_grad import( 361 | >>> GossipGraDState, 362 | >>> Topology, 363 | >>> get_num_modules, 364 | >>> gossip_grad_hook 365 | >>> ) 366 | >>> 367 | >>> net = torch.nn.Linear(4, 10) 368 | >>> fsdp_net = FSDP(net) 369 | >>> state = GossipGraDState(num_modules=get_num_modules(fsdp_net)) 370 | >>> fsdp_net.register_comm_hook(state, gossip_grad_hook) 371 | 372 | """ 373 | # Virtual topology changes every `state.gossip_period` step. 374 | # FSDP net can consist of multiple FSDP modules and every module will 375 | # increase `state.iter` during the backward pass. As a result, we need 376 | # to adjust for this behavior and make sure that virtual topology doesn't 377 | # change in the middle of the backward pass. 378 | if (state.iter // state.num_modules) % state.gossip_period == 0: 379 | state.cur_topology = next(state.topologies) 380 | 381 | # Reduce local gradients 382 | default.allreduce_hook(state, grad) 383 | # Perform gossiping step between master nodes (via master workers) 384 | if not dist._rank_not_in_group(state.master_process_group): 385 | _gossip(state, grad) 386 | # Broadcast received gradients in the local process group 387 | dist.broadcast(grad, src=state.master_worker, group=state.local_process_group) 388 | 389 | state.iter += 1 390 | -------------------------------------------------------------------------------- /src/python/torchdistx/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .anyprecision_optimizer import AnyPrecisionAdamW 2 | -------------------------------------------------------------------------------- /src/python/torchdistx/optimizers/anyprecision_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # AnyPrecisionAdamW: a flexible precision AdamW optimizer 8 | # with optional Kahan summation for high precision weight updates. 9 | # Allows direct control over momentum, variance and auxiliary compensation 10 | # buffer dtypes. 11 | # Optional Kahan summation is used to offset precision reduction for 12 | # the weight updates. This allows full training in BFloat16 (equal or 13 | # better than FP32 results in many cases) due to high precision weight upates. 14 | 15 | import torch 16 | from torch.optim.optimizer import Optimizer 17 | 18 | 19 | class AnyPrecisionAdamW(Optimizer): 20 | def __init__( 21 | self, 22 | params, 23 | lr=1e-3, 24 | betas=(0.9, 0.999), 25 | eps=1e-8, 26 | weight_decay=0.0, 27 | use_kahan_summation=False, 28 | momentum_dtype=torch.float32, 29 | variance_dtype=torch.bfloat16, 30 | compensation_buffer_dtype=torch.bfloat16, 31 | ): 32 | """ 33 | Args: 34 | params (iterable): iterable of parameters to optimize or dicts defining 35 | parameter groups 36 | lr (float, optional): learning rate (default: 1e-3) 37 | betas (Tuple[float, float], optional): coefficients used for computing 38 | running averages of gradient and its square (default: (0.9, 0.999)) 39 | eps (float, optional): term added to the denominator to improve 40 | numerical stability (default: 1e-8) 41 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 42 | 43 | # Any Precision specific 44 | use_kahan_summation = creates auxiliary buffer to ensure high precision 45 | model param updates (default: False) 46 | momentum_dtype = dtype for momentum (default: BFloat32) 47 | variance_dtype = dtype for uncentered variance (default: BFloat16) 48 | compensation_buffer_dtype = dtype for Kahan summation 49 | buffer (default: BFloat16). Only used if 50 | ``use_kahan_summation=True``. 51 | 52 | # Usage 53 | This optimizer implements optimizer states, and Kahan summation 54 | for high precision updates, all in user controlled dtypes. 55 | Defaults are variance in BF16, Momentum in FP32. 56 | This can be run in FSDP mixed precision, amp, or full precision, 57 | depending on what training pipeline you wish to work with. 58 | 59 | Setting to use_kahan_summation = False, and changing momentum and 60 | variance dtypes to FP32, reverts this to a standard AdamW optimizer. 61 | """ 62 | defaults = dict( 63 | lr=lr, 64 | betas=betas, 65 | eps=eps, 66 | weight_decay=weight_decay, 67 | use_kahan_summation=use_kahan_summation, 68 | momentum_dtype=momentum_dtype, 69 | variance_dtype=variance_dtype, 70 | compensation_buffer_dtype=compensation_buffer_dtype, 71 | ) 72 | 73 | super().__init__(params, defaults) 74 | 75 | @torch.no_grad() 76 | def step(self, closure=None): 77 | """Performs a single optimization step. 78 | Args: 79 | closure (callable, optional): A closure that reevaluates the model 80 | and returns the loss. 81 | """ 82 | 83 | if closure is not None: 84 | with torch.enable_grad(): 85 | # to fix linter, we do not keep the returned loss for use atm. 86 | closure() 87 | 88 | for group in self.param_groups: 89 | 90 | beta1, beta2 = group["betas"] 91 | lr = group["lr"] 92 | weight_decay = group["weight_decay"] 93 | eps = group["eps"] 94 | use_kahan_summation = group["use_kahan_summation"] 95 | 96 | momentum_dtype = group["momentum_dtype"] 97 | variance_dtype = group["variance_dtype"] 98 | compensation_buffer_dtype = group["compensation_buffer_dtype"] 99 | 100 | for p in group["params"]: 101 | if p.grad is None: 102 | continue 103 | 104 | if p.grad.is_sparse: 105 | raise RuntimeError( 106 | "AnyPrecisionAdamW does not support sparse gradients" 107 | ) 108 | 109 | state = self.state[p] 110 | 111 | # State initialization 112 | if len(state) == 0: 113 | 114 | state["step"] = torch.tensor(0.0) 115 | 116 | # momentum - EMA of gradient values 117 | state["exp_avg"] = torch.zeros_like( 118 | p, 119 | dtype=momentum_dtype, 120 | ) 121 | 122 | # variance uncentered - EMA of squared gradient values 123 | state["exp_avg_sq"] = torch.zeros_like( 124 | p, 125 | dtype=variance_dtype, 126 | ) 127 | 128 | # optional Kahan summation - accumulated error tracker 129 | if use_kahan_summation: 130 | state["compensation"] = torch.zeros_like( 131 | p, 132 | dtype=compensation_buffer_dtype, 133 | ) 134 | 135 | # main processing ------------------------- 136 | 137 | # update the steps for each param group update 138 | state["step"] += 1 139 | step = state["step"] 140 | 141 | exp_avg = state["exp_avg"] 142 | exp_avg_sq = state["exp_avg_sq"] 143 | 144 | grad = p.grad 145 | 146 | # weight decay, AdamW style 147 | if weight_decay: 148 | p.data.mul_(1 - lr * weight_decay) 149 | 150 | # update momentum 151 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 152 | 153 | # update uncentered variance 154 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 155 | 156 | # adjust using bias1 157 | bias_correction1 = 1 - beta1**step 158 | 159 | step_size = lr / bias_correction1 160 | 161 | # adjust using bias2 162 | denom_correction = (1 - beta2**step) ** 0.5 # avoids math import 163 | 164 | centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( 165 | eps, alpha=1 166 | ) 167 | 168 | # lr update to compensation 169 | if use_kahan_summation: 170 | compensation = state["compensation"] 171 | 172 | compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) 173 | 174 | # update weights with compensation (Kahan summation) 175 | # save error back to compensation for next iteration 176 | temp_buffer = p.detach().clone() 177 | p.data.add_(compensation) 178 | compensation.add_(temp_buffer.sub_(p.data)) 179 | 180 | else: 181 | # usual AdamW updates 182 | p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) 183 | -------------------------------------------------------------------------------- /src/python/torchdistx/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchdistx/9c1b9f5cb2fa36bfb8b70ec07c40ed42a33cc87a/src/python/torchdistx/py.typed -------------------------------------------------------------------------------- /src/python/torchdistx/slowmo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from . import slowmo_comm, slowmo_optimizer 8 | -------------------------------------------------------------------------------- /src/python/torchdistx/slowmo/slowmo_comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch.distributed.algorithms._comm_hooks import default 10 | 11 | 12 | class SlowMoState(default.DefaultState): 13 | r""" 14 | State for the `Slow Momentum `_ . 15 | 16 | Args: 17 | subgroup (ProcessGroup): stores subgroups, where communication will happen, 18 | by default a subgroup is initialized to workers, 19 | belonging to the same node. 20 | sync_grads (bool): if `True`, gradients will be communicated 21 | between members of the same subgroup (default: True). 22 | """ 23 | 24 | def __init__(self, subgroup, sync_grads=True): 25 | self.subgroup = subgroup if subgroup is not None else dist.new_subgroups()[0] 26 | super().__init__(self.subgroup) 27 | self.sync_grads = sync_grads 28 | 29 | 30 | def slowmo_hook(state: SlowMoState, grad: torch.Tensor): 31 | r""" 32 | If ``sync_grads`` is enabled in the ``state``, 33 | reduces gradients between workers under the same node. 34 | 35 | Args: 36 | state (SlowMoState): State information, configures 37 | if gradients are going to be communicated or not, 38 | and subgoups for gradient communication 39 | grad (torch.Tensor): A gradient for the local batch 40 | that needs to be communicated across ranks. 41 | """ 42 | if state.sync_grads: 43 | default.allreduce_hook(state, grad) 44 | -------------------------------------------------------------------------------- /src/python/torchdistx/slowmo/slowmo_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.distributed.algorithms.model_averaging.averagers as averagers 9 | 10 | 11 | class SlowMomentumOptimizer(torch.optim.Optimizer): 12 | r""" 13 | Wraps an arbitrary :class:`torch.optim.Optimizer` and runs 14 | FSDP distributed training with 15 | `Slow Momentum `_. 16 | Currently, only available for FSDP modules defined 17 | with a `NO_SHARD` strategy. 18 | 19 | Args: 20 | base_optim (torch.optim.Optimizer): 21 | The base optimizer, which updates local instance of a model 22 | slowmo_freq (int): Specifies how often (number of iterations) slow momentum 23 | is to be performed (default: 48) 24 | slowmo_factor (float): This specifies the value of slowmo momentum 25 | to be used (default: 0.5) 26 | slowmo_lr (float): This specifies the value of slowmo learning rate 27 | to be used (default: 1.0) 28 | 29 | Example:: 30 | 31 | >>> import torch 32 | >>> import torch.distributed as dist 33 | >>> from torch.distributed.fsdp import( 34 | >>> FullyShardedDataParallel as FSDP 35 | >>> ) 36 | >>> from torchdistx.slowmo import( 37 | >>> slowmo_comm, 38 | >>> slowmo_optimizer 39 | >>> ) 40 | >>> 41 | >>> net = torch.nn.Linear(4, 10) 42 | >>> fsdp_net = FSDP(net) 43 | >>> # This implementation communicates gradients between 44 | >>> # workers of the same node 45 | >>> # before averaging the model's parameters between nodes. 46 | >>> # The following creates intra-node subgroups 47 | >>> # and SlowMoState will take care of storing all required 48 | >>> # parameters for intra-node communication, 49 | >>> # i.e. pre- and post-division factors, and subgroups. 50 | >>> # To disable any communication between workers, 51 | >>> # set `sync_grads` to `False` 52 | >>> cur_subgroup, _ = dist.new_subgroups() 53 | >>> slowmo_state = slowmo_comm.SlowMoState( 54 | >>> cur_subgroup, 55 | >>> sync_grads=True 56 | >>> ) 57 | >>> 58 | >>> # Register SlowMo hook, which only communicates gradients 59 | >>> # in a intra-node fashion. 60 | >>> fsdp_net.register_comm_hook( 61 | >>> slowmo_state, 62 | >>> slowmo_comm.slowmo_hook 63 | >>> ) 64 | >>> 65 | >>> base_optimizer = torch.optim.SGD( 66 | >>> fsdp_net_slowmo.parameters(), 67 | >>> lr=1e-2 68 | >>> ) 69 | >>> # Create a SlowMo optimizer that wraps a local optimizer. 70 | >>> slowmo_optim = slowmo_optimizer.SlowMomentumOptimizer( 71 | >>> base_optim=base_optimizer, 72 | >>> slowmo_freq=6, 73 | >>> slowmo_factor=0.5, 74 | >>> slowmo_lr=0.1 75 | >>> ) 76 | >>> 77 | >>> # SlowMo runs intra-node gradient averaging at every step, 78 | >>> # every 6th step it will run model averaging and 79 | >>> # a slow momentum update. 80 | >>> for step in range(200): 81 | >>> slowmo_optim.zero_grad() 82 | >>> loss = loss_fn(output, labels) 83 | >>> loss.backward() 84 | >>> slowmo_optim.step() 85 | """ 86 | 87 | def __init__( 88 | self, 89 | base_optim: torch.optim.Optimizer, 90 | slowmo_freq: int = 48, 91 | slowmo_factor: float = 0.5, 92 | slowmo_lr: float = 1.0, 93 | ): 94 | if base_optim is None: 95 | raise ValueError("Base optimizer is a required parameter.") 96 | self._base_optim = base_optim 97 | 98 | # check that base optimizer's `param_groups` are present 99 | if not (self._base_optim.param_groups): 100 | raise ValueError( 101 | "Provided base optimizer does not have parameters specified." 102 | ) 103 | for group in self._base_optim.param_groups: 104 | if "lr" not in group: 105 | raise ValueError( 106 | "All parameter groups should have learning rate specified." 107 | ) 108 | 109 | self.param_groups = self._base_optim.param_groups 110 | 111 | if slowmo_freq < 1: 112 | raise ValueError( 113 | "Invalid ``slowmo_freq`` parameter, must be a positive value." 114 | ) 115 | self.slowmo_freq = slowmo_freq 116 | 117 | if slowmo_factor < 0.0: 118 | raise ValueError( 119 | "Invalid ``slowmo_factor`` parameter, must be non-negative." 120 | ) 121 | self.slowmo_factor = slowmo_factor 122 | 123 | if slowmo_lr < 0.0: 124 | raise ValueError("Invalid ``slowmo_lr`` parameter, must be non-negative.") 125 | self.slowmo_lr = slowmo_lr 126 | 127 | self.averager = averagers.PeriodicModelAverager( 128 | period=slowmo_freq, warmup_steps=0 129 | ) 130 | self.buffers_initialized = False 131 | 132 | # Memorize initial parameters before the first `step()`. 133 | # Can't put them in `self.state`, because some of optimizers rely 134 | # `self.state` being empty during the `step()` 135 | # to initialize optimizer states. 136 | # `self._prev_parameters` must be in sync with 137 | # the flattened version of `self.param_groups`, 138 | # since this implementation relies on `self._prev_parameters` 139 | # having the same order of parameters as in `self.param_groups` 140 | # to perform a slow momentum update. 141 | self._prev_parameters = [] 142 | for group in self.param_groups: 143 | for param in group["params"]: 144 | self._prev_parameters.append(param.detach().clone()) 145 | 146 | @property 147 | def state(self): 148 | r""" 149 | Forwards to base optimizer's ``state``. 150 | """ 151 | return self._base_optim.state 152 | 153 | def __repr__(self): 154 | return self._base_optim.__repr__() 155 | 156 | def state_dict(self): 157 | r""" 158 | This is the same as :class:`torch.optim.Optimizer` 159 | :meth:`state_dict`, but adds an extra entries to record 160 | Slow Momentum's specific parameters: ``slowmo_freq``, 161 | ``slowmo_factor``, ``slowmo_lr``, and ``step`` for the model's averager. 162 | """ 163 | optim_state_dict = self._base_optim.state_dict() 164 | optim_state_dict["slowmo_freq"] = self.slowmo_freq 165 | optim_state_dict["slowmo_factor"] = self.slowmo_factor 166 | optim_state_dict["slowmo_lr"] = self.slowmo_lr 167 | optim_state_dict["step"] = self.averager.step 168 | 169 | return optim_state_dict 170 | 171 | def load_state_dict(self, state_dict): 172 | r""" 173 | This is the same as :class:`torch.optim.Optimizer` 174 | :meth:`load_state_dict`, but also restores Slow Momentum's 175 | specific parameters, saved in the provided ``state_dict``. 176 | """ 177 | self.slowmo_freq = state_dict["slowmo_freq"] 178 | self.averager.period = state_dict.pop("slowmo_freq") 179 | self.slowmo_factor = state_dict.pop("slowmo_factor") 180 | self.slowmo_lr = state_dict.pop("slowmo_lr") 181 | self.averager.step = state_dict.pop("step") 182 | self._base_optim.load_state_dict(state_dict) 183 | if not self.param_groups: 184 | raise ValueError("Base optimizer does not have parameter groups specified.") 185 | for group in self._base_optim.param_groups: 186 | if "lr" not in group: 187 | raise ValueError( 188 | "All parameter groups should have learning rate specified." 189 | ) 190 | 191 | @torch.no_grad() 192 | def step(self): 193 | r""" 194 | Performs a single optimization step (parameter update) 195 | and a slow momentum update. Slow momentum update involves 196 | model's exact averaging of parameters and a momentum update, 197 | which happens every `slowmo_freq` step. 198 | """ 199 | self._base_optim.step() 200 | # Averager averages parameters between workers every `slowmo_freq` step. 201 | # At other times it just increases step counter. 202 | self.averager.average_parameters(params=self.param_groups) 203 | # Since at this point averager has increased its step, 204 | # we need to check (self.averager.step - 1). 205 | # No need to do momentum step at step 0. 206 | if (self.averager.step - 1) % self.slowmo_freq == 0 and self.averager.step != 1: 207 | prev_param_idx = 0 208 | for group in self.param_groups: 209 | for param in group["params"]: 210 | # Initialize momentums if they were not initialized 211 | if "slow_momentum" not in self.state[param]: 212 | self.state[param]["slow_momentum"] = torch.zeros( 213 | param.shape, device=torch.cuda.current_device() 214 | ) 215 | 216 | # Update the slow momentum 217 | p_state = self.state[param] 218 | factor = 1 / group["lr"] 219 | p_state["slow_momentum"].mul_(self.slowmo_factor).sub_( 220 | param, alpha=factor 221 | ).add_(self._prev_parameters[prev_param_idx], alpha=factor) 222 | # Update parameters 223 | self._prev_parameters[prev_param_idx].add_( 224 | p_state["slow_momentum"], alpha=-self.slowmo_lr * group["lr"] 225 | ) 226 | param.copy_(self._prev_parameters[prev_param_idx]) 227 | prev_param_idx += 1 228 | 229 | def zero_grad(self, set_to_none: bool = False): # type: ignore[override] 230 | self._base_optim.zero_grad(set_to_none=set_to_none) 231 | 232 | def add_param_group(self, param_group): 233 | self._base_optim.add_param_group(param_group) 234 | for param in param_group["params"]: 235 | self._prev_parameters.append(param.detach().clone()) 236 | -------------------------------------------------------------------------------- /tests/cc/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchdistx/9c1b9f5cb2fa36bfb8b70ec07c40ed42a33cc87a/tests/cc/.gitkeep -------------------------------------------------------------------------------- /tests/python/test_anyprecision_optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | from copy import deepcopy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.testing._internal.common_utils import ( 14 | TestCase, 15 | instantiate_parametrized_tests, 16 | parametrize, 17 | run_tests, 18 | ) 19 | 20 | from torchdistx.optimizers import AnyPrecisionAdamW 21 | 22 | 23 | class TestAnyPrecisionOptimizer(TestCase): 24 | def _test_adam_equivalence(self, model, model_clone): 25 | # Test non-default options 26 | betas = (0.8, 0.88) 27 | weight_decay = 0.03 28 | 29 | adam_opt = optim.AdamW( 30 | model_clone.parameters(), betas=betas, weight_decay=weight_decay 31 | ) 32 | anyprecision_adam = AnyPrecisionAdamW( 33 | model.parameters(), 34 | variance_dtype=torch.float32, 35 | betas=betas, 36 | weight_decay=weight_decay, 37 | ) 38 | 39 | # Verify params are equal initially 40 | model_orig_params = [p.clone() for p in model.parameters()] 41 | for p1, p2 in zip(model_clone.parameters(), model_orig_params): 42 | self.assertEqual(p1, p2) 43 | 44 | for i in range(6): 45 | adam_opt.zero_grad() 46 | anyprecision_adam.zero_grad() 47 | inp = torch.randn(5, 5, device=next(model.parameters()).device) 48 | model(inp).sum().backward() 49 | model_clone(inp).sum().backward() 50 | adam_opt.step() 51 | anyprecision_adam.step() 52 | 53 | # Ensure params are modified from original 54 | if i == 0: 55 | for p1, p2 in zip(model.parameters(), model_orig_params): 56 | self.assertNotEqual(p1, p2) 57 | 58 | for p1, p2 in zip(model.parameters(), model_clone.parameters()): 59 | self.assertEqual(p1, p2) 60 | 61 | @parametrize("device", ["cpu", "cuda"]) 62 | def test_adam_equivalence(self, device): 63 | """ 64 | Tests that AnyPrecisionAdamW is equivalent to AdamW when 65 | kahan summation and different dtypes for momentum, variance, 66 | and compensation buffer are turned off (i.e. all float32). 67 | """ 68 | if device == "cuda" and not torch.cuda.is_available(): 69 | raise unittest.SkipTest("CUDA not available") 70 | 71 | model = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5), nn.Linear(5, 5)) 72 | if device == "cuda": 73 | model.cuda() 74 | 75 | model_clone = deepcopy(model) 76 | 77 | self._test_adam_equivalence(model, model_clone) 78 | 79 | 80 | instantiate_parametrized_tests(TestAnyPrecisionOptimizer) 81 | 82 | if __name__ == "__main__": 83 | run_tests() 84 | -------------------------------------------------------------------------------- /tests/python/test_deferred_init.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import cast 8 | 9 | import torch 10 | from torch import Tensor 11 | from torch.nn import Module, Parameter 12 | 13 | from torchdistx.deferred_init import ( 14 | deferred_init, 15 | is_deferred, 16 | materialize_module, 17 | materialize_tensor, 18 | ) 19 | 20 | 21 | def test_materialize_tensor_is_noop_for_real_tensors() -> None: 22 | a = torch.ones([10]) 23 | 24 | e = materialize_tensor(a) 25 | 26 | assert a is e 27 | 28 | 29 | def test_materialize_tensor_returns_same_tensor() -> None: 30 | class FooModule(Module): 31 | def __init__(self): 32 | super().__init__() 33 | 34 | self.param1 = Parameter(torch.ones([5])) 35 | self.param2 = self.param1 36 | 37 | module = deferred_init(FooModule) 38 | 39 | a = materialize_tensor(cast(Tensor, module.param1)) 40 | b = materialize_tensor(cast(Tensor, module.param1)) 41 | c = materialize_tensor(cast(Tensor, module.param2)) 42 | 43 | assert a is b 44 | assert a is c 45 | 46 | 47 | def test_is_deferred_returns_right_value() -> None: 48 | class FooModule(Module): 49 | def __init__(self): 50 | super().__init__() 51 | 52 | self.param1 = Parameter(torch.ones([5])) 53 | self.param2 = Parameter(torch.ones([5])) 54 | 55 | module = FooModule() 56 | 57 | assert not is_deferred(module) 58 | 59 | module = deferred_init(FooModule) 60 | 61 | assert is_deferred(module) 62 | 63 | materialize_module(module) 64 | 65 | assert not is_deferred(module) 66 | 67 | module = deferred_init(FooModule) 68 | 69 | module.param1 = materialize_tensor(module.param1) 70 | 71 | assert is_deferred(module) 72 | 73 | module.param2 = materialize_tensor(module.param2) 74 | 75 | assert not is_deferred(module) 76 | -------------------------------------------------------------------------------- /tests/python/test_fake.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pytest 8 | import torch 9 | 10 | from torchdistx.fake import fake_mode, is_fake, meta_like 11 | 12 | 13 | def test_fake_mode_returns_cuda_tensor_if_fake_cuda_is_true() -> None: 14 | if torch.cuda.is_available(): 15 | pytest.skip("Can only be tested if CUDA is not available.") 16 | 17 | with fake_mode(fake_cuda=True): 18 | a = torch.ones([10], device="cuda") 19 | 20 | assert a.device.type == "cuda" 21 | 22 | 23 | def test_fake_mode_raises_error_if_fake_cuda_is_false() -> None: 24 | if torch.cuda.is_available(): 25 | pytest.skip("Can only be tested if CUDA is not available.") 26 | 27 | with pytest.raises((AssertionError, RuntimeError)): 28 | with fake_mode(): 29 | torch.ones([10], device="cuda") 30 | 31 | 32 | def test_cuda_tensor_raises_error_after_fake_mode() -> None: 33 | if torch.cuda.is_available(): 34 | pytest.skip("Can only be tested if CUDA is not available.") 35 | 36 | with fake_mode(fake_cuda=True): 37 | torch.ones([10], device="cuda") 38 | 39 | with pytest.raises((AssertionError, RuntimeError)): 40 | torch.ones([10], device="cuda") 41 | 42 | 43 | def test_meta_like_returns_meta_tensor() -> None: 44 | with fake_mode(): 45 | a = torch.ones([10]) 46 | 47 | b = meta_like(a) 48 | 49 | assert not is_fake(b) 50 | assert b.device.type == "meta" 51 | assert b.dtype == a.dtype 52 | assert b.size() == a.size() 53 | assert b.stride() == a.stride() 54 | 55 | 56 | def test_meta_like_raises_error_if_tensor_is_not_fake() -> None: 57 | a = torch.ones([10]) 58 | 59 | with pytest.raises(ValueError): 60 | meta_like(a) 61 | -------------------------------------------------------------------------------- /use-cpu.txt: -------------------------------------------------------------------------------- 1 | --pre --extra-index-url=https://download.pytorch.org/whl/nightly/cpu 2 | -------------------------------------------------------------------------------- /use-cu117.txt: -------------------------------------------------------------------------------- 1 | --pre --extra-index-url=https://download.pytorch.org/whl/nightly/cu117 2 | -------------------------------------------------------------------------------- /use-cu118.txt: -------------------------------------------------------------------------------- 1 | --pre --extra-index-url=https://download.pytorch.org/whl/nightly/cu118 2 | --------------------------------------------------------------------------------