├── .gitignore ├── .flake8 ├── autoparallel ├── __init__.py ├── dtensor_util │ └── __init__.py ├── _passes │ ├── graph_partition.py │ ├── split_fsdp_collectives.py │ ├── graph_multiplex.py │ └── split_di_dw_graph.py ├── init_weights.py ├── collective_runtime_estimation.py ├── collectives.py ├── auto_bucketing.py ├── cast_parametrization.py ├── graph_clustering.py ├── autobucketing_util │ ├── estimation.py │ └── reorder.py ├── debug_helpers.py └── graph_utils.py ├── requirements-test.txt ├── mast ├── .torchxconfig ├── run_torchtitan.sh ├── mount.sh └── sweep.py ├── pyproject.toml ├── .github └── workflows │ ├── lint.yml │ ├── test_cuda.yml │ └── test_torchtitan.yml ├── .pre-commit-config.yaml ├── README.md ├── LICENSE ├── CONTRIBUTING.md ├── tests ├── test_aot_eager.py ├── test_propagation_rules.py └── test_ordered_sharding.py ├── CLAUDE.md ├── examples ├── native_ds3 │ └── test_simple_batched_mm.py ├── run_ds3_numerics_check.py ├── example_autoparallel.py ├── example_local_map.py ├── example_ds3_local_map.py └── example_llama3.py ├── CODE_OF_CONDUCT.md └── partitioned_shard_proposal.md /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *.swp 3 | 4 | *.pyc 5 | *.pyo 6 | *.so 7 | 8 | .mypy_cache/ 9 | *.egg-info/ 10 | 11 | build/ 12 | dist/ 13 | tmp/ 14 | out/ 15 | 16 | .vscode/ 17 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = 3 | .git 4 | max-line-length = 140 5 | copyright-check = True 6 | select = E,F,W,C 7 | copyright-regexp=Copyright \(c\) Facebook, Inc. and its affiliates. All Rights Reserved 8 | ignore=W503,E203 9 | -------------------------------------------------------------------------------- /autoparallel/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | torch >= 2.7.0 2 | numpy 3 | pulp 4 | pytest >= 8.1 5 | expecttest 6 | psutil 7 | 8 | black == 22.3.0 9 | flake8 == 6.1.0 10 | flake8-copyright 11 | isort == 5.7.0 12 | mypy == 1.10.0 13 | tabulate 14 | types-tabulate 15 | -------------------------------------------------------------------------------- /autoparallel/dtensor_util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # functions to expose 7 | from .utils import ( 8 | batch_shard_strategy, 9 | get_op_strategy, 10 | op_strategy_context, 11 | replicate_op_strategy, 12 | with_implicit_strategies, 13 | ) 14 | 15 | __all__ = [ 16 | "replicate_op_strategy", 17 | "batch_shard_strategy", 18 | "get_op_strategy", 19 | "with_implicit_strategies", 20 | "op_strategy_context", 21 | ] 22 | -------------------------------------------------------------------------------- /mast/.torchxconfig: -------------------------------------------------------------------------------- 1 | [mast_conda] 2 | conda_path_in_fbpkg = conda 3 | activate_conda = False 4 | fbpkg_ids = fb-py-spy:prod 5 | hpcIdentity = pytorch_distributed 6 | rmAttribution = msl_infra_pytorch_dev 7 | workspace_fbpkg_name = torchtitan_workspace 8 | conda_pack_ignore_missing_files = True 9 | git = False 10 | hpcJobOncall = meta_conda 11 | modelTypeName = gen_ai_conda 12 | hpcClusterUuid = MastGenAICluster 13 | localityConstraints = region;gtn 14 | forceSingleRegion = False 15 | use_caf = False 16 | 17 | [component:mast.py:train] 18 | name = torchtitan 19 | ; other hardware options can be found at 20 | ; https://www.internalfb.com/code/fbsource/fbcode/torchx/specs/fb/named_resources.py 21 | h=grandteton 22 | run_as_root = True 23 | 24 | [cli:run] 25 | scheduler = mast_conda 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "autoparallel" 7 | version = "0.0.1" 8 | authors = [ 9 | { name="Francisco Massa", email="fmassa@meta.com" }, 10 | ] 11 | description = "Automatic PyTorch model sharding" 12 | readme = "README.md" 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "Operating System :: OS Independent", 16 | ] 17 | license = { file = "LICENSE" } 18 | requires-python = ">=3.10" 19 | dependencies = [ 20 | "torch>=2.7.0", 21 | "typing_extensions>=4.0.0", 22 | "filecheck", 23 | "pulp" 24 | ] 25 | 26 | [project.urls] 27 | Homepage = "https://github.com/pytorch-labs/autoparallel" 28 | Issues = "https://github.com/pytorch-labs/autoparallel/issues" 29 | 30 | [tool.hatch.build.targets.wheel] 31 | packages = ["autoparallel"] 32 | 33 | [tool.hatch.build] 34 | include = [ 35 | "autoparallel/**/*.py", 36 | "autoparallel/**/*.pyi", 37 | "LICENSE", 38 | ] 39 | exclude = [ 40 | "tests/**/*", 41 | "examples/**/*", 42 | ] 43 | 44 | [tool.hatch.metadata] 45 | allow-direct-references = true 46 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | - release/* 9 | 10 | jobs: 11 | linters: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Check out code 16 | uses: actions/checkout@v4 17 | 18 | - name: Setup Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: '3.10' 22 | - name: Install deps 23 | run: | 24 | pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu 25 | pip install -r requirements-test.txt 26 | - name: isort 27 | if: success() || failure() 28 | run: python -m isort . --check --profile black 29 | - name: black 30 | if: success() || failure() 31 | run: python -m black --check . 32 | - name: mypy 33 | if: success() || failure() 34 | run: | 35 | python -m mypy --version 36 | python -m mypy --ignore-missing-imports --scripts-are-modules --pretty --exclude "(docs|examples)" . 37 | - name: flake8 38 | if: success() || failure() 39 | run: python -m flake8 --config .flake8 --show-source --statistics 40 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'build|stubs' 2 | 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v3.4.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: check-ast 12 | - id: check-merge-conflict 13 | - id: no-commit-to-branch 14 | args: ['--branch=master'] 15 | - id: check-added-large-files 16 | args: ['--maxkb=500'] 17 | - id: end-of-file-fixer 18 | 19 | - repo: https://github.com/ambv/black 20 | rev: 22.3.0 21 | hooks: 22 | - id: black 23 | 24 | - repo: https://github.com/pycqa/flake8 25 | rev: 6.1.0 26 | hooks: 27 | - id: flake8 28 | 29 | - repo: https://github.com/pycqa/isort 30 | rev: 5.12.0 31 | hooks: 32 | - id: isort 33 | exclude: README.md 34 | additional_dependencies: [toml] 35 | args: ["--profile", "black"] 36 | 37 | - repo: local 38 | hooks: 39 | - id: mypy 40 | require_serial: true 41 | name: mypy 42 | entry: mypy 43 | language: system 44 | types: [python] 45 | exclude: (docs|examples) 46 | args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"] 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoParallel 2 | 3 | > ⚠️ **Early Development Warning** Autoparallel is currently in an experimental 4 | > stage. You should expect bugs, incomplete features, and APIs that may change 5 | > in future versions. The project welcomes bugfixes, but to make sure things are 6 | > well coordinated you should discuss any significant change before starting the 7 | > work. It's recommended that you signal your intention to contribute in the 8 | > issue tracker, either by filing a new issue or by claiming an existing one. 9 | 10 | This currently works on PyTorch 2.8.0.dev20250506. 11 | 12 | ## Installing it 13 | 14 | ``` 15 | pip install git+ssh://git@github.com/pytorch-labs/autoparallel.git 16 | ``` 17 | 18 | ## Developing it 19 | ``` 20 | cd autoparallel 21 | pip install -e . 22 | ``` 23 | Modified Python files will be reflected immediately. 24 | 25 | Run linter before submitting the PR 26 | ``` 27 | pip install pre-commit 28 | pre-commit run --all-files 29 | ``` 30 | 31 | If you got ``An unexpected error has occurred: ... 'python3.11')``, try modify `.pre-commit-config.yaml`/`language_version: python3.11` to match your python version. 32 | 33 | ## Running it 34 | 35 | ``` 36 | python examples/example_autoparallel.py 37 | ``` 38 | 39 | ## License 40 | 41 | Autoparallel is BSD-3 licensed, as found in the [LICENSE](LICENSE) file. 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | Redistribution and use in source and binary forms, with or without modification, 4 | are permitted provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this 7 | list of conditions and the following disclaimer. 8 | 9 | * Redistributions in binary form must reproduce the above copyright notice, 10 | this list of conditions and the following disclaimer in the documentation 11 | and/or other materials provided with the distribution. 12 | 13 | * Neither the name Meta nor the names of its contributors may be used to 14 | endorse or promote products derived from this software without specific 15 | prior written permission. 16 | 17 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 18 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 19 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 20 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 21 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 22 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 23 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 24 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | -------------------------------------------------------------------------------- /.github/workflows/test_cuda.yml: -------------------------------------------------------------------------------- 1 | name: Test CUDA 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | - release/* 9 | 10 | concurrency: 11 | group: test-cuda-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | test-cuda: 16 | name: Test CUDA (cuda12.6-py3.12) 17 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 18 | strategy: 19 | fail-fast: true 20 | matrix: 21 | include: 22 | - name: 12xlargegpu 23 | runs-on: linux.g5.12xlarge.nvidia.gpu 24 | torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126' 25 | gpu-arch-type: "cuda" 26 | gpu-arch-version: "12.6" 27 | with: 28 | timeout: 60 29 | runner: ${{ matrix.runs-on }} 30 | gpu-arch-type: ${{ matrix.gpu-arch-type }} 31 | gpu-arch-version: ${{ matrix.gpu-arch-version }} 32 | submodules: recursive 33 | script: | 34 | conda create --yes --quiet --name py312 python=3.12 35 | source $(conda info --base)/etc/profile.d/conda.sh 36 | conda activate py312 37 | 38 | pip install --quiet -r requirements-test.txt 39 | # For some reason the spec above isnt working 40 | pip uninstall -y torch 41 | pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 42 | pip install --quiet . 43 | pytest tests 44 | python examples/example_autoparallel.py 45 | python examples/example_llama3.py 46 | python examples/example_dcp.py 47 | python examples/example_local_map.py 48 | python examples/example_pp_graph_passes.py 49 | torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py 50 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Meta Open Source Projects 2 | 3 | We want to make contributing to this project as easy and transparent as 4 | possible. 5 | 6 | ## Pull Requests 7 | We actively welcome your pull requests. 8 | 9 | Note: pull requests are not imported into the GitHub directory in the usual way. There is an internal Meta repository that is the "source of truth" for the project. The GitHub repository is generated *from* the internal Meta repository. So we don't merge GitHub PRs directly to the GitHub repository -- they must first be imported into internal Meta repository. When Meta employees look at the GitHub PR, there is a special button visible only to them that executes that import. The changes are then automatically reflected from the internal Meta repository back to GitHub. This is why you won't see your PR having being directly merged, but you still see your changes in the repository once it reflects the imported changes. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Meta's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## License 33 | By contributing to this project, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. 35 | -------------------------------------------------------------------------------- /.github/workflows/test_torchtitan.yml: -------------------------------------------------------------------------------- 1 | name: Test TorchTitan Integration 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | - release/* 9 | 10 | concurrency: 11 | group: test-torchtitan-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 12 | cancel-in-progress: true 13 | 14 | jobs: 15 | test-torchtitan: 16 | name: Test TorchTitan Integration (cuda12.6-py3.12) 17 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 18 | strategy: 19 | fail-fast: true 20 | matrix: 21 | include: 22 | - name: 12xlargegpu 23 | runs-on: linux.g5.12xlarge.nvidia.gpu 24 | torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu126' 25 | gpu-arch-type: "cuda" 26 | gpu-arch-version: "12.6" 27 | with: 28 | timeout: 60 29 | runner: ${{ matrix.runs-on }} 30 | gpu-arch-type: ${{ matrix.gpu-arch-type }} 31 | gpu-arch-version: ${{ matrix.gpu-arch-version }} 32 | submodules: recursive 33 | script: | 34 | conda create --yes --quiet --name py312 python=3.12 35 | source $(conda info --base)/etc/profile.d/conda.sh 36 | conda activate py312 37 | 38 | pip install --quiet -r requirements-test.txt 39 | # For some reason the spec above isnt working 40 | pip uninstall -y torch 41 | pip install --no-input --quiet --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 42 | pip install --quiet . 43 | 44 | # Clone TorchTitan 45 | git clone https://github.com/pytorch/torchtitan.git 46 | cd torchtitan 47 | pip install --quiet -r requirements.txt 48 | 49 | # Run TorchTitan training with AutoParallel 50 | NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh \ 51 | --model.name autoparallel.llama3 \ 52 | --parallelism.tensor_parallel_degree 4 \ 53 | --training.dataset c4 \ 54 | --compile.enable \ 55 | --job.custom_config_module=torchtitan.experiments.autoparallel.job_config 56 | -------------------------------------------------------------------------------- /tests/test_aot_eager.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import pytest 7 | import torch 8 | from torch.utils._debug_mode import DebugMode 9 | 10 | from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs 11 | 12 | # TODO: make device generic 13 | 14 | 15 | @pytest.fixture(scope="module") 16 | def llama3_debug_model(): 17 | torch.manual_seed(1999) 18 | model_args = TransformerModelArgs( 19 | dim=256, n_layers=6, n_heads=16, vocab_size=2048, rope_theta=500000 20 | ) 21 | return Transformer(model_args).cuda() 22 | 23 | 24 | def test_deterministic(llama3_debug_model): 25 | batch_size = 8 26 | seqlen = 2048 27 | vocab_size = llama3_debug_model.model_args.vocab_size 28 | torch.manual_seed(2999) 29 | x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda") 30 | torch.manual_seed(3999) 31 | r1 = llama3_debug_model(x) 32 | torch.manual_seed(3999) 33 | r2 = llama3_debug_model(x) 34 | assert torch.equal(r1, r2) # bitwise equal 35 | 36 | 37 | def test_debug_mode_bitwise_equivalent(llama3_debug_model): 38 | batch_size = 8 39 | seqlen = 2048 40 | vocab_size = llama3_debug_model.model_args.vocab_size 41 | torch.manual_seed(2999) 42 | x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda") 43 | torch.manual_seed(3999) 44 | r1 = llama3_debug_model(x) 45 | torch.manual_seed(3999) 46 | with DebugMode() as debug_mode: 47 | r2 = llama3_debug_model(x) 48 | print(debug_mode.debug_string()) 49 | assert torch.equal(r1, r2) # bitwise equal 50 | 51 | 52 | @pytest.mark.xfail 53 | def test_aot_eager_bitwise_equivalent(llama3_debug_model): 54 | batch_size = 8 55 | seqlen = 2048 56 | vocab_size = llama3_debug_model.model_args.vocab_size 57 | torch.manual_seed(2999) 58 | x = torch.randint(0, vocab_size, (batch_size, seqlen), device="cuda") 59 | torch.manual_seed(3999) 60 | r1 = llama3_debug_model(x) 61 | grads1 = torch.autograd.grad(r1.sum(), llama3_debug_model.parameters()) 62 | torch.manual_seed(3999) 63 | r2 = torch.compile(backend="aot_eager")(llama3_debug_model)(x) 64 | grads2 = torch.autograd.grad(r2.sum(), llama3_debug_model.parameters()) 65 | assert torch.equal(r1, r2) # bitwise equal 66 | for g1, g2 in zip(grads1, grads2): 67 | assert torch.equal(g1, g2) 68 | -------------------------------------------------------------------------------- /mast/run_torchtitan.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | # 4 | # This source code is licensed under the BSD license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | set -x 8 | 9 | if [[ $# -lt 1 ]]; then 10 | echo "Incorrect number of arguments (0)" 11 | echo "Usage: $0 config_file " 12 | exit 1 13 | fi 14 | 15 | # consume config file and leave remaining args to 'overrides' 16 | CONFIG_FILE=${1} 17 | shift 18 | 19 | overrides="" 20 | if [ $# -gt 0 ]; then 21 | overrides="$*" 22 | fi 23 | 24 | edir="${DUMP_DIR}" 25 | ename="${JOB_ID}_v${MAST_HPC_JOB_VERSION}_a${MAST_HPC_JOB_ATTEMPT_INDEX}" 26 | dataset_path="/mnt/mffuse/c4" 27 | save_tb_folder="/mnt/wsfuse/outputs/${JOB_ID}/tb" 28 | 29 | 30 | echo dump_dir=$edir 31 | echo experiment_name=$ename 32 | 33 | 34 | LIBCUDA="/usr/local/fbcode/platform010/lib/libcuda.so" 35 | export LIBCUDA_DIR="${LIBCUDA%/*}" 36 | export TRITON_LIBCUDA_PATH="/usr/local/fbcode/platform010/lib/" 37 | export LD_PRELOAD="${PRELOAD_PATH:=$LIBCUDA:/usr/local/fbcode/platform010/lib/libnvidia-ml.so}" 38 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${CONDA_DIR}/lib" 39 | export PYTHONPATH="$PYTHONPATH:$TORCHX_RUN_PYTHONPATH" 40 | 41 | source ${CONDA_DIR}/bin/activate 42 | 43 | cd /packages/torchtitan_additional_packages/torchtitan 44 | 45 | ############### 46 | # do whatever you like below 47 | ############### 48 | 49 | if [ -n "${WANDB_API_KEY}" ]; then 50 | wandb login --host=https://meta.wandb.io 51 | fi 52 | 53 | if [ -n "$LIGHTHOUSE_SMC_TIER" ]; then 54 | # Run smcc command until it returns a host:port pair 55 | while true; do 56 | service=$(/packages/torchft_smcc/smcc list-services --enabled "$LIGHTHOUSE_SMC_TIER" | head -n 1) 57 | if [ -n "$service" ]; then 58 | break 59 | fi 60 | sleep 1 61 | done 62 | 63 | # Set TORCHFT_LIGHTHOUSE environment variable 64 | export TORCHFT_LIGHTHOUSE="http://$service" 65 | echo "TORCHFT_LIGHTHOUSE set to $TORCHFT_LIGHTHOUSE" 66 | else 67 | echo "LIGHTHOUSE_SMC_TIER env not set, skipping..." 68 | fi 69 | 70 | 71 | PYTORCH_KERNEL_CACHE_PATH="/mnt/mffuse/.cache/torch/kernels" \ 72 | PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ 73 | TORCH_DISABLE_ADDR2LINE=1 \ 74 | python torchtitan/train.py \ 75 | --job.config_file "${CONFIG_FILE}" \ 76 | --job.dump_folder "${edir}" \ 77 | --training.dataset_path "${dataset_path}" \ 78 | --validation.dataset_path "${dataset_path}" \ 79 | --metrics.save_tb_folder "${save_tb_folder}" \ 80 | --metrics.disable_color_printing \ 81 | --job.save_config_file "params.json" \ 82 | $overrides 83 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Project Overview 6 | 7 | AutoParallel is a PyTorch library for automatic model sharding and parallelization. It analyzes PyTorch models and automatically determines optimal sharding strategies for distributed training. 8 | 9 | **WARNING**: This project is highly under development. See README.md for current PyTorch version requirements. 10 | 11 | ## Core Architecture 12 | 13 | The library consists of several key components that work together: 14 | 15 | - **api.py**: Main entry point with `AutoParallel` class that orchestrates the sharding process 16 | - **optimize_sharding.py**: Contains `ShardingOptimizer` that uses PuLP (linear programming) to find optimal sharding strategies 17 | - **apply_sharding.py**: Applies computed sharding strategies to PyTorch models using DTensor specs 18 | - **propagation_rules.py**: Defines how tensor sharding propagates through different PyTorch operations 19 | - **compute_estimation.py**: Estimates runtime costs for different sharding strategies 20 | - **export_module.py**: Handles AOT (Ahead-of-Time) compilation and module export 21 | 22 | The optimization flow: Model → FX Graph → Sharding Options → Linear Program → Optimal Strategy → Apply Sharding 23 | 24 | ## Development Commands 25 | 26 | ### Setup 27 | ```bash 28 | # Install in development mode 29 | uv pip install -e . 30 | ``` 31 | 32 | ### Linting and Code Quality 33 | ```bash 34 | # Install pre-commit hooks 35 | uv pip install pre-commit 36 | 37 | # Run all linters and formatters 38 | pre-commit run --all-files 39 | ``` 40 | 41 | The pre-commit setup includes: 42 | - Black (code formatting) 43 | - flake8 (linting) 44 | - isort (import sorting) 45 | - mypy (type checking) 46 | 47 | ### Running Examples 48 | ```bash 49 | # Basic autoparallel example 50 | python examples/example_autoparallel.py 51 | 52 | # LLaMA-3 example 53 | python examples/example_llama3.py 54 | ``` 55 | 56 | ### Testing 57 | ```bash 58 | # Run tests (check for pytest or unittest patterns) 59 | python -m pytest tests/ 60 | ``` 61 | 62 | ## Key Dependencies 63 | 64 | - **torch**: Core PyTorch functionality and distributed tensor support 65 | - **pulp**: Linear programming solver for optimization 66 | - **filecheck**: Testing utilities 67 | 68 | ## Development Notes 69 | 70 | - Requires Python ≥3.10 71 | - Uses PyTorch's FX graph representation for model analysis 72 | - Leverages DTensor for distributed tensor operations 73 | - Uses linear programming (PuLP) to solve sharding optimization problems 74 | - Includes fake tensor mode for shape inference without actual computation 75 | -------------------------------------------------------------------------------- /examples/native_ds3/test_simple_batched_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import torch.utils._pytree as pytree 8 | from torch.utils._python_dispatch import TorchDispatchMode 9 | 10 | 11 | class TestMode(TorchDispatchMode): 12 | def __torch_dispatch__(self, func, types, args=..., kwargs=None): 13 | kwargs = kwargs if kwargs else {} 14 | out = func(*args, **kwargs) 15 | print("Op:", func) 16 | print("Args:") 17 | pytree.tree_map_only(torch.Tensor, lambda x: print(x.shape), args) 18 | print("Out:") 19 | pytree.tree_map_only(torch.Tensor, lambda x: print(x.shape), out) 20 | print( 21 | "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^" 22 | ) 23 | return out 24 | 25 | 26 | @torch.library.custom_op("autoparallel::batched_mm", mutates_args=()) 27 | def batched_mm( 28 | mat1: torch.Tensor, 29 | mat2: torch.Tensor, 30 | ) -> torch.Tensor: 31 | assert mat1.ndim == 3 32 | assert mat2.ndim == 2 or mat2.ndim == 3 33 | if mat2.ndim == 2: 34 | assert mat1.shape[2] == mat2.shape[0] 35 | mat2_expanded = mat2.expand(mat1.shape[0], -1, -1) 36 | else: 37 | assert mat1.shape[0] == mat2.shape[0] 38 | assert mat1.shape[2] == mat2.shape[1] 39 | mat2_expanded = mat2 40 | out = torch.bmm(mat1, mat2_expanded) 41 | return out 42 | 43 | 44 | def setup_context_batched_mm(ctx, inputs, output): 45 | mat1, mat2 = inputs 46 | ctx.save_for_backward(mat1, mat2) 47 | 48 | 49 | def backward_batched_mm(ctx, grad): 50 | assert grad.ndim == 3 51 | mat1, mat2 = ctx.saved_tensors 52 | grad1 = batched_mm(grad, mat2.transpose(-2, -1)) 53 | grad2 = torch.sum(batched_mm(mat1.transpose(-2, -1), grad), dim=0) 54 | return grad1, grad2 55 | 56 | 57 | torch.library.register_autograd( 58 | "autoparallel::batched_mm", 59 | backward_batched_mm, 60 | setup_context=setup_context_batched_mm, 61 | ) 62 | 63 | 64 | if __name__ == "__main__": 65 | DEVICE = "cuda" 66 | 67 | mat1 = torch.rand( 68 | 10, 32, 16, device=DEVICE, dtype=torch.float32, requires_grad=True 69 | ) 70 | mat2 = torch.rand(48, 16, device=DEVICE, dtype=torch.float32, requires_grad=True) 71 | 72 | out = batched_mm(mat1, mat2.transpose(-2, -1)) 73 | out = out.sum() 74 | with TestMode(): 75 | out.backward() 76 | mat1grad = mat1.grad 77 | mat2grad = mat2.grad 78 | mat1.grad = None 79 | mat2.grad = None 80 | out3 = mat1 @ mat2.transpose(-2, -1) 81 | out3 = out3.sum() 82 | out3.backward() 83 | print(torch.allclose(out, out3)) 84 | print(torch.allclose(mat1.grad, mat1grad)) 85 | print(torch.allclose(mat2.grad, mat2grad)) 86 | -------------------------------------------------------------------------------- /autoparallel/_passes/graph_partition.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Callable 7 | 8 | import torch 9 | from torch._functorch._aot_autograd.graph_compile import ( 10 | _aot_stage2a_partition, 11 | _apply_tensorify_python_scalars, 12 | ) 13 | from torch._functorch.aot_autograd import ( 14 | AOTConfig, 15 | AOTGraphCapture, 16 | AOTState, 17 | JointWithDescriptors, 18 | OutputType, 19 | ViewAndMutationMeta, 20 | boxed_nop_preserve_node_meta, 21 | default_partition, 22 | ) 23 | 24 | 25 | def partition_joint_with_descriptors( 26 | jd: JointWithDescriptors, 27 | *, 28 | partition_fn: Callable = default_partition, 29 | fw_compiler: Callable = boxed_nop_preserve_node_meta, 30 | bw_compiler: Callable = boxed_nop_preserve_node_meta, 31 | ) -> tuple[ 32 | torch.fx.GraphModule, 33 | torch.fx.GraphModule, 34 | int, 35 | int, 36 | int, 37 | int, 38 | int, 39 | list[int], 40 | list[Any], 41 | ]: 42 | aot_state: AOTState = jd._aot_state 43 | aot_graph_capture: AOTGraphCapture = jd._aot_graph_capture 44 | # Update the AOTState with the provided compilers 45 | aot_state.aot_config.partition_fn = partition_fn 46 | aot_state.aot_config.fw_compiler = fw_compiler 47 | aot_state.aot_config.bw_compiler = bw_compiler 48 | aot_state.aot_config.inference_compiler = fw_compiler 49 | 50 | fx_g: torch.fx.GraphModule = aot_graph_capture.graph_module 51 | maybe_subclass_meta: Any = aot_graph_capture.maybe_subclass_meta 52 | fw_metadata: ViewAndMutationMeta = aot_state.fw_metadata 53 | aot_config: AOTConfig = aot_state.aot_config 54 | 55 | # AOTAutogradStage2a: Partition the graph into forward and backward graphs and 56 | # return the some metadata about the partitioning. 57 | 58 | _apply_tensorify_python_scalars(fx_g) 59 | 60 | ( 61 | fw_module, 62 | bw_module, 63 | num_fw_outs_saved_for_bw, 64 | num_symints_saved_for_bw, 65 | _indices_of_inps_to_detach, 66 | adjusted_flat_args, 67 | ) = _aot_stage2a_partition( 68 | fx_g, 69 | aot_graph_capture.updated_flat_args, 70 | maybe_subclass_meta, 71 | fw_metadata, 72 | aot_config, 73 | ) 74 | 75 | num_user_outputs = ( 76 | len( 77 | [ 78 | x 79 | for x in fw_metadata.output_info 80 | if x.output_type 81 | in (OutputType.non_alias, OutputType.alias_of_intermediate) 82 | ] 83 | ) 84 | + fw_metadata.num_intermediate_bases 85 | ) 86 | 87 | num_mutate_inputs = len( 88 | [x for x in fw_metadata.input_info if x.mutates_data or x.mutates_metadata] 89 | ) 90 | num_params_buffers = aot_config.num_params_buffers 91 | return ( 92 | fw_module, 93 | bw_module, 94 | num_params_buffers, 95 | num_user_outputs, 96 | num_mutate_inputs, 97 | num_fw_outs_saved_for_bw, 98 | num_symints_saved_for_bw, 99 | _indices_of_inps_to_detach, 100 | adjusted_flat_args, 101 | ) 102 | -------------------------------------------------------------------------------- /examples/run_ds3_numerics_check.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Script to run DS3 numerics check by comparing outputs from local_map and pipeline parallel. 8 | """ 9 | import shutil 10 | import subprocess 11 | import tempfile 12 | import warnings 13 | from pathlib import Path 14 | 15 | 16 | def run_command(cmd, cwd): 17 | """Run a shell command in the specified directory.""" 18 | print(f"Running: {cmd}") 19 | print(f"In directory: {cwd}") 20 | result = subprocess.run(cmd, shell=True, cwd=cwd, capture_output=True, text=True) 21 | print(result.stdout) 22 | if result.stderr: 23 | print("STDERR:", result.stderr) 24 | if result.returncode != 0: 25 | warnings.warn(f"Command failed with return code {result.returncode}") 26 | return result 27 | 28 | 29 | def main(args): 30 | schedule_name = args.schedule_name 31 | 32 | # Create a temporary directory 33 | temp_dir = tempfile.mkdtemp(prefix="ds3_numerics_check_") 34 | print(f"Created temporary directory: {temp_dir}") 35 | 36 | try: 37 | examples_dir = Path(__file__).parent 38 | 39 | print("\n" + "=" * 80) 40 | print("Running non-PP example with 4 GPUs...") 41 | print("=" * 80) 42 | cmd1 = f"torchrun --standalone --nproc-per-node 4 {examples_dir}/example_ds3_local_map.py --rng-seed 42" 43 | run_command(cmd1, temp_dir) 44 | 45 | print("\n" + "=" * 80) 46 | print("Running PP example with 8 GPUs...") 47 | print("=" * 80) 48 | cmd2 = f"torchrun --standalone --nproc-per-node 8 {examples_dir}/example_ds3_pp.py --rng-seed 42 --schedule-name={schedule_name}" 49 | run_command(cmd2, temp_dir) 50 | 51 | out_dir = Path(temp_dir) / "out" 52 | if not out_dir.exists(): 53 | raise RuntimeError(f"Output directory {out_dir} does not exist") 54 | 55 | print("\n" + "=" * 80) 56 | print("Comparing weights.log files...") 57 | print("=" * 80) 58 | run_command("diff out/0/weights.log out/1/pp_weights.log", temp_dir) 59 | 60 | print("\n" + "=" * 80) 61 | print("Comparing diff.log files...") 62 | print("=" * 80) 63 | run_command("diff out/0/diff.log out/1/diff.log", temp_dir) 64 | 65 | print("\n" + "=" * 80) 66 | print("Numerics check completed successfully!") 67 | print(f"Output directory: {temp_dir}/out") 68 | print("=" * 80) 69 | 70 | except Exception as e: 71 | print(f"\nError occurred: {e}") 72 | print(f"Temporary directory preserved at: {temp_dir}") 73 | raise 74 | 75 | print(f"\nTemporary directory location: {temp_dir}") 76 | response = input("Do you want to delete the temporary directory? (y/n): ") 77 | if response.lower() == "y": 78 | shutil.rmtree(temp_dir) 79 | print("Temporary directory deleted.") 80 | else: 81 | print(f"Temporary directory preserved at: {temp_dir}") 82 | 83 | 84 | if __name__ == "__main__": 85 | import argparse 86 | 87 | parser = argparse.ArgumentParser( 88 | description="Run DeepSeek V3 pipeline parallel example" 89 | ) 90 | parser.add_argument( 91 | "--schedule-name", 92 | type=str, 93 | default="ZBVZeroBubble", 94 | help="Schedule to use for PP", 95 | ) 96 | args = parser.parse_args() 97 | main(args) 98 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /tests/test_propagation_rules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import pytest 7 | import torch 8 | from torch import nn 9 | from torch.distributed.fsdp import MixedPrecisionPolicy 10 | from torch.distributed.tensor.placement_types import Shard 11 | from torch.testing._internal.distributed.fake_pg import FakeStore 12 | 13 | from autoparallel.api import AutoParallel 14 | 15 | 16 | @pytest.fixture(scope="module", autouse=True) 17 | def init_pg(): 18 | world_size = 256 19 | fake_store = FakeStore() 20 | if torch.distributed.is_initialized(): 21 | return 22 | torch.distributed.init_process_group( 23 | "fake", store=fake_store, rank=0, world_size=world_size 24 | ) 25 | 26 | 27 | @pytest.fixture(scope="module") 28 | def device_mesh_1d(): 29 | world_size = torch.distributed.get_world_size() 30 | mesh = torch.distributed.device_mesh.init_device_mesh( 31 | "cuda", (world_size,), mesh_dim_names=("dp",) 32 | ) 33 | return mesh 34 | 35 | 36 | def test_permute_layernorm_stride_handling(device_mesh_1d): 37 | """Test that permute + layernorm handles non-contiguous to contiguous stride transitions. 38 | 39 | This test reproduces the stride mismatch bug in ConvNeXt-style architectures where: 40 | 1. First permute creates a non-contiguous tensor (view) with stride (301056, 56, 1, 3136) 41 | 2. LayerNorm receives non-contiguous input but returns a contiguous tensor 42 | 3. Second permute creates another non-contiguous tensor (view) 43 | """ 44 | 45 | class PermuteLayerNormNet(nn.Module): 46 | """Network with permute -> LayerNorm -> permute.""" 47 | 48 | def __init__(self, channels): 49 | super().__init__() 50 | self.norm = nn.LayerNorm(channels, eps=1e-6) 51 | 52 | def forward(self, x): 53 | # (N, C, H, W) -> (N, H, W, C) 54 | x = x.permute(0, 2, 3, 1) 55 | # LayerNorm on last dim (C) 56 | x = self.norm(x) 57 | # (N, H, W, C) -> (N, C, H, W) 58 | x = x.permute(0, 3, 1, 2) 59 | return x 60 | 61 | batch_size = 256 62 | channels = 96 63 | height = 56 64 | width = 56 65 | 66 | def input_fn(): 67 | return torch.rand(batch_size, channels, height, width, device="cuda") 68 | 69 | # Create model on meta device 70 | with torch.device("meta"): 71 | model = PermuteLayerNormNet(channels=channels) 72 | 73 | # Mixed precision policy 74 | mp_policy = MixedPrecisionPolicy( 75 | param_dtype=torch.float32, reduce_dtype=torch.float32 76 | ) 77 | 78 | # This should not raise an AssertionError about tensor_meta stride mismatch. 79 | with AutoParallel( 80 | model, input_fn, device_mesh_1d, mp_policy, compile=True 81 | ) as autop: 82 | x_sharding = (Shard(0),) 83 | y_sharding = (Shard(0),) 84 | 85 | autop.add_input_constraints([x_sharding]) 86 | autop.add_output_constraints([y_sharding]) 87 | 88 | sharding_placement = autop.optimize_placement() 89 | 90 | # Apply the optimized placement 91 | parallel_mod = autop.apply_placement(sharding_placement) 92 | 93 | # Initialize the parallel module 94 | parallel_mod.to_empty(device="cuda") 95 | 96 | for name, param in parallel_mod.named_parameters(): 97 | if "weight" in name: 98 | torch.nn.init.ones_(param) 99 | elif "bias" in name: 100 | torch.nn.init.zeros_(param) 101 | 102 | # Test forward pass execution works 103 | local_batch_size = batch_size // torch.distributed.get_world_size() 104 | x_test = torch.rand(local_batch_size, channels, height, width, device="cuda") 105 | out = parallel_mod(x_test) 106 | 107 | # Verify output shape (should match input after permute -> norm -> permute) 108 | assert out.shape == (local_batch_size, channels, height, width) 109 | # Output may be non-contiguous due to final permute (view operation) 110 | 111 | # Verify forward execution produces correct output 112 | assert out.abs().sum() > 0 113 | -------------------------------------------------------------------------------- /mast/mount.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 3 | # 4 | # This source code is licensed under the BSD license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | export PS4=' + [$(date +"%Y-%m-%d %H:%M:%S,%3N")] ' 8 | set -eExu -o pipefail 9 | 10 | # want REGION_DATACENTER_PREFIX 11 | source /etc/fbwhoami 12 | 13 | ####################################### 14 | # Set up oilfs and airstore mounts if set. 15 | # Globals: 16 | # DISABLE_OILFS kill switch to skip setting up oilfs and airstore entirely (default unset) 17 | # AI_RM_ATTRIBUTION should be set by mast for attribution 18 | ####################################### 19 | function setup_oilfs { 20 | if [[ -n "${DISABLE_OILFS-}" ]]; then 21 | echo "OilFS disabled through env DISABLE_OILFS=$DISABLE_OILFS. Skipping mounts." 22 | return 0 23 | fi 24 | 25 | if [ -n "$ENABLE_AIRSTORE" ]; then 26 | FUSE_SRC="ws://ws.ai.pci0ai/genai_fair_llm" 27 | else 28 | FUSE_SRC="ws://ws.ai.nha0genai/checkpoint/infra" 29 | fi 30 | FUSE_DST="/mnt/wsfuse" 31 | 32 | mkdir -p "$FUSE_DST" 33 | /packages/oil.oilfs/oilfs-wrapper --profile="${OILFS_PROFILE-genai}" --user="$AI_RM_ATTRIBUTION" --log-level=debug "$FUSE_SRC" "$FUSE_DST" 34 | } 35 | 36 | 37 | ####################################### 38 | # Set up ManifoldFS mount if configured. 39 | # Globals: 40 | # DISABLE_MANIFOLDFS kill switch to skip setting up manifoldfs (default unset) 41 | # MANIFOLDFS_FUSE_DST path on host to mount to; defaults to /mnt/mffuse 42 | # MANIFOLDFS_BUCKET which Manifold bucket to mount; if unset will skip setup 43 | ####################################### 44 | function setup_manifoldfs { 45 | if [[ -n "${DISABLE_MANIFOLDFS-}" ]]; then 46 | echo "ManifoldFS disabled through env DISABLE_MANIFOLDFS=$DISABLE_MANIFOLDFS. Skipping mounts." 47 | return 0 48 | fi 49 | 50 | if [[ -z "${MANIFOLDFS_BUCKET-}" ]]; then 51 | echo "Manifold bucket is not set (MANIFOLDFS_BUCKET is empty), skipping setting up ManifoldFS" 52 | return 0 53 | fi 54 | 55 | MANIFOLDFS_FUSE_DST="${MANIFOLDFS_FUSE_DST:-/mnt/mffuse}" 56 | mkdir -p "${MANIFOLDFS_FUSE_DST}" 57 | 58 | if [[ -n "${ENABLE_MANIFUSE_OVER_MANIFOLDFS-}" ]]; then 59 | MANIFOLD_FUSE_SRC="manifold://$MANIFOLDFS_BUCKET/tree" 60 | /packages/oil.oilfs/oilfs-wrapper --profile="manifold" --user="$AI_RM_ATTRIBUTION" --log-level=debug "${MANIFOLD_FUSE_SRC}" "${MANIFOLDFS_FUSE_DST}" 61 | else 62 | MANIFOLDFS_BINARY=${MANIFOLDFS_BINARY:-"/packages/manifold.manifoldfs/manifoldfs"} 63 | "${MANIFOLDFS_BINARY}" "manifold.blobstore" "${MANIFOLDFS_BUCKET}" "${MANIFOLDFS_FUSE_DST}" 64 | fi 65 | } 66 | 67 | ####################################### 68 | # Mounts airstore with the right setup. 69 | # Globals: 70 | # ENABLE_AIRSTORE enable airstore (default unset) 71 | # AIRSTORE_URI allows overriding the oilfs region used for airstore mount. 72 | ####################################### 73 | function mount_airstore { 74 | if [[ -z "${ENABLE_AIRSTORE-}" ]]; then 75 | echo "Airstore has not been enabled through env ENABLE_AIRSTORE. Skipping mounts." 76 | return 0 77 | fi 78 | 79 | local airstore_uri="${AIRSTORE_URI-}" 80 | if [[ -z "$airstore_uri" ]]; then 81 | local host 82 | host="$(hostname)" 83 | 84 | case $host in 85 | *.pci* ) 86 | airstore_uri="ws://ws.ai.pci0ai/airstore" 87 | ;; 88 | *.eag* ) 89 | airstore_uri="ws://ws.ai.eag0genai/airstore" 90 | ;; 91 | *.gtn* ) 92 | airstore_uri="ws://ws.ai.gtn0genai/airstore" 93 | ;; 94 | *.nha* ) 95 | airstore_uri="ws://ws.ai.nha0genai/airstore" 96 | ;; 97 | *.snb* ) 98 | airstore_uri="ws://ws.ai.snb0genai/airstore" 99 | ;; 100 | *.vcn* ) 101 | airstore_uri="ws://ws.ai.vcn0genai/airstore" 102 | ;; 103 | *.zas* ) 104 | airstore_uri="ws://ws.ai.zas0genai/airstore" 105 | ;; 106 | *.nao* ) 107 | airstore_uri="ws://ws.ai.nao0ai/airstore" 108 | ;; 109 | * ) 110 | echo -e "\e[31mNo airstore source available based on region of $host, only available in pci, eag, gtn, nha. You can mount a cross-region airstore by passing in the AIRSTORE_URI environment variable\e[0m" 1>&2 111 | exit 1 112 | ;; 113 | esac 114 | fi 115 | 116 | local mount_dir="${AIRSTORE_LOCAL_MOUNT_ROOT:-/data/users/airstore}" 117 | if [ ! -d "$mount_dir" ] ; then 118 | mkdir -p "$mount_dir" 119 | fi 120 | 121 | # Enable privacy logging for airstore mount unless pretraining 122 | if [[ "${OILFS_PROFILE-genai}" != "pretraining" ]]; then 123 | export OILFS_ENABLE_PRIVACY_LIB_LOGGER_AIRSTORE=1; 124 | fi 125 | 126 | echo "WS-Airstore: mount from $airstore_uri to $mount_dir" 127 | if [[ ${OILFS_USE_LEGACY_SCRIPT+set} && "${OILFS_USE_LEGACY_SCRIPT}" == 1 ]]; then 128 | /packages/oil.oilfs/scripts/airstore_wrapper.sh "$airstore_uri" "$mount_dir" 129 | else 130 | /packages/oil.oilfs/oilfs-wrapper --log-level debug --profile=airstore "$airstore_uri" "$mount_dir" --user "airstore-${AI_RM_ATTRIBUTION-}" 131 | fi 132 | } 133 | 134 | setup_oilfs 135 | setup_manifoldfs 136 | mount_airstore 137 | -------------------------------------------------------------------------------- /autoparallel/init_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | from typing import Any, Union 6 | 7 | import torch 8 | from torch._dynamo.utils import warn_once 9 | from torch.distributed.tensor import DTensor 10 | 11 | 12 | def _submod_setattr(model: torch.nn.Module, fqn: str, value: Any): 13 | module_path, _, buffer_name = fqn.rpartition(".") 14 | submod: torch.nn.Module = model.get_submodule(module_path) 15 | setattr(submod, buffer_name, value) 16 | 17 | 18 | def _copy_set_value_to_dtensor( 19 | fqn: str, parallel_value: DTensor, set_value: torch.Tensor 20 | ): 21 | # We expect the user wrote their module's init_weights in terms of a single-gpu model, so we do not expect 22 | # set_value to be a DTensor already (since this would imply init_weights was written in a 'distributed' way), 23 | # and we interpret it as a global tensor which we map to a Replicated DTensor. 24 | assert not isinstance( 25 | set_value, DTensor 26 | ), "Expected local/full tensor from setattr in init_weights, not DTensor." 27 | 28 | # This creates a replicated DTensor 29 | new_parallel_value = DTensor.from_local( 30 | set_value, device_mesh=parallel_value.device_mesh 31 | ) 32 | if parallel_value.placements != new_parallel_value.placements: 33 | # no harm done if the parallel value is replicated, e.g. freqs_cis in llama3, but it would be 34 | # noticeably wasteful if we do this for all the sharded parameters. 35 | warn_once( 36 | f"init_weights set a new value for {fqn}, " 37 | f"but the existing value is already sharded ({parallel_value.placements=}, " 38 | "and it is wasteful to materialize the new value as a global tensor. " 39 | "Change init_weights to perform an inplace initialization instead if possible." 40 | ) 41 | with torch.no_grad(): 42 | # This ensures that we faithfully redistribute the replicated new_parallel_value into whatever placement 43 | # the autoparallel engine decided for parallel_value. Note: this should in general be comm free, since it 44 | # would be going from Replicate -> Shard. 45 | parallel_value.copy_(new_parallel_value) 46 | 47 | 48 | def _build_param_property(parallel_model: torch.nn.Module, fqn: str): 49 | def getter(self) -> torch.nn.Parameter: 50 | param = parallel_model.get_parameter(fqn) 51 | return param 52 | 53 | def setter(self, value: Union[torch.Tensor, torch.nn.Parameter]) -> None: 54 | parallel_value = parallel_model.get_parameter(fqn) 55 | assert isinstance( 56 | parallel_value, DTensor 57 | ), "Expected parallel_module params to be DTensors" 58 | _copy_set_value_to_dtensor(fqn, parallel_value, value) 59 | 60 | return property(getter, setter) 61 | 62 | 63 | def _build_buffer_property(parallel_model: torch.nn.Module, fqn: str): 64 | def getter(self) -> torch.Tensor: 65 | return parallel_model.get_buffer(fqn) 66 | 67 | def setter(self, value: torch.Tensor) -> None: 68 | parallel_value = parallel_model.get_buffer(fqn) 69 | assert isinstance( 70 | parallel_value, DTensor 71 | ), "Expected parallel_module params to be DTensors" 72 | _copy_set_value_to_dtensor(fqn, parallel_value, value) 73 | 74 | return property(getter, setter) 75 | 76 | 77 | def hook_params_setters( 78 | init_weights_model: torch.nn.Module, parallel_model: torch.nn.Module 79 | ) -> None: 80 | """ 81 | Replaces init_weights_model's parameters with hooked properties that let us 82 | (a) return a new parameter (from our parallel_mod) instead of the one on the original model, 83 | similar to using stateless.reparametrize 84 | (b) also, detect if anyone tries to assign a new value to the parameter, e.g. 85 | self.layer.weight = nn.Parameter(torch.randn(10, 10)) 86 | would not be properly captured if relying on parametrization alone 87 | 88 | Assumes init_weights_model is a deepcopy of the user's original model, with all fake params. This way we can 89 | modify the model to enable init_weights to work, without affecting the user's original model. 90 | 91 | Adds one 'property' (e.g. getter+setter) obj for each parameter name at the right spot in 92 | the module hierarchy. For self.layer.weight, this would install a 'weight' property on the self.layer 93 | submodule. 94 | """ 95 | for mod_name, mod in sorted(init_weights_model.named_modules()): 96 | params_dict = dict(mod.named_parameters(recurse=False)) 97 | buffers_dict = dict(mod.named_buffers(recurse=False)) 98 | 99 | namespace = {} 100 | for p_name in params_dict: 101 | fqn = mod_name + "." + p_name 102 | namespace[p_name] = _build_param_property(parallel_model, fqn) 103 | 104 | for b_name in buffers_dict: 105 | fqn = mod_name + "." + b_name 106 | namespace[b_name] = _build_buffer_property(parallel_model, fqn) 107 | 108 | cls = mod.__class__ 109 | # nn.Module.__setattr__ gets in the way 110 | namespace["__setattr__"] = object.__setattr__ 111 | mod.__class__ = type(f"HookedInit{cls.__name__}", (cls,), namespace) 112 | -------------------------------------------------------------------------------- /autoparallel/_passes/split_fsdp_collectives.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import dataclasses 7 | from contextlib import contextmanager 8 | from copy import deepcopy 9 | from functools import partial 10 | from typing import Any 11 | 12 | import torch 13 | import torch.fx.node 14 | import torch.utils._pytree as pytree 15 | from torch._functorch._aot_autograd.descriptors import AOTOutput 16 | from torch._functorch.partitioners import _extract_graph_with_inputs_outputs 17 | from torch._inductor.fx_passes.bucketing import ( 18 | is_all_gather_into_tensor, 19 | is_reduce_scatter_tensor, 20 | is_wait_tensor, 21 | ) 22 | 23 | 24 | @contextmanager 25 | def exclude_from_fx_side_effectful(exclude_vals: set[Any]): 26 | original_val = torch.fx.node._side_effectful_functions.copy() 27 | try: 28 | torch.fx.node._side_effectful_functions -= exclude_vals 29 | yield 30 | finally: 31 | torch.fx.node._side_effectful_functions.clear() 32 | torch.fx.node._side_effectful_functions.update(original_val) 33 | 34 | 35 | exclude_wait_from_fx_side_effectful = partial( 36 | exclude_from_fx_side_effectful, 37 | { 38 | torch.ops._c10d_functional.wait_tensor, 39 | torch.ops._c10d_functional.wait_tensor.default, 40 | }, 41 | ) 42 | 43 | 44 | @dataclasses.dataclass(frozen=True) 45 | class PrefetchOutput(AOTOutput): 46 | pass 47 | 48 | 49 | @dataclasses.dataclass(frozen=True) 50 | class EpilogueInput(AOTOutput): 51 | pass 52 | 53 | 54 | def split_fsdp_prefetch( 55 | gm: torch.fx.GraphModule, 56 | num_params: int, 57 | ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: 58 | g = deepcopy(gm.graph) 59 | all_g_ins = g.find_nodes(op="placeholder") 60 | param_g_ins = all_g_ins[:num_params] 61 | rem_g_ins = all_g_ins[num_params:] 62 | 63 | prefetch_g_outs_map = [] 64 | 65 | for param_g_in in param_g_ins: 66 | n = param_g_in 67 | last_ag = None 68 | while True: 69 | if len(n.users) != 1: 70 | break 71 | user = next(iter(n.users)) 72 | if len(user.all_input_nodes) > 1: 73 | break 74 | n = user 75 | if is_all_gather_into_tensor(n): 76 | last_ag = n 77 | if last_ag is None: 78 | prefetch_g_outs_map.append(param_g_in) 79 | else: 80 | w_n = next(iter(last_ag.users)) 81 | assert is_wait_tensor(w_n) 82 | prefetch_g_outs_map.append(w_n) 83 | 84 | prefetch_g_outs = prefetch_g_outs_map 85 | prefetch_g_outs_descs: list[AOTOutput] = [ 86 | PrefetchOutput() for _ in range(len(prefetch_g_outs)) 87 | ] 88 | g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) 89 | g_outs_descs = pytree.arg_tree_leaves( 90 | next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) 91 | ) 92 | with exclude_wait_from_fx_side_effectful(): 93 | prefetch_g = _extract_graph_with_inputs_outputs( 94 | g, 95 | param_g_ins, 96 | prefetch_g_outs, 97 | prefetch_g_outs_descs, 98 | ignore_must_be_in_fw_bw=True, 99 | ) 100 | 101 | main_g = _extract_graph_with_inputs_outputs( 102 | g, 103 | prefetch_g_outs + rem_g_ins, 104 | g_outs, 105 | g_outs_descs, 106 | ignore_must_be_in_fw_bw=True, 107 | ) 108 | prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) 109 | main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) 110 | return prefetch_gm, main_gm 111 | 112 | 113 | def split_fsdp_reduce_scatters_epilogue( 114 | gm: torch.fx.GraphModule, 115 | num_grads: int, 116 | ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: 117 | g = deepcopy(gm.graph) 118 | g_ins = g.find_nodes(op="placeholder") 119 | g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) 120 | grad_outs = g_outs[:num_grads] 121 | rem_g_outs = g_outs[num_grads:] 122 | out_descs = pytree.arg_tree_leaves( 123 | next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs)) 124 | ) 125 | grad_outs_descs = out_descs[:num_grads] 126 | rem_g_outs_descs = out_descs[num_grads:] 127 | 128 | grad_outs_map = [] 129 | for grad_out in grad_outs: 130 | n = grad_out 131 | earliest_rs = None 132 | while n is not None: 133 | if len(n.all_input_nodes) != 1: 134 | break 135 | n_in = n.all_input_nodes[0] 136 | if len(n_in.users) > 1: 137 | break 138 | prev_n = n 139 | n = n_in 140 | # Maybe we also need to track all_reduce? 141 | if is_reduce_scatter_tensor(prev_n): 142 | # In AP for mesh dim > 1 143 | # The reduction of gradients happen in multiple steps 144 | earliest_rs = n 145 | if earliest_rs is not None: 146 | grad_outs_map.append(earliest_rs) 147 | else: 148 | grad_outs_map.append(grad_out) 149 | 150 | epi_g_ins = grad_outs_map 151 | epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] 152 | 153 | with exclude_wait_from_fx_side_effectful(): 154 | main_g = _extract_graph_with_inputs_outputs( 155 | g, 156 | g_ins, 157 | epi_g_ins + rem_g_outs, 158 | epi_g_ins_descs + rem_g_outs_descs, 159 | ignore_must_be_in_fw_bw=True, 160 | ) 161 | epi_g = _extract_graph_with_inputs_outputs( 162 | g, 163 | epi_g_ins, 164 | grad_outs, 165 | grad_outs_descs, 166 | ignore_must_be_in_fw_bw=True, 167 | ) 168 | epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g) 169 | main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) 170 | return main_gm, epi_gm 171 | -------------------------------------------------------------------------------- /examples/example_autoparallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import functools 8 | 9 | import torch 10 | from torch import nn 11 | from torch.distributed.fsdp import MixedPrecisionPolicy 12 | from torch.distributed.tensor.placement_types import Replicate, Shard 13 | from torch.testing._internal.distributed.fake_pg import FakeStore 14 | from torch.utils.checkpoint import create_selective_checkpoint_contexts 15 | 16 | from autoparallel.api import AutoParallel 17 | 18 | 19 | def policy_fn(ctx, op, *args, **kwargs): 20 | if ( 21 | op == torch.ops.aten._scaled_dot_product_flash_attention.default 22 | or op == torch.ops.aten._scaled_dot_product_efficient_attention.default 23 | ): 24 | return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE 25 | return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE 26 | 27 | 28 | context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 29 | 30 | 31 | class Block(nn.Module): 32 | def __init__(self, nheads, dim1, dim2): 33 | super().__init__() 34 | self.nheads = nheads 35 | bias = False 36 | self.wq = nn.Linear(dim1, dim1, bias=bias) 37 | self.wk = nn.Linear(dim1, dim1, bias=bias) 38 | self.wv = nn.Linear(dim1, dim1, bias=bias) 39 | self.wo = nn.Linear(dim1, dim1, bias=bias) 40 | self.w1 = nn.Linear(dim1, dim2, bias=bias) 41 | self.w2 = nn.Linear(dim2, dim1, bias=bias) 42 | 43 | def init_weights(self): 44 | for lin in [self.wq, self.wk, self.wv, self.wo, self.w1, self.w2]: 45 | torch.nn.init.normal_(lin.weight) 46 | if lin.bias is not None: 47 | torch.nn.init.normal_(lin.bias) 48 | 49 | def _compute_attention(self, x): 50 | q = self.wq(x) 51 | k = self.wk(x) 52 | v = self.wv(x) 53 | 54 | q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) 55 | k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) 56 | v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) 57 | 58 | o = nn.functional.scaled_dot_product_attention(q, k, v) 59 | o = o.permute(0, 2, 1, 3).flatten(-2) 60 | 61 | o = self.wo(o) 62 | return o 63 | 64 | def forward(self, x): 65 | o = torch.utils.checkpoint.checkpoint( 66 | self._compute_attention, x, use_reentrant=False, context_fn=context_fn 67 | ) 68 | 69 | o0 = o + x 70 | 71 | o = self.w1(o0) 72 | o = torch.nn.functional.relu(o) 73 | o = self.w2(o) 74 | 75 | o = o0 + o 76 | 77 | return o 78 | 79 | 80 | world_size = 256 81 | 82 | fake_store = FakeStore() 83 | torch.distributed.init_process_group( 84 | "fake", store=fake_store, rank=0, world_size=world_size 85 | ) 86 | 87 | use_1d_mesh = False 88 | 89 | if use_1d_mesh: 90 | mesh = torch.distributed.device_mesh.init_device_mesh( 91 | "cuda", (world_size,), mesh_dim_names=("dp",) 92 | ) 93 | else: 94 | mesh = torch.distributed.device_mesh.init_device_mesh( 95 | "cuda", 96 | (world_size // 8, 8), 97 | mesh_dim_names=( 98 | "dp", 99 | "tp", 100 | ), 101 | ) 102 | 103 | bs = 8 * mesh.shape[0] 104 | seq_len = 256 105 | nheads = 48 106 | dim1 = 6144 107 | dim2 = dim1 * 4 108 | 109 | 110 | def input_fn(): 111 | print(f"global input shape: {(bs, seq_len, dim1)}") 112 | return torch.rand(bs, seq_len, dim1, device="cuda") 113 | 114 | 115 | # parallelize the model 116 | with torch.device("meta"): 117 | model = Block(nheads, dim1, dim2) 118 | 119 | mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) 120 | # mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) 121 | 122 | with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop: 123 | assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes) 124 | assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes) 125 | autop.add_parameter_memory_constraint(low=None, high=None) 126 | 127 | x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) 128 | 129 | autop.add_input_constraints([x_sharding]) 130 | autop.add_output_constraints([x_sharding]) 131 | 132 | sharding_placement = autop.optimize_placement() 133 | 134 | # AutoParallel produces a module with meta-DTensor parameters that need to be initialized 135 | parallel_mod = autop.apply_placement(sharding_placement) 136 | 137 | parallel_mod.to_empty(device="cuda") 138 | parallel_mod.init_weights() 139 | 140 | # now let's run it 141 | x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),) 142 | out = parallel_mod(*x) 143 | out.backward(torch.randn_like(out)) 144 | 145 | # Validate 146 | seqs = set() 147 | for n in autop.gm.graph.nodes: 148 | if "checkpoint" in n.meta.get( 149 | "stack_trace", "" 150 | ): # placeholders don't have stack trace 151 | is_bwd = n.meta.get("partitioner_tag", "") == "is_backward" 152 | if not is_bwd: 153 | if "getitem" in str(n.target): 154 | # getitem nodes are tagged same as their parent 155 | expected = policy_fn(None, n.args[0].target, (), ()) 156 | elif "alias" in str(n.target) and "getitem" in str(n.args[0].target): 157 | # alias nodes that depend on getitem are tagged same as their parent 158 | expected = policy_fn(None, n.args[0].args[0].target, (), ()) 159 | else: 160 | expected = policy_fn(None, n.target, (), ()) 161 | actual = n.meta.get("recompute") 162 | # NOTE: this assert only supports policy_fns on op alone 163 | assert actual == expected, f"{n} {actual} {expected}" 164 | seqs.add(n.meta["seq_nr"]) 165 | else: 166 | # fwd counterpart should have already populated seqs 167 | assert n.meta["seq_nr"] in seqs 168 | 169 | mm_nodes = autop.gm.graph.find_nodes( 170 | op="call_function", target=torch.ops.aten.mm.default 171 | ) 172 | 173 | assert ( 174 | mm_nodes[0].meta.get("recompute") 175 | == torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE 176 | ) 177 | 178 | print("All good!") 179 | -------------------------------------------------------------------------------- /autoparallel/collective_runtime_estimation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import cast 7 | 8 | import torch.distributed.tensor._dtensor_spec as dtensor_spec 9 | from torch._prims_common import check_contiguous_sizes_strides 10 | from torch.distributed.tensor._collective_utils import ( 11 | MeshTopoInfo, 12 | allgather_cost, 13 | allreduce_cost, 14 | reduce_scatter_cost, 15 | spec_to_bytes, 16 | ) 17 | from torch.distributed.tensor.placement_types import Partial, Shard 18 | 19 | from .compute_estimation import compute_read_write_time 20 | 21 | 22 | def all_to_all_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: 23 | num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] 24 | mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] 25 | num_hops = num_devices_on_mesh_dim - 1 26 | # base latency + comm latency 27 | latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us 28 | bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s 29 | total_time = latency + bw * 1e6 # rescale to us 30 | # FIXME: this is a hack, we need to spend some more effort on the cost model 31 | total_time *= 5 32 | return total_time 33 | 34 | 35 | # this is a copy-paste from https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_collective_utils.py 36 | # with iteration order introduced 37 | def redistribute_cost( 38 | current_spec: "dtensor_spec.DTensorSpec", 39 | target_spec: "dtensor_spec.DTensorSpec", 40 | order: list[int], 41 | ) -> float: 42 | """ 43 | This function returns the cost of redistribute from current to target DTensorSpec. 44 | 45 | NOTE: 46 | 1. Only consider communication cost here, since computation costs for redistribute 47 | are quite trivial (i.e. we only need to narrow or simple division) 48 | 2. Only consider redistribute cost on same mesh, cross mesh communication cost is 49 | not quite needed for operator strategy estimation/selection. 50 | """ 51 | if current_spec.mesh != target_spec.mesh: 52 | # make infinite cost if meshes are not same 53 | # TODO: see if we want to support this once there's cross mesh communication 54 | return float("inf") 55 | 56 | if current_spec.is_replicated(): 57 | # short-cut: 58 | # comm cost is 0 if current spec is already full replication 59 | # except if output is partial, which doesn't make sense for us 60 | if any(p.is_partial() for p in target_spec.placements): 61 | return float("inf") 62 | return 0.0 63 | 64 | mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) 65 | cost = 0.0 66 | comm_bytes_gb = ( 67 | spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 68 | ) 69 | # Transformation that considered for redistribute cost: 70 | # 1. allgather 2. alltoall 71 | # 3. allreduce 4. reduce_scatter 72 | curr_placements = [current_spec.placements[i] for i in order] 73 | tgt_placements = [target_spec.placements[i] for i in order] 74 | is_contiguous: bool = check_contiguous_sizes_strides( 75 | current_spec.shape, current_spec.stride 76 | ) 77 | for i, current, target in zip(order, curr_placements, tgt_placements): 78 | if current == target: 79 | continue 80 | num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] 81 | if not is_contiguous: 82 | cost += compute_read_write_time(comm_bytes_gb * 2 * 1024**3) 83 | if current.is_shard() and target.is_replicate(): 84 | current = cast(Shard, current) 85 | # allgather gives larger comm bytes 86 | comm_bytes_gb *= num_devices_on_mesh_dim 87 | # add up allgather comm cost 88 | cost += allgather_cost(comm_bytes_gb, mesh_topo, i) 89 | if current.dim != 0: 90 | # penalize cases like S(1) -> R as there are additional compute cost 91 | # which corresponds to reshuffling the whole output tensor 92 | # we multiply the cost by 2 because we need to count input and output 93 | # reads for the reshuffle 94 | compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3) 95 | cost += compute_cost 96 | elif current.is_shard() and target.is_shard(): 97 | current = cast(Shard, current) 98 | target = cast(Shard, target) 99 | # should be alltoall comm, since we haven't implement it yet, add penalty 100 | # to favor allgather instead 101 | cost += all_to_all_cost(comm_bytes_gb, mesh_topo, i) # us 102 | 103 | num_copies = 0 104 | if current.dim != 0: 105 | num_copies += 1 106 | 107 | if target.dim != 0: 108 | num_copies += 1 109 | 110 | compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3) 111 | cost += num_copies * compute_cost 112 | 113 | elif current.is_partial() and target.is_replicate(): 114 | # add up allreduce comm cost 115 | cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) 116 | elif current.is_partial() and target.is_shard(): 117 | target = cast(Shard, target) 118 | # add up reduce_scatter comm cost 119 | cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) 120 | if target.dim != 0: 121 | # penalize cases like P -> S(1) as there are additional compute cost 122 | # which corresponds to reshuffling the whole input tensor 123 | # we multiply the cost by 2 because we need to count input and output 124 | # reads for the reshuffle 125 | compute_cost = compute_read_write_time(comm_bytes_gb * 2 * 1024**3) 126 | cost += compute_cost 127 | # after reduce_scatter the comm bytes for further collectives halved. 128 | comm_bytes_gb /= num_devices_on_mesh_dim 129 | elif current.is_shard() and target.is_partial(): 130 | # ban shard -> partial as it does not make sense to perform 131 | # this redistribute 132 | return float("inf") 133 | elif current.is_replicate() and target.is_partial(): 134 | # ban replicate -> partial as it does not make sense to perform 135 | # this redistribute in our case 136 | return float("inf") 137 | 138 | # once we redistribute across one mesh dim, assume the output 139 | # is now contiguous. This is generally the case for most operations, 140 | # except when we fuse nd collectives into a 1d collective. 141 | is_contiguous = True 142 | 143 | return cost 144 | 145 | 146 | def estimate_strategy_comms_cost(src_spec, tgt_spec): 147 | order = list(range(src_spec.mesh.ndim)) 148 | if src_spec.placements == (Partial(), Partial()) and all( 149 | p.is_shard() for p in tgt_spec.placements 150 | ): 151 | order = [1, 0] 152 | comms_cost = redistribute_cost(src_spec, tgt_spec, order) 153 | return comms_cost 154 | -------------------------------------------------------------------------------- /autoparallel/collectives.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Any, Optional 7 | 8 | import torch 9 | import torch.distributed.distributed_c10d as c10d 10 | from torch.distributed._tensor.experimental import local_map as _local_map 11 | 12 | # Import GroupName for type checking 13 | GroupName = c10d.GroupName 14 | 15 | _local_map_device_mesh = None 16 | 17 | 18 | def local_map(*args, **kwargs): 19 | # TODO: ideally after we get out of the local map region we should 20 | # just reset the global device mesh to None. For now we just keep it 21 | # around. 22 | global _local_map_device_mesh 23 | _local_map_device_mesh = kwargs.get("device_mesh", None) 24 | return _local_map(*args, **kwargs) 25 | 26 | 27 | def get_mesh_from_global(): 28 | global _local_map_device_mesh 29 | if _local_map_device_mesh is None: 30 | raise RuntimeError( 31 | "No mesh found, make sure to call this collective in a local_map region" 32 | ) 33 | return _local_map_device_mesh 34 | 35 | 36 | def _get_group_name_from_axis_name(mesh_name): 37 | mesh = get_mesh_from_global() 38 | group = mesh.get_group(mesh_name) 39 | return group.group_name 40 | 41 | 42 | def axis_size(axis_name): 43 | mesh = get_mesh_from_global() 44 | assert axis_name in mesh.mesh_dim_names 45 | axis_dim = mesh.mesh_dim_names.index(axis_name) 46 | return mesh.size(axis_dim) 47 | 48 | 49 | def axis_index(axis_name): 50 | mesh = get_mesh_from_global() 51 | return mesh.get_local_rank(mesh_dim=axis_name) 52 | 53 | 54 | def _all_gather_tensor( 55 | x: torch.Tensor, 56 | gather_dim: int, 57 | group_name: GroupName, 58 | ) -> torch.Tensor: 59 | x = x.contiguous() 60 | group_size = c10d._get_group_size_by_name(group_name) 61 | tensor = torch.ops._c10d_functional.all_gather_into_tensor( 62 | x, group_size, group_name 63 | ) 64 | res = torch.ops._c10d_functional.wait_tensor(tensor) 65 | if gather_dim != 0: 66 | # torch.cat access the data so we already need to wait here, first do wait 67 | # and then chunk + cat avoid us going through ACT dispatching logic again 68 | res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim) 69 | return res 70 | 71 | 72 | def _reduce_scatter_tensor( 73 | self: torch.Tensor, reduceOp: str, scatter_dim: int, group_name: GroupName 74 | ): 75 | group_size = c10d._get_group_size_by_name(group_name) 76 | 77 | assert ( 78 | self.size(scatter_dim) % group_size == 0 79 | ), f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})" 80 | if scatter_dim != 0: 81 | tensor_list = torch.chunk(self, group_size, dim=scatter_dim) 82 | self = torch.cat(tensor_list) 83 | 84 | tensor = torch.ops._c10d_functional.reduce_scatter_tensor( 85 | self, 86 | reduceOp.lower(), 87 | group_size, 88 | group_name, 89 | ) 90 | res = torch.ops._c10d_functional.wait_tensor(tensor) 91 | return res 92 | 93 | 94 | def _all_reduce(self: torch.Tensor, reduceOp: str, group_name: GroupName): 95 | tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name) 96 | res = torch.ops._c10d_functional.wait_tensor(tensor) 97 | return res 98 | 99 | 100 | def _all_to_all( 101 | self: torch.Tensor, 102 | output_split_sizes: Optional[list[int]], 103 | input_split_sizes: Optional[list[int]], 104 | group_name: GroupName, 105 | ): 106 | group_size = c10d._get_group_size_by_name(group_name) 107 | if output_split_sizes is None or input_split_sizes is None: 108 | assert output_split_sizes is None and input_split_sizes is None, ( 109 | "output_split_sizes and input_split_sizes must either be " 110 | "specified together or both set to None" 111 | ) 112 | output_split_sizes = [self.shape[0] // group_size] * group_size 113 | input_split_sizes = output_split_sizes 114 | 115 | tensor = torch.ops._c10d_functional.all_to_all_single( 116 | self, output_split_sizes, input_split_sizes, group_name 117 | ) 118 | res = torch.ops._c10d_functional.wait_tensor(tensor) 119 | return res 120 | 121 | 122 | class _AllGather(torch.autograd.Function): 123 | @staticmethod 124 | def forward(ctx: Any, x: torch.Tensor, gather_dim: int, axis_name: str): 125 | group_name = _get_group_name_from_axis_name(axis_name) 126 | ctx.group_name = group_name 127 | ctx.gather_dim = gather_dim 128 | return _all_gather_tensor(x, gather_dim, group_name) 129 | 130 | @staticmethod 131 | def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override] 132 | return ( 133 | _reduce_scatter_tensor(grad_output, "sum", ctx.gather_dim, ctx.group_name), 134 | None, 135 | None, 136 | ) 137 | 138 | 139 | class _ReduceScatter(torch.autograd.Function): 140 | @staticmethod 141 | def forward(ctx: Any, x: torch.Tensor, scatter_dim: int, axis_name: str): 142 | group_name = _get_group_name_from_axis_name(axis_name) 143 | ctx.group_name = group_name 144 | ctx.scatter_dim = scatter_dim 145 | return _reduce_scatter_tensor(x, "sum", scatter_dim, group_name) 146 | 147 | @staticmethod 148 | def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override] 149 | return ( 150 | _all_gather_tensor(grad_output, ctx.scatter_dim, ctx.group_name), 151 | None, 152 | None, 153 | ) 154 | 155 | 156 | class _AllReduce(torch.autograd.Function): 157 | @staticmethod 158 | def forward(ctx: Any, x: torch.Tensor, axis_name: str): 159 | group_name = _get_group_name_from_axis_name(axis_name) 160 | ctx.group_name = group_name 161 | return _all_reduce(x, "sum", group_name) 162 | 163 | @staticmethod 164 | def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override] 165 | # TODO: split this into a function that does all-reduce and one which is the identity 166 | return _all_reduce(grad_output, "sum", ctx.group_name), None 167 | 168 | 169 | class _AllToAll(torch.autograd.Function): 170 | @staticmethod 171 | def forward( 172 | ctx: Any, 173 | x: torch.Tensor, 174 | output_split_sizes: Optional[list[int]], 175 | input_split_sizes: Optional[list[int]], 176 | axis_name: str, 177 | ): 178 | group_name = _get_group_name_from_axis_name(axis_name) 179 | ctx.group_name = group_name 180 | ctx.output_split_sizes = output_split_sizes 181 | ctx.input_split_sizes = input_split_sizes 182 | return _all_to_all(x, output_split_sizes, input_split_sizes, group_name) 183 | 184 | @staticmethod 185 | def backward(ctx: Any, grad_output: torch.Tensor): # type: ignore[override] 186 | return _all_to_all( 187 | grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.group_name 188 | ) 189 | 190 | 191 | all_gather = _AllGather.apply 192 | all_reduce = _AllReduce.apply 193 | reduce_scatter = _ReduceScatter.apply 194 | all_to_all = _AllToAll.apply 195 | -------------------------------------------------------------------------------- /tests/test_ordered_sharding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from unittest.mock import patch 7 | 8 | import pytest 9 | import torch 10 | from torch import nn 11 | from torch.testing._internal.distributed.fake_pg import FakeStore 12 | 13 | from autoparallel.api import AutoParallel 14 | from autoparallel.ordered_sharding import compute_optimal_placement_order_for_parameters 15 | 16 | 17 | @pytest.fixture(scope="module", autouse=True) 18 | def init_pg(): 19 | world_size = 256 20 | fake_store = FakeStore() 21 | if torch.distributed.is_initialized(): 22 | return 23 | torch.distributed.init_process_group( 24 | "fake", store=fake_store, rank=0, world_size=world_size 25 | ) 26 | 27 | 28 | @pytest.fixture(scope="module") 29 | def device_mesh_2d(): 30 | world_size = torch.distributed.get_world_size() 31 | mesh = torch.distributed.device_mesh.init_device_mesh( 32 | "cuda", 33 | (world_size // 8, 8), 34 | mesh_dim_names=( 35 | "dp", 36 | "tp", 37 | ), 38 | ) 39 | return mesh 40 | 41 | 42 | class ModelWithNonTrainableParams(nn.Module): 43 | """A model with both trainable and non-trainable parameters to test the grad is None case.""" 44 | 45 | def __init__(self, dim): 46 | super().__init__() 47 | # Trainable parameter (requires_grad=True by default) 48 | self.linear = nn.Linear(dim, dim, bias=False) 49 | 50 | # Non-trainable parameters (requires_grad=False) 51 | self.register_parameter( 52 | "non_trainable_weight", 53 | nn.Parameter(torch.randn(dim, dim), requires_grad=False), 54 | ) 55 | self.register_buffer("buffer", torch.randn(dim)) 56 | 57 | # Another trainable parameter 58 | self.linear2 = nn.Linear(dim, dim, bias=False) 59 | 60 | def forward(self, x): 61 | # Use both trainable and non-trainable parameters 62 | x = self.linear(x) 63 | x = x + torch.mm(x, self.non_trainable_weight) # Use non-trainable parameter 64 | x = x + self.buffer # Use buffer 65 | x = self.linear2(x) 66 | return x 67 | 68 | 69 | class ModelWithAllNonTrainableParams(nn.Module): 70 | """A model where all parameters don't require gradients.""" 71 | 72 | def __init__(self, dim): 73 | super().__init__() 74 | # Create linear layers but set requires_grad=False for all params 75 | self.linear1 = nn.Linear(dim, dim, bias=False) 76 | self.linear2 = nn.Linear(dim, dim, bias=False) 77 | 78 | # Set all parameters to not require gradients 79 | for param in self.parameters(): 80 | param.requires_grad = False 81 | 82 | def forward(self, x): 83 | x = self.linear1(x) 84 | x = self.linear2(x) 85 | return x 86 | 87 | 88 | @patch("torch.cuda.device_count", lambda: 8) 89 | @patch("torch.cuda.get_device_name", lambda device: "H100") 90 | def test_compute_optimal_placement_order_with_non_trainable_params(device_mesh_2d): 91 | """Test that compute_optimal_placement_order_for_parameters handles parameters with grad=None.""" 92 | 93 | dim = 128 94 | device = "cuda" 95 | 96 | def model_fn(): 97 | return ModelWithNonTrainableParams(dim) 98 | 99 | def input_fn(): 100 | return torch.randn(512, dim, device=device, requires_grad=True) 101 | 102 | with torch.device("meta"): 103 | model = model_fn() 104 | 105 | # Verify our test setup: some params should have requires_grad=False 106 | trainable_params = [p for p in model.parameters() if p.requires_grad] 107 | non_trainable_params = [p for p in model.parameters() if not p.requires_grad] 108 | 109 | assert ( 110 | len(trainable_params) > 0 111 | ), "Test setup error: should have some trainable params" 112 | assert ( 113 | len(non_trainable_params) > 0 114 | ), "Test setup error: should have some non-trainable params" 115 | 116 | with AutoParallel(model, input_fn, device_mesh_2d) as autop: 117 | autop.add_parameter_memory_constraint(low=0, high=None) 118 | sharding_placement = autop.optimize_placement() 119 | 120 | # This should not raise an exception due to grad=None 121 | # Before the fix, this would fail when trying to process non-trainable parameters 122 | placement_order = compute_optimal_placement_order_for_parameters( 123 | autop.gm, sharding_placement 124 | ) 125 | 126 | # The function should return successfully 127 | assert isinstance(placement_order, dict) 128 | assert len(placement_order) == 0 129 | 130 | # Verify we can examine the graph structure to understand param/grad relationships 131 | from torch._functorch._aot_autograd.fx_utils import get_param_and_grad_nodes 132 | 133 | param_and_grad_nodes = list(get_param_and_grad_nodes(autop.gm.graph).values()) 134 | 135 | # Should have param/grad pairs where some grads are None 136 | assert len(param_and_grad_nodes) > 0 137 | 138 | # At least one should have grad=None (the non-trainable param) 139 | has_none_grad = any(grad is None for param, grad in param_and_grad_nodes) 140 | assert has_none_grad, "Expected at least one parameter to have grad=None" 141 | 142 | # At least one should have a valid grad (the trainable param) 143 | has_valid_grad = any(grad is not None for param, grad in param_and_grad_nodes) 144 | assert ( 145 | has_valid_grad 146 | ), "Expected at least one parameter to have a valid gradient" 147 | 148 | 149 | @patch("torch.cuda.device_count", lambda: 8) 150 | @patch("torch.cuda.get_device_name", lambda device: "H100") 151 | def test_compute_optimal_placement_order_with_all_non_trainable_params(device_mesh_2d): 152 | """Test edge case where ALL parameters don't require gradients.""" 153 | 154 | dim = 64 155 | device = "cuda" 156 | 157 | def model_fn(): 158 | return ModelWithAllNonTrainableParams(dim) 159 | 160 | def input_fn(): 161 | return torch.randn(256, dim, device=device, requires_grad=True) 162 | 163 | with torch.device("meta"): 164 | model = model_fn() 165 | 166 | # Verify test setup: all params should have requires_grad=False 167 | non_trainable_params = [p for p in model.parameters() if not p.requires_grad] 168 | trainable_params = [p for p in model.parameters() if p.requires_grad] 169 | 170 | assert ( 171 | len(non_trainable_params) > 0 172 | ), "Test setup error: should have non-trainable params" 173 | assert ( 174 | len(trainable_params) == 0 175 | ), "Test setup error: should have NO trainable params" 176 | 177 | with AutoParallel(model, input_fn, device_mesh_2d) as autop: 178 | autop.add_parameter_memory_constraint(low=0, high=None) 179 | sharding_placement = autop.optimize_placement() 180 | 181 | # This should not raise an exception even when ALL gradients are None 182 | placement_order = compute_optimal_placement_order_for_parameters( 183 | autop.gm, sharding_placement 184 | ) 185 | 186 | # Should return successfully with empty or minimal result 187 | assert isinstance(placement_order, dict) 188 | assert len(placement_order) == 0 189 | -------------------------------------------------------------------------------- /autoparallel/auto_bucketing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from functools import partial 7 | 8 | import torch 9 | from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing 10 | 11 | from .autobucketing_util import bucket_func, bucket_plan, bucket_utils, reorder 12 | 13 | 14 | class simplefsdp_autobucketing_config: 15 | """ 16 | Config for simplefsdp's autobucketing pass, which by default would give good performance. 17 | To make the results tunable, we expose the following parameters: 18 | - relax_ratio: relax comp time to include more comm in one bucket 19 | with this config, comp is updated as comp * (1 + relax_ratio) 20 | - peak_memory_offset: relax peak_memory to include more comm in one bucket 21 | with this config, peak_memory is updated as (peak_memory + peak_memory_offset) 22 | - load_cache: set to True to load cache from save_estimation_path 23 | - enable_bucket_ir: set to True to bucket all_gather/reduce_scatter 24 | - enable_reorder_ir: set to True to reorder all_gather/reduce_satter 25 | - calibrate_number: number of samples to calibrate during comm estimation 26 | """ 27 | 28 | relax_ratio = 0 29 | peak_memory_offset = 0 30 | load_cache = False 31 | save_estimation_path = "/mnt/mffuse/cache_ruisi/estimation_mast.pkl" 32 | enable_bucket_ir = True 33 | enable_reorder_ir = True 34 | calibrate_number = 40 35 | 36 | 37 | def simple_fsdp_autobucketing_reordering_pass( 38 | snodes: list["torch._inductor.scheduler.BaseSchedulerNode"], 39 | configs: "simplefsdp_autobucketing_config", 40 | ) -> list["torch._inductor.scheduler.BaseSchedulerNode"]: 41 | scheduler = snodes[0].scheduler 42 | bucketable_nodes = bucket_utils.get_bucketable_ir_nodes( 43 | snodes, scheduler.name_to_fused_node, scheduler.name_to_buf 44 | ) 45 | 46 | assert ( 47 | not torch._inductor.config.allow_buffer_reuse 48 | ), "bucketing algorithm requires torch._inductor.config.allow_buffer_reuse to be False" 49 | 50 | if configs.enable_bucket_ir: 51 | all_gather_plan, reduce_scatter_plan = bucket_plan.get_simplefsdp_auto_plan( 52 | scheduler, 53 | snodes, 54 | scheduler.name_to_buf, 55 | scheduler.name_to_fused_node, 56 | bucketable_nodes, 57 | configs, 58 | ) 59 | 60 | snodes = bucket_func.bucket_fsdp_all_gather_with_plan( 61 | scheduler, 62 | snodes, 63 | scheduler.name_to_buf, 64 | scheduler.name_to_fused_node, 65 | all_gather_plan, 66 | bucketable_nodes, 67 | ) 68 | if len(reduce_scatter_plan) > 0: 69 | snodes = bucket_func.bucket_fsdp_reduce_scatter_with_plan( 70 | scheduler, 71 | snodes, 72 | scheduler.name_to_buf, 73 | scheduler.name_to_fused_node, 74 | reduce_scatter_plan, 75 | bucketable_nodes, 76 | ) 77 | 78 | if configs.enable_reorder_ir: 79 | print("Reorder scheduler nodes with autobucketing algroithm") 80 | node_length = len(snodes) 81 | snodes = reorder.reorder_all_gather( 82 | snodes, bucketable_nodes, all_gather_before_last_wait=False 83 | ) 84 | assert node_length == len( 85 | snodes 86 | ), f"Missed nodes in reordering all gather: expected {node_length}, but got {len(snodes)}" 87 | snodes = reorder.reorder_reduce_scatter(snodes, bucketable_nodes) 88 | assert node_length == len( 89 | snodes 90 | ), f"Missed nodes in reordering reduce scatter: expected {node_length}, but got {len(snodes)}" 91 | 92 | return snodes 93 | 94 | 95 | class aten_autobucketing_config: 96 | """ 97 | Config for aten level autobucketing pass from stacked PR: https://github.com/pytorch/pytorch/pull/163960 98 | - max_in_flight_gb: maximum GB of concurrent collective data 99 | - compute_overlap_multipler: scale factor for compute time used to hide collectives 100 | - max_coll_distance: maximum node distance for overlap consideration 101 | """ 102 | 103 | max_in_flight_gb = 2.0 104 | compute_overlap_multipler = 1.0 105 | max_coll_distance = 100 106 | custom_runtime_estimation = None 107 | max_compute_pre_fetch = 5 108 | collective_bucketing = False 109 | save_trace = True 110 | _counter = 0 111 | 112 | 113 | def aten_autobucketing_reordering_pass( 114 | gm: torch.fx.Graph, configs: "aten_autobucketing_config" 115 | ) -> torch.fx.GraphModule: 116 | new_gm = schedule_overlap_bucketing( 117 | gm.owning_module, 118 | collective_bucketing=configs.collective_bucketing, 119 | max_compute_pre_fetch=configs.max_compute_pre_fetch, 120 | custom_runtime_estimation=configs.custom_runtime_estimation, 121 | compute_overlap_multipler=configs.compute_overlap_multipler, 122 | max_in_flight_gb=configs.max_in_flight_gb, 123 | max_coll_distance=configs.max_coll_distance, 124 | ) 125 | new_gm.recompile() 126 | 127 | if configs.save_trace: 128 | from autoparallel.debug_helpers import create_execution_trace 129 | 130 | assert configs.custom_runtime_estimation is not None 131 | 132 | create_execution_trace( 133 | new_gm, 134 | configs.custom_runtime_estimation, 135 | f"fake_trace_{configs._counter}.json", 136 | ) 137 | configs._counter += 1 138 | return new_gm 139 | 140 | 141 | def configure_inductor_for_autobucketing(mode: str = "aten"): 142 | # allow configuring inductor comms optimizations from torchtitan commandline 143 | if mode == "aten": 144 | torch._inductor.config.aten_distributed_optimizations.enable_overlap_scheduling = ( 145 | True 146 | ) 147 | torch._inductor.config.aten_distributed_optimizations.collective_bucketing = ( 148 | True 149 | ) 150 | torch._inductor.config.aten_distributed_optimizations.insert_overlap_deps = True 151 | torch._inductor.config.aten_distributed_optimizations.max_compute_pre_fetch = 10 152 | elif mode == "inductor": 153 | from autoparallel.auto_bucketing import ( 154 | simple_fsdp_autobucketing_reordering_pass, 155 | simplefsdp_autobucketing_config, 156 | ) 157 | 158 | torch._inductor.config.allow_buffer_reuse = False 159 | torch._inductor.config.reorder_for_peak_memory = False 160 | torch._inductor.config.reorder_for_compute_comm_overlap = True 161 | simplefsdp_autobucketing_config.calibrate_number = 5 162 | simplefsdp_autobucketing_config.save_estimation_path = "./estimation_mast.pkl" 163 | simple_fsdp_autobucketing_reordering_pass = partial( 164 | simple_fsdp_autobucketing_reordering_pass, 165 | configs=simplefsdp_autobucketing_config, # type: ignore 166 | ) 167 | torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ 168 | simple_fsdp_autobucketing_reordering_pass 169 | ] 170 | elif mode == "none": 171 | torch._inductor.config.reorder_for_peak_memory = False 172 | torch._inductor.config.reorder_for_compute_comm_overlap = False 173 | else: 174 | raise ValueError(f"Unknown comms bucket reorder strategy: {mode}") 175 | -------------------------------------------------------------------------------- /autoparallel/cast_parametrization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import copy 7 | import copyreg 8 | from contextlib import contextmanager 9 | from typing import Type 10 | 11 | import torch 12 | from torch.distributed.fsdp import MixedPrecisionPolicy 13 | from torch.utils._pytree import tree_map 14 | 15 | 16 | def make_getter(self, p_name, mp_policy): 17 | def getter( 18 | self_mod=self, 19 | _param_name=p_name, 20 | _dtype=mp_policy.param_dtype, 21 | ): 22 | _param = self_mod._parameters[_param_name] 23 | if not active_param(): 24 | return _param 25 | return torch.ops.autoparallel.dtype_cast(_param, _dtype) 26 | 27 | return getter 28 | 29 | 30 | # taken from PyTorch's parametrize module from 31 | # https://github.com/pytorch/pytorch/blob/5d9653d90ee003173dd03f93e09fed236500ef06/torch/nn/utils/parametrize.py#L324-L351 32 | # with some improvements 33 | def default_deepcopy(self, memo): 34 | # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. 35 | obj = memo.get(id(self), None) 36 | if obj is not None: 37 | return obj 38 | replica = self.__new__(self.__class__) 39 | memo[id(self)] = replica 40 | replica.__dict__ = copy.deepcopy(self.__dict__, memo) 41 | 42 | # Fix the parametrization getters to point to the replica instead of the original 43 | if hasattr(replica, "_name_to_dtype_cast_managed_attr_getter") and hasattr( 44 | replica, "_mp_policy" 45 | ): 46 | # Recreate the getter functions to point to the replica 47 | param_properties = {} 48 | for p_name in list(replica._name_to_dtype_cast_managed_attr_getter.keys()): 49 | # Use a function factory to properly capture the loop variable 50 | # def make_getter(param_name): 51 | param_properties[p_name] = make_getter(replica, p_name, replica._mp_policy) 52 | replica._name_to_dtype_cast_managed_attr_getter = param_properties 53 | 54 | # Also save all slots if they exist. 55 | slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined] 56 | for slot in slots_to_save: 57 | if hasattr(self, slot): 58 | setattr(replica, slot, copy.deepcopy(getattr(self, slot), memo)) 59 | return replica 60 | 61 | 62 | def getstate(self): 63 | raise RuntimeError( 64 | "Serialization of parametrized modules is only " 65 | "supported through state_dict(). See:\n" 66 | "https://pytorch.org/tutorials/beginner/saving_loading_models.html" 67 | "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" 68 | ) 69 | 70 | 71 | @torch.library.custom_op("autoparallel::dtype_cast", mutates_args=()) 72 | def dtype_cast(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: 73 | """ 74 | This is a custom op that is used to cast the input tensor to the specified dtype. 75 | We use it to be able to specify a special compute cost for the cast operation, 76 | so that we always favor performing all-gather of small tensors in the smallest 77 | dtype. 78 | """ 79 | return x.to(dtype) 80 | 81 | 82 | def setup_context(ctx, inputs, output) -> None: 83 | x, _ = inputs 84 | ctx.orig_dtype = x.dtype 85 | 86 | 87 | def backward(ctx, grad): 88 | dtype = ctx.orig_dtype 89 | return torch.ops.autoparallel.dtype_cast(grad, dtype), None 90 | 91 | 92 | torch.library.register_autograd( 93 | "autoparallel::dtype_cast", backward, setup_context=setup_context 94 | ) 95 | 96 | 97 | @dtype_cast.register_fake 98 | def _(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: 99 | out = torch.empty_like(x, dtype=dtype) 100 | return out 101 | 102 | 103 | def create_dtype_cast_managed_attr(p_name): 104 | def getter(self): 105 | # TODO: if this function throws exception, how does it behave? add a unit test for it. 106 | return self._name_to_dtype_cast_managed_attr_getter[p_name]() 107 | 108 | def setter(self, value): 109 | raise RuntimeError( 110 | "Setting DTypeCast-managed attribute is not supported", 111 | ) 112 | 113 | return property(getter, setter) 114 | 115 | 116 | def canonicalize_mp(mp_policy: MixedPrecisionPolicy) -> MixedPrecisionPolicy: 117 | # try and follow standard FSDP behavior 118 | # maybe this should be handled in the MixedPrecisionPolicy class itself 119 | param_dtype = mp_policy.param_dtype 120 | reduce_dtype = mp_policy.reduce_dtype or param_dtype 121 | output_dtype = mp_policy.output_dtype or param_dtype # TODO: check if this is right 122 | cast_forward_inputs = mp_policy.cast_forward_inputs 123 | return MixedPrecisionPolicy( 124 | param_dtype, reduce_dtype, output_dtype, cast_forward_inputs 125 | ) 126 | 127 | 128 | _active_param = False 129 | 130 | 131 | def active_param(): 132 | global _active_param 133 | return _active_param 134 | 135 | 136 | @contextmanager 137 | def set_dtype_cast(val): 138 | global _active_param 139 | prev = _active_param 140 | try: 141 | _active_param = val 142 | yield 143 | finally: 144 | _active_param = prev 145 | 146 | 147 | # taken from https://www.internalfb.com/code/fbsource/[master][history]/fbcode/caffe2/torch/distributed/fb/simple_fsdp/simple_fsdp.py 148 | # with minor modifications 149 | def apply_dtype_cast(model, mp_policy: MixedPrecisionPolicy): 150 | mp_policy = canonicalize_mp(mp_policy) 151 | cls_key_to_dtype_cast_cls: dict[tuple[Type, str], Type] = {} 152 | 153 | for mod_name, mod in sorted(model.named_modules()): 154 | params_dict = dict(mod.named_parameters(recurse=False)) 155 | 156 | # Create new class for this module with all parametrized parameters 157 | cls = mod.__class__ 158 | param_properties_key = "#".join(sorted(params_dict.keys())) 159 | new_cls = cls_key_to_dtype_cast_cls.get((cls, param_properties_key), None) 160 | if not new_cls: 161 | namespace = {"__getstate__": getstate} 162 | # We don't allow serialization of parametrized modules but should still allow deepcopying. 163 | # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. 164 | if not hasattr(cls, "__deepcopy__"): 165 | namespace["__deepcopy__"] = default_deepcopy # type: ignore[assignment] 166 | 167 | for p_name in params_dict.keys(): 168 | # NOTE: it's important to have this indirection, to make sure that: 169 | # Different instances of the same class can resolve their parameter access to instance-specific getters 170 | # (which contains unique objects used in that instance-specific parameter's unshard operation). 171 | namespace[p_name] = create_dtype_cast_managed_attr(p_name) 172 | cls_t = (DTypeCastModule, cls) if mod is model else (cls,) 173 | new_cls = type(f"DTypeCast{cls.__name__}", cls_t, namespace) 174 | cls_key_to_dtype_cast_cls[(cls, param_properties_key)] = new_cls 175 | mod.__class__ = new_cls 176 | 177 | param_properties = {} 178 | for p_name in params_dict.keys(): 179 | param_properties[p_name] = make_getter(mod, p_name, mp_policy) 180 | 181 | mod._name_to_dtype_cast_managed_attr_getter = param_properties 182 | mod._mp_policy = mp_policy 183 | 184 | return model 185 | 186 | 187 | class DTypeCastModule(torch.nn.Module): 188 | def forward(self, *args, **kwargs): 189 | def cast_fn(x): 190 | if not torch.is_floating_point(x): 191 | return x 192 | return x.to(self._mp_policy.param_dtype) 193 | 194 | if self._mp_policy.cast_forward_inputs: 195 | args, kwargs = tree_map(cast_fn, args), tree_map(cast_fn, kwargs) 196 | output = super().forward(*args, **kwargs) 197 | 198 | def cast_out_fn(x): 199 | return x.to(self._mp_policy.output_dtype) 200 | 201 | output = tree_map(cast_out_fn, output) 202 | return output 203 | -------------------------------------------------------------------------------- /examples/example_local_map.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import functools 7 | 8 | import torch 9 | import torch.fx.traceback as fx_traceback 10 | from torch import nn 11 | from torch.distributed._tensor.experimental import local_map 12 | from torch.distributed.fsdp import MixedPrecisionPolicy 13 | from torch.distributed.tensor.placement_types import Replicate, Shard 14 | from torch.testing._internal.distributed.fake_pg import FakeStore 15 | from torch.utils.checkpoint import create_selective_checkpoint_contexts 16 | 17 | from autoparallel.api import AutoParallel 18 | 19 | world_size = 256 20 | 21 | fake_store = FakeStore() 22 | torch.distributed.init_process_group( 23 | "fake", store=fake_store, rank=0, world_size=world_size 24 | ) 25 | mesh = torch.distributed.device_mesh.init_device_mesh( 26 | "cuda", 27 | (world_size // 32, 8, 4), 28 | mesh_dim_names=( 29 | "dp", 30 | "tp", 31 | "cp", 32 | ), 33 | ) 34 | assert mesh.ndim == 3, "Please also update local_map" 35 | 36 | 37 | def policy_fn(ctx, op, *args, **kwargs): 38 | if ( 39 | op == torch.ops.aten._scaled_dot_product_flash_attention.default 40 | or op == torch.ops.aten._scaled_dot_product_efficient_attention.default 41 | ): 42 | # NOTE: we can't save nondeterministic_seeded ops, the run with rng wrapper is not traceable yet 43 | return torch.utils.checkpoint.CheckpointPolicy.PREFER_SAVE 44 | return torch.utils.checkpoint.CheckpointPolicy.PREFER_RECOMPUTE 45 | 46 | 47 | context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 48 | 49 | 50 | @local_map( 51 | out_placements=((Replicate(), Replicate(), Replicate()),), 52 | in_placements=( 53 | (Replicate(), Replicate(), Replicate()), 54 | (Replicate(), Replicate(), Replicate()), 55 | ), 56 | redistribute_inputs=True, 57 | in_grad_placements=None, 58 | device_mesh=mesh, 59 | ) 60 | def replicate_linear(w, x): 61 | with fx_traceback.annotate({"inside_local_map": 1}): 62 | return torch.matmul(x, w.t()) 63 | 64 | 65 | @local_map( 66 | out_placements=((Shard(0), Shard(0), Replicate()),), 67 | in_placements=((Shard(0), Shard(0), Replicate()),), 68 | redistribute_inputs=True, 69 | in_grad_placements=None, 70 | device_mesh=mesh, 71 | ) 72 | def sharded_pointwise(x): 73 | with fx_traceback.annotate({"inside_local_map": 0}): 74 | return x + 10 75 | 76 | 77 | @local_map( 78 | out_placements=((Shard(0), Shard(1), Shard(2)),), 79 | in_placements=( 80 | (Shard(0), Shard(1), Shard(2)), 81 | (Shard(0), Shard(1), Shard(2)), 82 | (Shard(0), Shard(1), Shard(2)), 83 | ), 84 | redistribute_inputs=True, 85 | in_grad_placements=None, 86 | device_mesh=mesh, 87 | ) 88 | def context_parallel_attention(query, key, value): 89 | with fx_traceback.annotate({"inside_local_map": 2}): 90 | out = nn.functional.scaled_dot_product_attention( 91 | query=query, key=key, value=value, is_causal=False 92 | ) 93 | return out 94 | 95 | 96 | class Block(nn.Module): 97 | def __init__(self, nheads, dim1, dim2): 98 | super().__init__() 99 | self.nheads = nheads 100 | bias = False 101 | self.wq = nn.Linear(dim1, dim1, bias=bias) 102 | self.wk = nn.Linear(dim1, dim1, bias=bias) 103 | self.wv = nn.Linear(dim1, dim1, bias=bias) 104 | self.wo = nn.Linear(dim1, dim1, bias=bias) 105 | self.w1 = nn.Linear(dim1, dim2, bias=bias) 106 | self.w2 = nn.Linear(dim2, dim1, bias=bias) 107 | 108 | def init_weights(self): 109 | for lin in [self.wq, self.wk, self.wv, self.wo, self.w1, self.w2]: 110 | torch.nn.init.normal_(lin.weight) 111 | if lin.bias is not None: 112 | torch.nn.init.normal_(lin.bias) 113 | 114 | def _compute_attention(self, x): 115 | with fx_traceback.annotate({"inside_checkpoint": 0}): 116 | boosted_weight = sharded_pointwise(self.wq.weight) 117 | q = replicate_linear(boosted_weight, x) 118 | k = self.wk(x) 119 | v = self.wv(x) 120 | 121 | q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) 122 | k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) 123 | v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) 124 | 125 | o = context_parallel_attention(q, k, v) 126 | o = o.permute(0, 2, 1, 3).flatten(-2) 127 | 128 | o = self.wo(o) 129 | return o 130 | 131 | def forward(self, x): 132 | with fx_traceback.annotate({"outside_checkpoint": 0}): 133 | o = torch.utils.checkpoint.checkpoint( 134 | self._compute_attention, x, use_reentrant=False, context_fn=context_fn 135 | ) 136 | 137 | o0 = o + x 138 | 139 | o = self.w1(o0) 140 | o = torch.nn.functional.relu(o) 141 | o = self.w2(o) 142 | 143 | o = o0 + o 144 | 145 | return o 146 | 147 | 148 | bs = 8 * mesh.shape[0] 149 | seq_len = 256 150 | nheads = 48 151 | dim1 = 6144 152 | dim2 = dim1 * 4 153 | 154 | 155 | def input_fn(): 156 | print(f"global input shape: {(bs, seq_len, dim1)}") 157 | return torch.rand(bs, seq_len, dim1, device="cuda") 158 | 159 | 160 | # parallelize the model 161 | with torch.device("meta"): 162 | model = Block(nheads, dim1, dim2) 163 | 164 | # MP policy causing some deepcopy issues 165 | # mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) 166 | mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) 167 | # mp_policy = None 168 | 169 | with torch.fx.traceback.preserve_node_meta(), AutoParallel( 170 | model, input_fn, mesh, mp_policy, compile=True 171 | ) as autop: 172 | assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes) 173 | assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes) 174 | autop.add_parameter_memory_constraint(low=None, high=None) 175 | 176 | x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) 177 | 178 | autop.add_input_constraints([x_sharding]) 179 | autop.add_output_constraints([x_sharding]) 180 | 181 | sharding_placement = autop.optimize_placement() 182 | 183 | # AutoParallel produces a module with meta-DTensor parameters that need to be initialized 184 | parallel_mod = autop.apply_placement(sharding_placement) 185 | 186 | parallel_mod.to_empty(device="cuda") 187 | parallel_mod.init_weights() 188 | 189 | # now let's run it 190 | x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),) 191 | out = parallel_mod(*x) 192 | out.backward(torch.randn_like(out)) 193 | 194 | # Validate 195 | seqs = set() 196 | for n in autop.gm.graph.nodes: 197 | if "checkpoint" in n.meta.get( 198 | "stack_trace", "" 199 | ): # placeholders don't have stack trace 200 | is_bwd = n.meta.get("partitioner_tag", "") == "is_backward" 201 | if not is_bwd: 202 | if "getitem" in str(n.target): 203 | # getitem nodes are tagged same as their parent 204 | expected = policy_fn(None, n.args[0].target, (), ()) 205 | else: 206 | expected = policy_fn(None, n.target, (), ()) 207 | actual = n.meta.get("recompute") 208 | # NOTE: this assert only supports policy_fns on op alone 209 | assert actual == expected 210 | seqs.add(n.meta["seq_nr"]) 211 | else: 212 | # fwd counterpart should have already populated seqs 213 | assert n.meta["seq_nr"] in seqs 214 | 215 | mm_nodes = autop.gm.graph.find_nodes( 216 | op="call_function", target=torch.ops.aten.mm.default 217 | ) 218 | 219 | metas = [n.meta.get("custom", None) for n in autop.parallel_gm.graph.nodes] 220 | fwd_sdpa, bwd_sdpa = [ 221 | n 222 | for n in autop.parallel_gm.graph.nodes 223 | if "_scaled_dot_product_flash_attention" in n.name 224 | ] 225 | # TODO: Dynamo HOP body is not preserving the fx_traceback.annotate 226 | # We should expect to also see the "inside_local_map" annotation 227 | assert fwd_sdpa.meta["custom"] == { 228 | "inside_checkpoint": 0, 229 | "inside_local_map": 2, 230 | "outside_checkpoint": 0, 231 | } 232 | assert bwd_sdpa.meta["custom"] == { 233 | "inside_checkpoint": 0, 234 | "inside_local_map": 2, 235 | "outside_checkpoint": 0, 236 | } 237 | 238 | print("All good!") 239 | -------------------------------------------------------------------------------- /examples/example_ds3_local_map.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import os 7 | from typing import Optional 8 | 9 | import torch 10 | from torch._subclasses.fake_tensor import FakeTensorMode 11 | from torch.distributed.tensor.placement_types import Shard 12 | from torch.fx.experimental.symbolic_shapes import ShapeEnv 13 | from torch.testing._internal.distributed.fake_pg import FakeStore 14 | 15 | from autoparallel._testing.models.dsv3 import ( 16 | DeepSeekV3Model, 17 | DeepSeekV3ModelArgs, 18 | MoEArgs, 19 | ) 20 | from autoparallel.api import AutoParallel 21 | from autoparallel.utils import NumericsLogger 22 | 23 | 24 | def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str): 25 | seq_len = 1024 26 | if fake_evaluate: 27 | # must symbolically evaluate to run on 32 dp ranks 28 | # world_size = 2048 29 | 30 | world_size = 256 31 | 32 | fake_store = FakeStore() 33 | torch.distributed.init_process_group( 34 | "fake", store=fake_store, rank=0, world_size=world_size 35 | ) 36 | local_rank = torch.distributed.get_rank() 37 | mesh = torch.distributed.device_mesh.init_device_mesh( 38 | "cuda", 39 | (world_size // 64, 64), 40 | mesh_dim_names=( 41 | "dp", 42 | "ep", 43 | ), 44 | ) 45 | 46 | config = DeepSeekV3ModelArgs( 47 | vocab_size=102400, 48 | max_seq_len=seq_len, 49 | dim=2048, 50 | inter_dim=10944, 51 | moe_inter_dim=1408, 52 | n_layers=1, # 27, 53 | n_dense_layers=0, # 1, 54 | n_heads=16, 55 | moe_args=MoEArgs( 56 | num_experts=64, 57 | num_shared_experts=2, 58 | top_k=6, 59 | score_func="softmax", 60 | route_norm=False, 61 | score_before_experts=False, 62 | mesh=mesh, 63 | ), 64 | q_lora_rank=0, 65 | kv_lora_rank=512, 66 | qk_nope_head_dim=128, 67 | qk_rope_head_dim=64, 68 | v_head_dim=128, 69 | mscale=0.70, 70 | use_flex_attn=False, 71 | attn_mask_type="causal", 72 | ) 73 | else: 74 | dp_degree = 2 75 | ep_degree = 2 76 | world_size = dp_degree * ep_degree 77 | 78 | assert ( 79 | "WORLD_SIZE" in os.environ 80 | ), f"run with torchrun --standalone --nproc-per-node {world_size}" 81 | assert ( 82 | int(os.getenv("WORLD_SIZE")) == world_size 83 | ), f"Need at least {world_size} GPUs for real evaluation" 84 | local_rank = int(os.getenv("LOCAL_RANK")) 85 | torch.distributed.init_process_group(backend="nccl") 86 | mesh = torch.distributed.device_mesh.init_device_mesh( 87 | "cuda", 88 | (dp_degree, ep_degree), 89 | mesh_dim_names=( 90 | "dp", 91 | "ep", 92 | ), 93 | ) 94 | 95 | config = DeepSeekV3ModelArgs( 96 | vocab_size=2048, 97 | max_seq_len=seq_len, 98 | dim=256, 99 | inter_dim=1024, 100 | moe_inter_dim=256, 101 | n_layers=4, 102 | n_dense_layers=0, 103 | n_heads=16, 104 | moe_args=MoEArgs( 105 | num_experts=4, 106 | num_shared_experts=2, 107 | top_k=2, 108 | score_func="softmax", 109 | route_norm=False, 110 | score_before_experts=False, 111 | mesh=mesh, 112 | ), 113 | q_lora_rank=0, 114 | kv_lora_rank=512, 115 | qk_nope_head_dim=128, 116 | qk_rope_head_dim=64, 117 | v_head_dim=128, 118 | mscale=0.70, 119 | ) 120 | 121 | local_batch_size = 2 122 | global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1] 123 | device = torch.device(f"cuda:{local_rank}") 124 | 125 | # parallelize the model 126 | with torch.device("meta"): 127 | model = DeepSeekV3Model(config).bfloat16() 128 | 129 | def input_fn(): 130 | return torch.randint( 131 | 0, 132 | config.vocab_size, 133 | (global_batch_size, seq_len), 134 | device=device, 135 | ) 136 | 137 | numerics_logger = None 138 | if rng_seed is not None: 139 | numerics_logger = NumericsLogger(logs_dir) 140 | with AutoParallel( 141 | model, input_fn, mesh, dynamic=True, numerics_logger=None 142 | ) as autop: 143 | autop.add_parameter_memory_constraint(low=None, high=None) 144 | 145 | # x_sharding = (Shard(0), Replicate()) 146 | x_sharding = (Shard(0), Shard(0)) 147 | 148 | autop.add_input_constraints([x_sharding]) 149 | autop.add_output_constraints([x_sharding]) 150 | 151 | sharding_placement = autop.optimize_placement(verbose=False) 152 | parallel_mod = autop.apply_placement(sharding_placement) 153 | 154 | parallel_mod.to_empty(device=device) 155 | # run weight init on our sharded DTensor params 156 | # TODO: plumb init_std through 157 | # parallel_mod.init_weights( 158 | # init_std=0.02, buffer_device="cuda" 159 | # ) # maybe not correct value 160 | parallel_mod.init_weights(buffer_device=device, seed=rng_seed) 161 | if rng_seed is not None: 162 | numerics_logger.log_model_weights(parallel_mod) 163 | torch.manual_seed(rng_seed) 164 | 165 | n_microbatches = 16 166 | full_batch = torch.randint( 167 | 0, 168 | config.vocab_size, 169 | (local_batch_size * n_microbatches, seq_len), 170 | device=device, 171 | ) 172 | microbatches = torch.split(full_batch, local_batch_size, dim=0) 173 | assert len(microbatches) == n_microbatches 174 | if rng_seed: 175 | numerics_logger.log_diff( 176 | full_batch.to(torch.float32), prefix="full batch input" 177 | ) 178 | 179 | # Symbolically evaluate in case you want to test running a graph bigger than your gpu 180 | if fake_evaluate: 181 | # all gather on the tokens takes 128 GiB (4GiB * 32 ranks) 182 | shape_env = ShapeEnv() 183 | with FakeTensorMode( 184 | allow_non_fake_inputs=True, 185 | shape_env=shape_env, 186 | ): 187 | # now let's run it 188 | for x in microbatches: 189 | out = parallel_mod(x) 190 | out.backward(torch.ones_like(out)) 191 | else: 192 | for i, x in enumerate(microbatches): 193 | assert x.shape[0] == 2 194 | out = parallel_mod(x) 195 | assert not torch.any(torch.isnan(out)), "Found NaNs in forward output" 196 | out.backward(torch.ones_like(out)) 197 | if rng_seed is not None: 198 | numerics_logger.log_diff(out, prefix=f"mb{i} fwd out") 199 | 200 | if rng_seed is not None: 201 | for k, v in parallel_mod.named_parameters(): 202 | numerics_logger.log_diff(v.grad, prefix=f"grad {k}") 203 | 204 | print("All good!") 205 | 206 | if torch.distributed.is_initialized(): 207 | torch.distributed.barrier() 208 | torch.cuda.synchronize() 209 | torch.distributed.destroy_process_group() 210 | 211 | 212 | if __name__ == "__main__": 213 | import argparse 214 | 215 | parser = argparse.ArgumentParser( 216 | description="Run DeepSeek V3 pipeline parallel example" 217 | ) 218 | parser.add_argument( 219 | "--fake-evaluate", 220 | action="store_true", 221 | default=False, 222 | help="Use fake evaluation mode with FakeTensorMode (default: False)", 223 | ) 224 | parser.add_argument( 225 | "--rng-seed", 226 | type=int, 227 | default=None, 228 | help="Use a specific rng seed and deterministic algorithms for run-to-run invariance (default: None).", 229 | ) 230 | parser.add_argument( 231 | "--logs-dir", 232 | type=str, 233 | default="out/", 234 | help="Directory to store logs (default: ./out/).", 235 | ) 236 | args = parser.parse_args() 237 | 238 | if args.rng_seed is not None: 239 | torch.use_deterministic_algorithms(True) 240 | torch.manual_seed(args.rng_seed) 241 | 242 | run_test( 243 | fake_evaluate=args.fake_evaluate, rng_seed=args.rng_seed, logs_dir=args.logs_dir 244 | ) 245 | -------------------------------------------------------------------------------- /autoparallel/graph_clustering.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # This file is adapted from 7 | # https://github.com/pytorch/pytorch/blob/af10f1f86cc4effc93142a447693d8be55966615/torch/_dynamo/graph_region_tracker.py#L278 8 | # with slight modifications 9 | 10 | import logging 11 | import math 12 | import time 13 | from collections import defaultdict 14 | from typing import Optional 15 | 16 | import torch 17 | from torch._dynamo.graph_region_tracker import ( 18 | Any, 19 | IdenticalNodes, 20 | InputPickler, 21 | Node, 22 | Region, 23 | _populate_recursive_ancestor_map, 24 | fully_expand_region_group, 25 | operator, 26 | tree_flatten, 27 | ) 28 | from torch._inductor.codecache import sha256_hash 29 | from torch.distributed.tensor._dtensor_spec import DTensorSpec 30 | from torch.distributed.tensor._op_schema import OpStrategy 31 | 32 | logger: logging.Logger = logging.getLogger(__name__) 33 | logger.setLevel(logging.INFO) 34 | 35 | 36 | def _extract_args(arg: Any) -> Any: 37 | if isinstance(arg, Node): 38 | return arg.meta.get("val") 39 | elif isinstance(arg, (torch.Tensor, int)): 40 | return arg 41 | else: 42 | return None 43 | 44 | 45 | def _normalize_args( 46 | node: Node, 47 | ) -> tuple[tuple[str, ...], tuple[Optional[Any], ...]]: 48 | flat_args, _ = tree_flatten(node.args) 49 | sorted_kwargs = sorted(node.kwargs.items(), key=operator.itemgetter(0)) 50 | sorted_keys = tuple(sorted(node.kwargs.keys())) 51 | flat_kwargs, _ = tree_flatten(sorted_kwargs) 52 | all_args = flat_args + flat_kwargs 53 | return (sorted_keys, tuple(_extract_args(arg) for arg in all_args)) 54 | 55 | 56 | def _print_output_specs(op_strategy): 57 | output = [] 58 | for s in op_strategy.strategies: 59 | output_placements = [] 60 | output_specs = s.output_specs 61 | if isinstance(output_specs, DTensorSpec): 62 | output_specs = [output_specs] 63 | for output_spec in output_specs: 64 | if output_spec is None: 65 | output_placements.append("(None)") 66 | continue 67 | plc_str = ",".join([str(p) for p in output_spec.placements]) 68 | output_placements.append(f"({plc_str})") 69 | output.append(f"({','.join(output_placements)})") 70 | return ", ".join(output) 71 | 72 | 73 | def _prepare_op_strategy(op_strategy, output_only=False): 74 | # hasing op_strategy is expensive, so we hash the string representation 75 | # instead, which is much cheaper and is a reasonable proxy for the 76 | # clustering 77 | # NOTE: ideally, we woulnd't need to pass the op_strategy at all, 78 | # as we would expect that if two nodes have identical inputs, they would 79 | # also have identical op_strategy. This is actually not the case for 80 | # view ops, which propagate the input shardings to the output. 81 | # So we also add the strategy for a node as a hash key to avoid 82 | # clustering nodes that look the same but have different strategies 83 | if output_only: 84 | return _print_output_specs(op_strategy) 85 | return str(op_strategy) 86 | 87 | 88 | def _hash_node(node, strategies, input_pickler): 89 | key = ( 90 | node.meta.get("stack_trace"), 91 | _normalize_args(node), 92 | _prepare_op_strategy(strategies[node]), 93 | tuple( 94 | _prepare_op_strategy(strategies[s], output_only=True) 95 | for s in node.all_input_nodes 96 | ), 97 | ) 98 | return sha256_hash(input_pickler.dumps(key)) 99 | 100 | 101 | def get_identical_regions( 102 | graph: torch.fx.Graph, strategies: dict[Node, OpStrategy] 103 | ) -> list[list[Region]]: 104 | """ 105 | This function is responsible for extracting the largest regions of identical nodes from the given graph. 106 | **Note**: This function assumes the nodes that have been tracked with track_node are in the provided graph argument. 107 | 108 | The algorithm proceeds as follows: 109 | The nodes tracked via track_node above are organized into region groups. The initial region groups look like this: 110 | [[IdenticalNode1], [IdenticalNode2], [IdenticalNode3]] and each sublist is called a region. For each region group 111 | (starting at the topologically latest region group), the inner regions are gradually expanded one node at time from 112 | the flattened args and kwargs of the node in each region provided that for all regions in the group, the nodes being 113 | added are also identical (ie have the same key computed by track_node). This is checked by verifying that the two 114 | nodes have the same identical node list in node_to_duplicates. 115 | """ 116 | topological_ranking = {node: i for i, node in enumerate(graph.nodes)} 117 | region_groups_with_rank = [] 118 | # needed to detect if replacing a region will create cycles 119 | t = time.time() 120 | node_to_recursive_ancestors = _populate_recursive_ancestor_map(graph) 121 | logger.info(f"Populated recursive ancestors in {time.time() - t} s") 122 | 123 | input_pickler = InputPickler() 124 | hash_to_duplicates: dict[str, IdenticalNodes] = defaultdict(list) 125 | node_to_duplicates: dict[Node, IdenticalNodes] = {} 126 | t = time.time() 127 | for node in graph.nodes: 128 | if node.op == "placeholder": 129 | continue 130 | 131 | duplicates = hash_to_duplicates[_hash_node(node, strategies, input_pickler)] 132 | duplicates.append(node) 133 | node_to_duplicates[node] = duplicates 134 | logger.info(f"Hashed nodes in {time.time() - t} s") 135 | 136 | def _is_identical(n0: Node, n1: Node) -> bool: 137 | return ( 138 | n0 in node_to_duplicates 139 | and n1 in node_to_duplicates 140 | and node_to_duplicates[n0] is node_to_duplicates[n1] 141 | and n0 is not n1 142 | ) 143 | 144 | # Create region groups; a region group is a group 145 | # of regions that are all identical. In this initial state 146 | # each region in the group is a single node, and we discard 147 | # groups that are only a single region. 148 | # We track the topological ranking to start with groups later in the graph 149 | # the reason for this is that we will necessarily create the largest groups first. 150 | for group in hash_to_duplicates.values(): 151 | if len(group) > 1: 152 | region_group = [] 153 | min_rank = math.inf 154 | for node in group: 155 | # some nodes aren't in the topo ranking? 156 | if node in topological_ranking: 157 | min_rank = min(min_rank, topological_ranking[node]) 158 | region_group.append([node]) 159 | 160 | if len(region_group) > 1: 161 | region_groups_with_rank.append((region_group, min_rank)) 162 | 163 | region_groups_with_rank.sort(key=lambda rg: -rg[1]) 164 | region_groups = [rg for rg, _ in region_groups_with_rank] 165 | 166 | # We start from regions later in the graph and expand them earlier 167 | # as a result, we will create the largest regions first and they won't 168 | # overlap. 169 | t = time.time() 170 | seen_nodes: set[Node] = set() 171 | for region_group in region_groups: 172 | # NOTE: this seems like it's missing in the original implementation 173 | # from PyTorch. Given that fully_expand_region_group doesn't check 174 | # if the root from a region is in a seen node, it might end up 175 | # having duplicate nodes in different clusters 176 | if region_group[0][0] in seen_nodes: 177 | continue 178 | fully_expand_region_group( 179 | region_group, 180 | seen_nodes, 181 | node_to_recursive_ancestors, 182 | _is_identical, 183 | ) 184 | # sort topologically 185 | for region in region_group: 186 | region.sort(key=lambda n: topological_ranking[n]) 187 | 188 | region_groups = [ 189 | region_group for region_group in region_groups if len(region_group[0]) > 1 190 | ] 191 | 192 | # sort everything so that we have nodes in topological ranking 193 | for region_group in region_groups: 194 | region_group.sort(key=lambda rg: topological_ranking[rg[0]]) 195 | region_groups.sort(key=lambda rg: topological_ranking[rg[0][0]]) 196 | logger.info(f"Expanded regions in {time.time() - t} s") 197 | 198 | # sanity check that we don't have duplicate nodes 199 | seen_nodes.clear() 200 | for region_group in region_groups: 201 | for region in region_group: 202 | for node in region: 203 | if node in seen_nodes: 204 | raise RuntimeError(f"Duplicate node {node} in region group") 205 | seen_nodes.add(node) 206 | return region_groups 207 | -------------------------------------------------------------------------------- /autoparallel/_passes/graph_multiplex.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import copy 7 | from itertools import dropwhile 8 | 9 | import torch 10 | import torch.fx as fx 11 | from torch._inductor.fx_passes.bucketing import is_wait_tensor 12 | from torch._logging import trace_structured 13 | 14 | 15 | def _add_compute_annotations(gm: fx.GraphModule, tag: str): 16 | """Add compute_region annotations to nodes without custom metadata.""" 17 | for n in gm.graph.nodes: 18 | if n.op == "placeholder": 19 | continue 20 | if n.meta.get("custom", None) is None: 21 | n.meta["custom"] = {"compute_region": tag} 22 | else: 23 | assert "comm_region" in n.meta["custom"] 24 | val = n.meta["custom"]["comm_region"] 25 | n.meta["custom"]["comm_region"] = tag + " " + val 26 | 27 | 28 | def _move_wait_tensors_to_compute_region(gm: fx.GraphModule, tag: str): 29 | """Move wait_tensor nodes from comm_region to compute_region of their users.""" 30 | for n in gm.graph.nodes: 31 | if n.op == "placeholder": 32 | continue 33 | if "comm_region" in n.meta["custom"] and is_wait_tensor(n): 34 | assert len(n.users) >= 1, "wait tensor must have at least one user" 35 | user: fx.Node = next(iter(n.users)) 36 | if "compute_region" in user.meta["custom"]: 37 | n.meta["custom"].pop("comm_region") 38 | n.meta["custom"].update({"compute_region": tag + " " + "wait"}) 39 | if n.next is not user: 40 | user.prepend(n) 41 | 42 | 43 | def multiplex_fw_bw_graph( 44 | fw_gm: fx.GraphModule, bw_gm: fx.GraphModule, overlap_with_annotations: bool = True 45 | ) -> fx.GraphModule: 46 | """ 47 | Multiplexes forward and backward graphs into a single unified graph module. 48 | 49 | This function combines a forward graph and a backward graph into one multiplexed 50 | graph by merging their nodes and outputs. The resulting graph has: 51 | - All placeholders from both forward and backward graphs (backward followed by forward) 52 | - All computation nodes from both graphs (backward followed by forward) 53 | - Combined outputs (backward outputs followed by forward outputs) 54 | 55 | Args: 56 | fw_gm: The forward graph module containing the forward computation 57 | bw_gm: The backward graph module containing the backward computation 58 | 59 | Returns: 60 | A multiplexed fx.GraphModule containing both forward and backward computations 61 | with backward outputs appearing before forward outputs 62 | 63 | Note: 64 | The function preserves node metadata during the merging process. 65 | """ 66 | if overlap_with_annotations: 67 | _add_compute_annotations(fw_gm, "forward") 68 | _add_compute_annotations(bw_gm, "backward") 69 | _move_wait_tensors_to_compute_region(fw_gm, "forward") 70 | _move_wait_tensors_to_compute_region(bw_gm, "backward") 71 | 72 | # Mapping to track correspondence between forward graph nodes and new nodes 73 | old_node_to_new_node: dict[torch.fx.Node, torch.fx.Node] = {} 74 | 75 | # Start with a deep copy of the backward graph as the base 76 | multiplexed_gm = copy.deepcopy(bw_gm) 77 | 78 | # Collect all placeholder nodes from all the graphs 79 | bw_placeholders = bw_gm.graph.find_nodes(op="placeholder") 80 | fw_placeholders = fw_gm.graph.find_nodes(op="placeholder") 81 | insert_point = multiplexed_gm.graph.find_nodes(op="placeholder")[-1] 82 | 83 | # Insert forward placeholders after the backward placeholders of the multiplexed graph 84 | for n in fw_placeholders: 85 | with multiplexed_gm.graph.inserting_after(insert_point): 86 | new_placeholder = multiplexed_gm.graph.placeholder(n.name) 87 | new_placeholder.meta = copy.copy(n.meta) 88 | new_placeholder.target = new_placeholder.name 89 | old_node_to_new_node[n] = new_placeholder 90 | insert_point = new_placeholder 91 | 92 | multiplexed_gm_placeholders = multiplexed_gm.graph.find_nodes(op="placeholder") 93 | assert len(multiplexed_gm_placeholders) == len(fw_placeholders) + len( 94 | bw_placeholders 95 | ) 96 | fw_nodes_iter = iter(fw_gm.graph.nodes) 97 | fw_nodes_iter = dropwhile(lambda n: n.op == "placeholder", fw_nodes_iter) 98 | # Initialize the forward node to be the first non-placeholder node 99 | fn = next(fw_nodes_iter) 100 | if overlap_with_annotations: 101 | # Interleave forward and backward nodes to create overlap pattern: 102 | # bw_compute (if any) -> bw_comm -> fw_compute (if any) -> fw_comm -> [repeat] 103 | # This allows bw_comm to overlap with fw_compute, and fw_comm to overlap with bw_compute 104 | bw_in_comm = False 105 | for bn in multiplexed_gm.graph.nodes: 106 | if bn.op == "placeholder" or bn.op == "output": 107 | continue 108 | # Track when we enter a backward comm region 109 | if "comm_region" in bn.meta["custom"] and not bw_in_comm: 110 | bw_in_comm = True 111 | # When we transition from bw_comm to bw_compute, insert forward nodes 112 | elif "compute_region" in bn.meta["custom"] and bw_in_comm: 113 | bw_in_comm = False 114 | fw_in_comm = False 115 | insert_point = bn 116 | # Insert forward nodes before this bw_compute node 117 | # Note: We cannot reorder nodes within a graph, only their relative order between graphs 118 | while fn.op != "output": 119 | if "comm_region" in fn.meta["custom"] and not fw_in_comm: 120 | fw_in_comm = True 121 | elif "compute_region" in fn.meta["custom"] and fw_in_comm: 122 | # Stop when we reach the next fw_compute after fw_comm 123 | # This ensures we insert one fw_compute + fw_comm cycle per bw_comm -> bw_compute transition 124 | # If fw starts with comm (no compute before it), we still insert it to overlap with future bw_compute 125 | fw_in_comm = False 126 | break 127 | with multiplexed_gm.graph.inserting_before(insert_point): 128 | # Copy node and remap its arguments using the node mapping 129 | new_node = multiplexed_gm.graph.node_copy( 130 | fn, lambda x: old_node_to_new_node[x] 131 | ) 132 | new_node.meta = copy.copy(fn.meta) 133 | old_node_to_new_node[fn] = new_node 134 | fn = next(fw_nodes_iter) 135 | # Insert any remaining forward nodes at the end 136 | # If overlap_with_annotations is False, this concatenates all fw nodes after bw nodes 137 | insert_point = multiplexed_gm.graph.find_nodes(op="output")[-1] 138 | while fn.op != "output": 139 | with multiplexed_gm.graph.inserting_before(insert_point): 140 | # Copy node and remap its arguments using the node mapping 141 | new_node = multiplexed_gm.graph.node_copy( 142 | fn, lambda x: old_node_to_new_node[x] 143 | ) 144 | new_node.meta = copy.copy(fn.meta) 145 | old_node_to_new_node[fn] = new_node 146 | fn = next(fw_nodes_iter) 147 | 148 | # Collect output arguments from forward graph, remapping to new nodes 149 | fw_outputs = fw_gm.graph.find_nodes(op="output") 150 | multiplexed_graph_outputs = multiplexed_gm.graph.find_nodes(op="output") 151 | assert len(multiplexed_graph_outputs) == 1 and len(fw_outputs) == 1 152 | fw_graph_op_node = fw_outputs[0] 153 | fw_op_node_args = [ 154 | old_node_to_new_node[n] if n is not None else None 155 | for n in fw_graph_op_node.args[0] 156 | ] 157 | 158 | # Collect output arguments from multiplexed graph (will contain only bwd_outs) 159 | multiplexed_graph_op_node = multiplexed_graph_outputs[0] 160 | bw_op_node_args = list(multiplexed_graph_op_node.args[0]) 161 | 162 | # Update output node args to prepend backward outputs before forward outputs 163 | multiplexed_graph_op_node.args = (tuple(bw_op_node_args + fw_op_node_args),) 164 | 165 | multiplexed_gm.graph.eliminate_dead_code() 166 | multiplexed_gm.graph.lint() 167 | multiplexed_gm.recompile() 168 | trace_structured( 169 | "artifact", 170 | metadata_fn=lambda: { 171 | "name": "autoparallel_multiplexed_graph", 172 | "encoding": "string", 173 | }, 174 | payload_fn=lambda: multiplexed_gm.print_readable( 175 | print_output=False, include_stride=True, include_device=True 176 | ), 177 | ) 178 | return multiplexed_gm 179 | -------------------------------------------------------------------------------- /partitioned_shard_proposal.md: -------------------------------------------------------------------------------- 1 | # Proposal for PartitionedShard Placement Type 2 | 3 | ## Overview 4 | 5 | This proposal introduces a new placement type `PartitionedShard(Placement)` to the PyTorch distributed tensor framework. This placement type extends the existing sharding capabilities to handle partitioned shards with variable chunk sizes and supports both aligned and unaligned partitioning strategies. 6 | 7 | ## Motivation 8 | 9 | The current `Shard` placement type assumes uniform chunk sizes across all ranks in a mesh dimension. However, for certain use cases such as Mixture of Experts (MoE) models, we need the ability to handle variable-sized partitions within shards. The `PartitionedShard` placement enables efficient distribution of tensors where different partitions may have different sizes while maintaining the semantic structure needed for MoE operations. 10 | 11 | ## Class Definition 12 | 13 | ```python 14 | @dataclass(frozen=True) 15 | class PartitionedShard(Placement): 16 | """ 17 | The PartitionedShard placement describes a DTensor that is sharded on a tensor dimension 18 | where each shard contains multiple partitions of potentially variable sizes. 19 | 20 | This placement type is particularly useful for MoE (Mixture of Experts) models where 21 | different experts may have different sizes, and we need to maintain partition alignment 22 | across different sharding strategies. 23 | 24 | Args: 25 | dim (int): The tensor dimension that describes how the DTensor is sharded 26 | num_partitions (int): Total number of partitions across all shards 27 | splits (List[Union[int, torch.SymInt]]): Number of elements in each partition 28 | aligned (bool): Whether partitions are aligned across shards or not 29 | """ 30 | 31 | dim: int 32 | num_partitions: int 33 | splits: List[Union[int, torch.SymInt]] 34 | aligned: bool = False 35 | ``` 36 | 37 | ## Semantic Description 38 | 39 | ### Core Concepts 40 | 41 | 1. **Partitions**: Logical subdivisions of the tensor along the specified dimension 42 | 2. **Shards**: Physical distribution units across mesh dimensions 43 | 3. **Splits**: Size specification for each partition (not indices, but element counts) 44 | 4. **Alignment**: Strategy for how partitions are distributed across shards 45 | 46 | ### Chunk Size Calculation 47 | - The chunk sizes are similar to standard `Shard` where the number of chunks equals the mesh dimension size 48 | - Each chunk can have even or uneven sizes depending on partition alignment 49 | - Total elements across all partitions must equal the tensor dimension size 50 | 51 | ### Alignment Strategies 52 | 53 | #### Unaligned Partitioned Shard 54 | - **Definition**: Each shard contains slices of ALL partitions 55 | - **Example**: With 2 shards and 4 partitions (P00, P01, P02, P03 for shard 0; P10, P11, P12, P13 for shard 1): 56 | - Shard 1: [P00, P01, P02, P03] 57 | - Shard 2: [P10, P11, P12, P13] 58 | - **Use Case**: When partitions need to be processed together within each shard 59 | 60 | #### Aligned Partitioned Shard 61 | - **Definition**: Each shard contains complete partitions (one or more full partitions) 62 | - **Partition Distribution**: `num_partitions_per_shard = num_partitions // num_shards` 63 | - **Example**: With same setup: 64 | - Shard 1: [P00, P10, P01, P11] 65 | - Shard 2: [P02, P12, P03, P13] 66 | - **Use Case**: When partitions can be independently processed across shards 67 | 68 | ## Core Operations 69 | 70 | ### 1. Unaligned to Replicate Conversion 71 | 72 | ```python 73 | def _unaligned_to_replicate( 74 | self, 75 | local_tensor: torch.Tensor, 76 | mesh: DeviceMesh, 77 | mesh_dim: int, 78 | current_logical_shape: List[int], 79 | ) -> torch.Tensor: 80 | """ 81 | Convert unaligned partitioned shard to replicated tensor. 82 | 83 | Process: 84 | 1. All-gather to collect all shards: [P00, P01, P02, P03, P10, P11, P12, P13] 85 | 2. Perform partition alignment using splits to get: [P00, P10, P01, P11, P02, P12, P03, P13] 86 | 3. Return fully reconstructed tensor 87 | """ 88 | ``` 89 | 90 | ### 2. Aligned to Replicate Conversion 91 | 92 | ```python 93 | def _aligned_to_replicate( 94 | self, 95 | local_tensor: torch.Tensor, 96 | mesh: DeviceMesh, 97 | mesh_dim: int, 98 | current_logical_shape: List[int], 99 | ) -> torch.Tensor: 100 | """ 101 | Convert aligned partitioned shard to replicated tensor. 102 | 103 | Process: 104 | 1. All-gather with list of tensors (handles dynamic sizes): [P00, P10, P01, P11, P02, P12, P03, P13] 105 | 2. Concatenate to form complete tensor 106 | """ 107 | ``` 108 | 109 | ### 3. Replicate to Shard Conversion 110 | 111 | ```python 112 | def _replicate_to_unaligned_shard( 113 | self, 114 | local_tensor: torch.Tensor, 115 | mesh: DeviceMesh, 116 | mesh_dim: int, 117 | shard_index: int, 118 | ) -> torch.Tensor: 119 | """ 120 | Convert replicated tensor to unaligned partitioned shard. 121 | 122 | Requirements: 123 | - num_partitions: Total number of partitions 124 | - splits: Partition sizes within each shard 125 | - Chunk size: Sum of splits (uniform across shards except possibly last) 126 | """ 127 | 128 | def _replicate_to_aligned_shard( 129 | self, 130 | local_tensor: torch.Tensor, 131 | mesh: DeviceMesh, 132 | mesh_dim: int, 133 | shard_index: int, 134 | ) -> torch.Tensor: 135 | """ 136 | Convert replicated tensor to aligned partitioned shard. 137 | 138 | Requirements: 139 | - num_partitions: Total number of partitions 140 | - partitions_per_shard: num_partitions / mesh_size 141 | - Variable partition sizes allowed 142 | - Fixed number of partitions per shard 143 | """ 144 | ``` 145 | 146 | ### 4. Alignment Conversion Operations 147 | 148 | ```python 149 | def _unaligned_to_aligned_shard( 150 | self, 151 | local_tensor: torch.Tensor, 152 | mesh: DeviceMesh, 153 | mesh_dim: int, 154 | ) -> torch.Tensor: 155 | """ 156 | Convert unaligned partitioned shard to aligned partitioned shard. 157 | 158 | Algorithm: 159 | 1. Calculate partitions per shard: num_partitions_per_shard = num_partitions / mesh_size 160 | 2. First all-to-all: Exchange split information 161 | - Input splits (shard1): [4,6,4,2], (shard2): [2,4,8,2] 162 | - Output splits (shard1): [4,6,2,4], (shard2): [4,2,8,2] 163 | 3. Compute boundaries: 164 | - in_boundaries = input_splits.reshape(num_shards, num_partitions_per_shard).sum(dim=1) 165 | - out_boundaries = out_splits.reshape(num_shards, num_partitions_per_shard).sum(dim=1) 166 | 4. Second all-to-all: Exchange tensor data using boundaries 167 | 5. Local reordering using out_splits to achieve final alignment 168 | """ 169 | 170 | def _aligned_to_unaligned_shard( 171 | self, 172 | local_tensor: torch.Tensor, 173 | mesh: DeviceMesh, 174 | mesh_dim: int, 175 | ) -> torch.Tensor: 176 | """ 177 | Convert aligned partitioned shard to unaligned partitioned shard. 178 | 179 | This performs the reverse operation of unaligned_to_aligned_shard. 180 | 181 | Algorithm: 182 | Starting state (aligned): 183 | - Shard 1: [P00, P10, P01, P11] with splits [4,2,6,4] 184 | - Shard 2: [P02, P12, P03, P13] with splits [4,8,2,2] 185 | 186 | Goal (unaligned): 187 | - Shard 1: [P00, P01, P02, P03] with splits [4,6,4,2] 188 | - Shard 2: [P10, P11, P12, P13] with splits [2,4,8,2] 189 | 190 | Steps: 191 | 1. Calculate partitions per shard: num_partitions_per_shard = num_partitions / mesh_size 192 | 2. Prepare current split information (what we currently have per shard) 193 | - current_splits = local tensor partition sizes in aligned order 194 | - Example: Shard1 current_splits = [4,2,6,4], Shard2 current_splits = [4,8,2,2] 195 | 3. First all-to-all: Exchange split information to get target unaligned splits 196 | - We need to transpose the split matrix from aligned to unaligned layout 197 | - Input: splits arranged as [shard][partition_within_shard] 198 | - Output: splits arranged as [partition][shard] 199 | - After exchange: Shard1 gets [4,6,4,2], Shard2 gets [2,4,8,2] 200 | 4. Compute boundaries for data exchange: 201 | - in_boundaries: Current chunk boundaries (aligned layout) 202 | - out_boundaries: Target chunk boundaries (unaligned layout) 203 | - in_boundaries = current_splits.reshape(num_partitions_per_shard, num_shards).sum(dim=0) 204 | - out_boundaries = target_splits # Sequential partitions per shard 205 | 5. Second all-to-all: Exchange tensor data using computed boundaries 206 | - Send data from aligned layout to unaligned layout 207 | - Each rank sends: partitions intended for other ranks' unaligned layout 208 | - Each rank receives: sequential partitions for unaligned layout 209 | 6. Local reordering if needed to ensure correct partition order 210 | 211 | Detailed Example: 212 | Input (aligned): Shard1=[P00:4, P10:2, P01:6, P11:4], Shard2=[P02:4, P12:8, P03:2, P13:2] 213 | After step 3: target_splits Shard1=[4,6,4,2], Shard2=[2,4,8,2] 214 | After step 5: Shard1=[P00:4, P01:6, P02:4, P03:2], Shard2=[P10:2, P11:4, P12:8, P13:2] 215 | Final (unaligned): Achieved target layout 216 | """ 217 | ``` 218 | -------------------------------------------------------------------------------- /autoparallel/autobucketing_util/estimation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # mypy: ignore-errors 7 | import os 8 | import pickle 9 | from collections import defaultdict 10 | from typing import Any 11 | 12 | import torch 13 | import torch.distributed as c10d 14 | from torch._inductor import memory, scheduler 15 | from torch._inductor.utils import is_collective 16 | from torch._inductor.virtualized import V 17 | from torch.utils._ordered_set import OrderedSet 18 | 19 | from .bucket_utils import ( 20 | check_ir_node_bucketable, 21 | get_snode_process_group_info, 22 | get_snode_tensor_info, 23 | ) 24 | from .estimation_utils import ( 25 | CommPerfCache, 26 | CompPerfCache, 27 | benchmark_and_cache_comm_dicts, 28 | estimate_comp_time, 29 | ) 30 | 31 | 32 | def sync_dict_across_ranks(runtime_dict, world_size, group=None): 33 | gathered_lists = [None for _ in range(world_size)] 34 | c10d.all_gather_object(gathered_lists, list(runtime_dict.values()), group=group) 35 | median_gathered_time = torch.median(torch.tensor(gathered_lists), dim=0).values 36 | for idx, (key, value) in enumerate(runtime_dict.items()): 37 | runtime_dict[key] = median_gathered_time[idx] 38 | return runtime_dict 39 | 40 | 41 | def benchmark_and_sync_runtime( 42 | sched: "scheduler.Scheduler", 43 | snodes: list["scheduler.BaseSchedulerNode"], 44 | name_to_buf: dict[str, "scheduler.SchedulerBuffer"], 45 | name_to_fused_node: dict[str, "scheduler.BaseSchedulerNode"], 46 | bucketable_nodes: set[str], 47 | configs: Any, 48 | ): 49 | world_size = c10d.distributed_c10d.get_world_size() 50 | 51 | fsdp_ag_input_size_dict = defaultdict(list) 52 | fsdp_rs_output_size_dict = defaultdict(list) 53 | non_fsdp_ag_input_size_dict = defaultdict(list) 54 | non_fsdp_rs_input_size_dict = defaultdict(list) 55 | all_reduce_input_size_dict = defaultdict(list) 56 | all_to_all_input_size_dict = defaultdict(list) 57 | comp_cache, comm_cache = CompPerfCache(), CommPerfCache() 58 | 59 | cali_num_samples = configs.calibrate_number 60 | comp_time_dict = defaultdict(float) 61 | memory_dict = defaultdict(int) 62 | peak_memory_per_step_dict = defaultdict(int) 63 | fsdp_ag_idx = -1 64 | release_steps = [0] 65 | 66 | graph_outputs = OrderedSet(V.graph.get_output_names()) 67 | graph_inputs = OrderedSet(V.graph.graph_inputs.keys()) 68 | _, name_to_freeable_input_buf = memory.prepare_planning_info( 69 | snodes, 70 | name_to_buf, 71 | name_to_fused_node, 72 | graph_inputs, 73 | graph_outputs, 74 | ) 75 | _, memories_at_nodes = memory.estimate_peak_memory( 76 | snodes, name_to_freeable_input_buf, graph_outputs 77 | ) 78 | # ensure memory offset is always positive 79 | if min(memories_at_nodes) < 0: 80 | shift_value = abs(min(memories_at_nodes)) 81 | memories_at_nodes = [x + shift_value for x in memories_at_nodes] 82 | 83 | for idx, snode in enumerate(snodes): 84 | if is_collective( 85 | snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor.default 86 | ): 87 | fsdp_ag_idx += 1 88 | release_steps.append(idx) 89 | node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) 90 | node_pg_info = get_snode_process_group_info( 91 | snode, 92 | expected_op=torch.ops._c10d_functional.all_gather_into_tensor.default, 93 | resolve_pg=True, 94 | ) 95 | if node_pg_info is None: 96 | continue 97 | node_info = node_tensor_info[:-2] + node_pg_info 98 | input_size = node_tensor_info[-2] 99 | if check_ir_node_bucketable(snode.node, bucketable_nodes): 100 | # For FSDP, we assume they have all have the 101 | fsdp_ag_input_size_dict[node_info].append(input_size) 102 | else: 103 | non_fsdp_ag_input_size_dict[node_info].append(input_size) 104 | elif is_collective( 105 | snode.node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default 106 | ): 107 | node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) 108 | node_pg_info = get_snode_process_group_info( 109 | snode, 110 | expected_op=torch.ops._c10d_functional.reduce_scatter_tensor.default, 111 | resolve_pg=True, 112 | ) 113 | if node_pg_info is None: 114 | continue 115 | node_info = node_tensor_info[:-2] + node_pg_info 116 | output_size = node_tensor_info[-1] 117 | if check_ir_node_bucketable(snode.node, bucketable_nodes): 118 | # For FSDP, we assume they have all have the same group size 119 | fsdp_rs_output_size_dict[node_info].append(output_size) 120 | else: 121 | non_fsdp_rs_input_size_dict[node_info].append(output_size) 122 | elif is_collective( 123 | snode.node, op=torch.ops._c10d_functional.all_reduce_.default 124 | ): 125 | node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) 126 | node_pg_info = get_snode_process_group_info( 127 | snode, 128 | expected_op=torch.ops._c10d_functional.all_reduce_.default, 129 | resolve_pg=True, 130 | ) 131 | if node_pg_info is None: 132 | continue 133 | node_info = node_tensor_info[:-2] + node_pg_info 134 | input_size = node_tensor_info[-2] 135 | all_reduce_input_size_dict[node_info].append(input_size) 136 | elif is_collective( 137 | snode.node, op=torch.ops._c10d_functional.all_to_all_single.default 138 | ): 139 | node_tensor_info = get_snode_tensor_info(snode, return_data_size=True) 140 | node_pg_info = get_snode_process_group_info( 141 | snode, 142 | expected_op=torch.ops._c10d_functional.all_to_all_single.default, 143 | resolve_pg=True, 144 | ) 145 | if node_pg_info is None: 146 | continue 147 | node_info = node_tensor_info[:-2] + node_pg_info 148 | input_size = node_tensor_info[-2] 149 | all_to_all_input_size_dict[node_info].append(input_size) 150 | else: 151 | if not is_collective(snode.node): 152 | comp_time = estimate_comp_time( 153 | sched, snode, verbose=False, comp_cache=comp_cache 154 | ) 155 | comp_time_dict[fsdp_ag_idx] += comp_time 156 | memory_dict[fsdp_ag_idx] = max( 157 | abs( 158 | memories_at_nodes[idx + 1] 159 | - memories_at_nodes[release_steps[-1]] 160 | ), 161 | memory_dict[fsdp_ag_idx], 162 | ) 163 | peak_memory_per_step_dict[fsdp_ag_idx] = max( 164 | memories_at_nodes[idx + 1], peak_memory_per_step_dict[fsdp_ag_idx] 165 | ) 166 | else: 167 | print( 168 | "[Relaxed Setting] untracked communication", 169 | snode.node.python_kernel_name, 170 | ) 171 | 172 | # Sync total compute time 173 | comp_time_dict = sync_dict_across_ranks(comp_time_dict, world_size) 174 | memory_dict = sync_dict_across_ranks(memory_dict, world_size) 175 | peak_memory_per_step_dict = sync_dict_across_ranks( 176 | peak_memory_per_step_dict, world_size 177 | ) 178 | 179 | if configs.load_cache and os.path.exists(configs.save_estimation_path): 180 | with open(configs.save_estimation_path, "rb") as file: 181 | cache = pickle.load(file) 182 | comm_cache.cache = cache 183 | comm_cache._update_max_size() 184 | return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict 185 | 186 | benchmark_params = [ 187 | ( 188 | fsdp_ag_input_size_dict, 189 | "torch.ops._c10d_functional.all_gather_into_tensor.default", 190 | cali_num_samples, 191 | ), 192 | ( 193 | fsdp_rs_output_size_dict, 194 | "torch.ops._c10d_functional.reduce_scatter_tensor.default", 195 | cali_num_samples, 196 | ), 197 | ( 198 | non_fsdp_ag_input_size_dict, 199 | "torch.ops._c10d_functional.all_gather_into_tensor.default", 200 | 3, 201 | ), 202 | ( 203 | non_fsdp_rs_input_size_dict, 204 | "torch.ops._c10d_functional.reduce_scatter_tensor.default", 205 | 3, 206 | ), 207 | ( 208 | all_reduce_input_size_dict, 209 | "torch.ops._c10d_functional.all_reduce_.default", 210 | 3, 211 | ), 212 | ( 213 | all_to_all_input_size_dict, 214 | "torch.ops._c10d_functional.all_to_all_single.default", 215 | 3, 216 | ), 217 | ] 218 | for input_size_dict, op_name, num_samples in benchmark_params: 219 | if len(input_size_dict) > 0: 220 | benchmark_and_cache_comm_dicts( 221 | comm_cache, input_size_dict, op_name, num_samples 222 | ) 223 | 224 | median_runtimes = sync_dict_across_ranks(comm_cache.cache, world_size) 225 | comm_cache.cache = median_runtimes 226 | comm_cache._update_max_size() 227 | with open(configs.save_estimation_path, "wb") as file: 228 | pickle.dump(comm_cache.cache, file) 229 | return comm_cache, comp_time_dict, memory_dict, peak_memory_per_step_dict 230 | -------------------------------------------------------------------------------- /examples/example_llama3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import time 7 | from functools import partial 8 | 9 | import torch 10 | from torch.distributed.fsdp import MixedPrecisionPolicy 11 | from torch.distributed.tensor.placement_types import Partial, Replicate, Shard 12 | from torch.testing._internal.distributed.fake_pg import FakeStore 13 | 14 | from autoparallel._testing.models.llama3 import Transformer, TransformerModelArgs 15 | from autoparallel.api import AutoParallel 16 | from autoparallel.auto_bucketing import ( 17 | aten_autobucketing_config, 18 | aten_autobucketing_reordering_pass, 19 | configure_inductor_for_autobucketing, 20 | ) 21 | from autoparallel.debug_helpers import make_custom_runtime_estimation 22 | 23 | world_size = 64 24 | 25 | fake_store = FakeStore() 26 | torch.distributed.init_process_group( 27 | "fake", store=fake_store, rank=0, world_size=world_size 28 | ) 29 | 30 | use_1d_mesh = False 31 | 32 | if use_1d_mesh: 33 | mesh = torch.distributed.device_mesh.init_device_mesh( 34 | "cuda", (world_size,), mesh_dim_names=("dp",) 35 | ) 36 | else: 37 | mesh = torch.distributed.device_mesh.init_device_mesh( 38 | "cuda", 39 | (world_size // 8, 8), 40 | mesh_dim_names=( 41 | "dp", 42 | "tp", 43 | ), 44 | ) 45 | 46 | batch_size = 2 * mesh.shape[0] 47 | seqlen = 2048 * 4 48 | vocab_size = 128256 49 | use_vocab_parallel = not use_1d_mesh 50 | device = torch.device("cuda") 51 | 52 | model_type = "8b" 53 | enable_asynctp = False 54 | 55 | 56 | def model_fn(): 57 | if model_type == "8b": 58 | model_args = TransformerModelArgs( 59 | dim=4096, 60 | n_layers=32, 61 | n_heads=32, 62 | n_kv_heads=8, 63 | ffn_dim_multiplier=1.3, 64 | multiple_of=1024, 65 | rope_theta=500000, 66 | vocab_size=vocab_size, 67 | max_seq_len=seqlen, 68 | ) 69 | elif model_type == "70b": 70 | model_args = TransformerModelArgs( 71 | dim=8192, 72 | n_layers=80, 73 | n_heads=64, 74 | n_kv_heads=8, 75 | ffn_dim_multiplier=1.3, 76 | multiple_of=4096, 77 | rope_theta=500000, 78 | vocab_size=vocab_size, 79 | max_seq_len=seqlen, 80 | ) 81 | else: 82 | raise ValueError(f"{model_type} not available") 83 | m = Transformer(model_args) 84 | return m 85 | 86 | 87 | def input_fn(): 88 | x = torch.randint(0, vocab_size, (batch_size, seqlen), device=device) 89 | return x 90 | 91 | 92 | autobucketing_level = "aten" 93 | 94 | if autobucketing_level == "aten": 95 | aten_autobucketing_config.custom_runtime_estimation = ( 96 | make_custom_runtime_estimation(mesh) 97 | ) 98 | # this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960 99 | torch._inductor.config.reorder_for_peak_memory = False 100 | torch._inductor.config.reorder_for_compute_comm_overlap = False 101 | aten_autobucketing_reordering_pass = partial( 102 | aten_autobucketing_reordering_pass, 103 | configs=aten_autobucketing_config, 104 | ) 105 | torch._inductor.config.post_grad_custom_post_pass = ( 106 | aten_autobucketing_reordering_pass 107 | ) 108 | elif autobucketing_level == "inductor": 109 | configure_inductor_for_autobucketing(autobucketing_level) 110 | else: 111 | raise ValueError(f"Unknown autobucketing_level {autobucketing_level}") 112 | 113 | 114 | # parallelize the model 115 | with torch.device("meta"): 116 | model = model_fn() 117 | 118 | mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) 119 | 120 | 121 | def group_mm_nodes_with_its_gradients(nodes): 122 | fwd_nodes = [n for n in nodes if "nn_module_stack" in n.meta] 123 | bwd_nodes = [n for n in nodes if "fwd_nn_module_stack" in n.meta] 124 | assert len(fwd_nodes) * 2 == len(bwd_nodes) 125 | res = {} 126 | for fwd_node in fwd_nodes: 127 | o = [] 128 | for bwd_node in bwd_nodes: 129 | if fwd_node.meta["nn_module_stack"] == bwd_node.meta["fwd_nn_module_stack"]: 130 | o.append(bwd_node) 131 | assert len(o) == 2 132 | res[fwd_node] = o 133 | return res 134 | 135 | 136 | def force_tp_constraints(autop, mm_nodes, feat_dim=1, bwd_constraint=False): 137 | # out = x @ w - S(0)R, RS(1) -> S(0)S(1) 138 | # g_w = g.T @ x - S(1)S(0), S(0)R -> PS(0) 139 | # g_x = g @ w.T - S(0)S(1), RS(0) -> S(0)P 140 | 141 | add_node_constraint = autop.sharding_optimizer.add_node_constraint 142 | fwd_bwd_groups = group_mm_nodes_with_its_gradients(mm_nodes) 143 | fwd_nodes = list(fwd_bwd_groups.keys()) 144 | dim1 = 0 if feat_dim == 1 else 1 145 | dim2 = 1 if feat_dim == 1 else 0 146 | # assume there are 7 mm nodes per transformer block 147 | # skip last mm as it's the final projection layer 148 | assert ( 149 | len(fwd_nodes) - 1 150 | ) % 7 == 0, f"expected 7 mm nodes per transformer block, {len(fwd_nodes) - 1}" 151 | for block in range(0, len(fwd_nodes) - 1, 7): 152 | fwd_nodes_block = fwd_nodes[block : block + 7] 153 | # force the first 3 mm nodes to be S(0)S(1) 154 | the_nodes = fwd_nodes_block[:3] + fwd_nodes_block[4:6] 155 | for n in the_nodes: 156 | add_node_constraint(n, (Shard(0), Shard(feat_dim))) 157 | add_node_constraint(n.all_input_nodes[0], (Shard(0), Replicate())) 158 | add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(1))) 159 | 160 | if bwd_constraint: 161 | bwd_nodes = fwd_bwd_groups[n] 162 | # first is g_w, second is g_x 163 | add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim1))) 164 | add_node_constraint(bwd_nodes[1], (Shard(0), Partial())) 165 | 166 | # add reduction to finish TP, yielding S(0)P 167 | the_nodes = fwd_nodes_block[3:4] + fwd_nodes_block[6:7] 168 | for n in the_nodes: 169 | add_node_constraint(n, (Shard(0), Partial())) 170 | add_node_constraint(n.all_input_nodes[0], (Shard(0), Shard(feat_dim))) 171 | add_node_constraint(n.all_input_nodes[1], (Replicate(), Shard(0))) 172 | 173 | if bwd_constraint: 174 | bwd_nodes = fwd_bwd_groups[n] 175 | # first is g_w, second is g_x 176 | add_node_constraint(bwd_nodes[0], (Partial(), Shard(dim2))) 177 | add_node_constraint(bwd_nodes[1], (Shard(0), Shard(feat_dim))) 178 | 179 | 180 | def add_tp_constraints(autop): 181 | mm_nodes = autop.gm.graph.find_nodes( 182 | op="call_function", target=torch.ops.aten.mm.default 183 | ) 184 | einsum_nodes = autop.gm.graph.find_nodes( 185 | op="call_function", target=torch.ops.aten.einsum.default 186 | ) 187 | assert (len(mm_nodes) > 0) ^ ( 188 | len(einsum_nodes) > 0 189 | ), f"only one should be non-empty, got {len(mm_nodes)} and {len(einsum_nodes)}" 190 | feat_dim = 1 if len(mm_nodes) > 0 else 2 191 | tgt_nodes = mm_nodes + einsum_nodes 192 | force_tp_constraints(autop, tgt_nodes, feat_dim=feat_dim, bwd_constraint=True) 193 | 194 | if einsum_nodes: 195 | # add sequence parallelism if we have einsum nodes 196 | autop.sharding_optimizer.add_node_constraint( 197 | list(tgt_nodes[3].users)[0], (Shard(0), Shard(1)) 198 | ) 199 | autop.sharding_optimizer.add_node_constraint( 200 | list(list(tgt_nodes[3].users)[0].users)[0], (Shard(0), Shard(1)) 201 | ) 202 | 203 | 204 | # parallelize the model 205 | with AutoParallel( 206 | model, input_fn, mesh, mp_policy, compile=True, repeated_subgraphs=True 207 | ) as autop: 208 | autop.add_parameter_memory_constraint(low=None, high=None) 209 | 210 | x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) 211 | out_sharding = x_sharding 212 | if use_vocab_parallel: 213 | # add vocab parallel constraint 214 | assert mesh.ndim == 2, "Only 2d mesh supported here" 215 | out_sharding = (Shard(0), Shard(2)) 216 | 217 | autop.add_input_constraints([x_sharding]) 218 | autop.add_output_constraints([out_sharding]) 219 | 220 | enable_manual_constraint = False 221 | if enable_manual_constraint and not use_1d_mesh: 222 | add_tp_constraints(autop) 223 | 224 | if enable_asynctp: 225 | from torch.distributed._symmetric_memory import enable_symm_mem_for_group 226 | 227 | enable_symm_mem_for_group(mesh["dp"].get_group().group_name) 228 | enable_symm_mem_for_group(mesh["tp"].get_group().group_name) 229 | torch._inductor.config._micro_pipeline_tp = False 230 | from autoparallel.asynctp import micro_pipeline_tp_pass 231 | 232 | existing_post_grad_custom_post_pass = ( 233 | torch._inductor.config.post_grad_custom_post_pass 234 | ) 235 | 236 | def _pass(graph): 237 | if existing_post_grad_custom_post_pass is not None: 238 | existing_post_grad_custom_post_pass(graph) 239 | micro_pipeline_tp_pass(graph) 240 | 241 | torch._inductor.config.post_grad_custom_post_pass = _pass 242 | 243 | t = time.time() 244 | sharding_placement = autop.optimize_placement(verbose=True) 245 | print(f"Took {time.time() - t:.2f} s") 246 | parallel_mod = autop.apply_placement(sharding_placement) 247 | 248 | # run weight init on our sharded DTensor params 249 | parallel_mod.to_empty(device="cuda") 250 | parallel_mod.init_weights() 251 | 252 | # now let's run it 253 | x = ( 254 | torch.randint( 255 | 0, 256 | vocab_size, 257 | (batch_size // mesh.shape[0], seqlen), 258 | device=torch.device("cuda"), 259 | ), 260 | ) 261 | out = parallel_mod(*x) 262 | out.backward(torch.randn_like(out)) 263 | print("All good!") 264 | -------------------------------------------------------------------------------- /autoparallel/debug_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import inspect 7 | import json 8 | import re 9 | from contextlib import ExitStack 10 | from typing import Any, Callable 11 | 12 | import torch 13 | from torch._functorch.aot_autograd import aot_export_joint_with_descriptors 14 | from torch._inductor.fx_passes.bucketing import is_wait_tensor 15 | from torch._inductor.fx_passes.overlap_scheduling import ( 16 | get_group_name, 17 | schedule_overlap_bucketing, 18 | ) 19 | from torch.utils._dtype_abbrs import dtype_abbrs 20 | 21 | from autoparallel.collective_runtime_estimation import ( 22 | MeshTopoInfo, 23 | allgather_cost, 24 | allreduce_cost, 25 | reduce_scatter_cost, 26 | ) 27 | from autoparallel.compute_estimation import estimate_strategy_runtime_cost 28 | 29 | 30 | def parse_tensor_annotation(annotation: str) -> torch.Tensor: 31 | """ 32 | Parse a tensor annotation string and create a PyTorch tensor. 33 | 34 | Format: dtype[shape][strides]device 35 | Example: f32[384][1]cuda:0 36 | 37 | Args: 38 | annotation: String in format "dtype[shape][strides]device" 39 | 40 | Returns: 41 | A PyTorch tensor with the specified properties 42 | """ 43 | # Parse the annotation string 44 | # Pattern: dtype[shape][strides]device 45 | pattern = r"([a-z0-9]+)(\[[\d,\s]*\])(\[[\d,\s]*\])(.+)" 46 | match = re.match(pattern, annotation) 47 | 48 | if not match: 49 | raise ValueError(f"Invalid tensor annotation format: {annotation}") 50 | 51 | dtype_str, shape_str, strides_str, device_str = match.groups() 52 | 53 | # Map dtype string to PyTorch dtype 54 | dtype_map = {v: k for k, v in dtype_abbrs.items()} 55 | 56 | if dtype_str not in dtype_map: 57 | raise ValueError(f"Unsupported dtype: {dtype_str}") 58 | 59 | dtype = dtype_map[dtype_str] 60 | 61 | # Parse shape: [384] or [384,512] -> (384,) or (384, 512) 62 | shape = ( 63 | tuple(map(int, shape_str.strip("[]").split(","))) 64 | if shape_str.strip("[]") 65 | else () 66 | ) 67 | 68 | # Parse strides: [1] or [512,1] -> (1,) or (512, 1) 69 | strides = ( 70 | tuple(map(int, strides_str.strip("[]").split(","))) 71 | if strides_str.strip("[]") 72 | else () 73 | ) 74 | 75 | # Parse device 76 | device = torch.device(device_str) 77 | 78 | # Create tensor with specified properties 79 | # We create an empty tensor and then use as_strided to set custom strides 80 | if shape: 81 | tensor = torch.empty_strided(shape, stride=strides, dtype=dtype, device=device) 82 | if tensor.dtype.is_floating_point: 83 | tensor.uniform_() 84 | else: 85 | try: 86 | tensor.random_() 87 | tensor = tensor % 128 88 | except NotImplementedError: 89 | tensor.fill_(0) 90 | else: 91 | # Scalar tensor 92 | tensor = torch.empty((), dtype=dtype, device=device) 93 | 94 | return tensor 95 | 96 | 97 | def build_arguments(fn): 98 | sig = inspect.signature(fn) 99 | args = {} 100 | for k, v in sig.parameters.items(): 101 | if k == "self": 102 | continue 103 | anno = v.annotation 104 | args[k] = parse_tensor_annotation(anno) 105 | return args 106 | 107 | 108 | def _is_communication_node(node): 109 | if not node.op == "call_function": 110 | return False 111 | if not isinstance(node.target, torch._ops.OpOverload): 112 | return False 113 | 114 | return node.target.namespace == "_c10d_functional" 115 | 116 | 117 | def make_custom_runtime_estimation(mesh): 118 | def custom_runtime_estimation(node: torch.fx.Node, override_size=None): 119 | if not node.op == "call_function": 120 | return 0 121 | if not isinstance(node.target, torch._ops.OpOverload): 122 | return 0 123 | 124 | if _is_communication_node(node): 125 | target = node.target 126 | if target == torch.ops._c10d_functional.wait_tensor.default: 127 | return 0 128 | # TODO: figure out mesh without reading from global scope 129 | mesh_topo = MeshTopoInfo.build_from_mesh(mesh) 130 | groups_name = tuple(g.group_name for g in mesh.get_all_groups()) 131 | group_name = get_group_name(node) 132 | mesh_dim = groups_name.index(group_name) 133 | t = node.args[0].meta["val"] # type: ignore[union-attr] 134 | comm_bytes_gb = t.numel() * t.itemsize / 2**30 135 | if override_size is not None: 136 | comm_bytes_gb = override_size 137 | if target in { 138 | torch.ops._c10d_functional.all_gather_into_tensor.default, 139 | torch.ops._c10d_functional.all_gather_into_tensor_out.default, 140 | }: 141 | comm_bytes_gb *= mesh.shape[mesh_dim] 142 | return allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) 143 | elif target == torch.ops._c10d_functional.reduce_scatter_tensor.default: 144 | return reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim) 145 | elif target == torch.ops._c10d_functional.all_reduce.default: 146 | return allreduce_cost(comm_bytes_gb, mesh_topo, mesh_dim) 147 | else: 148 | # TODO: add all_to_all cost 149 | return 0 150 | return estimate_strategy_runtime_cost(node, None) 151 | 152 | return custom_runtime_estimation 153 | 154 | 155 | def get_graph_module(gm, args): 156 | stack = ExitStack() 157 | with stack: 158 | joint_with_descriptors = aot_export_joint_with_descriptors( 159 | stack, 160 | gm, 161 | tuple(x for x in args.values()), 162 | ) 163 | return joint_with_descriptors.graph_module 164 | 165 | 166 | def apply_schedule_overlap_bucket(gm, custom_runtime_estimation): 167 | new_gm = schedule_overlap_bucketing( 168 | gm, 169 | collective_bucketing=False, 170 | custom_runtime_estimation=custom_runtime_estimation, 171 | max_compute_pre_fetch=5, 172 | max_in_flight_gb=2.0, 173 | ) 174 | new_gm.recompile() 175 | return new_gm 176 | 177 | 178 | def _get_tid(node): 179 | if _is_communication_node(node): 180 | if node.target == torch.ops._c10d_functional.wait_tensor.default: 181 | return 0 182 | return node.args[-1] 183 | return 0 184 | 185 | 186 | def get_repr(arg, mode="full"): 187 | def get_dtype_repr(dtype): 188 | return dtype_abbrs[dtype] 189 | 190 | if isinstance(arg, torch.Tensor): 191 | out = {} 192 | out["shape"] = tuple(arg.shape) 193 | out["dtype"] = get_dtype_repr(arg.dtype) 194 | return out 195 | 196 | if isinstance(arg, (int, float, str)): 197 | return arg 198 | 199 | if isinstance(arg, torch.dtype): 200 | return get_dtype_repr(arg) 201 | 202 | if isinstance(arg, torch.fx.Node): 203 | if mode == "name_only" or "val" not in arg.meta: 204 | return f"fx node {arg.name}" 205 | elif mode == "full": 206 | return {"name": arg.name, "data": get_repr(arg.meta["val"])} 207 | elif mode == "content_only": 208 | return get_repr(arg.meta["val"]) 209 | else: 210 | raise ValueError(f"Unknown mode {mode}") 211 | 212 | if isinstance(arg, (list, tuple)): 213 | return [get_repr(x, mode="name_only") for x in arg] 214 | 215 | if isinstance(arg, dict): 216 | return {k: get_repr(v, mode="name_only") for k, v in arg.items()} 217 | 218 | return f"arg {type(arg)}" 219 | 220 | 221 | def create_execution_trace( 222 | gm: torch.fx.GraphModule, 223 | runtime_estimator: Callable[[torch.fx.Node], float], 224 | file_path: str = "fake_trace.json", 225 | ): 226 | """ 227 | Create a perfetto trace from a GraphModule representing its execution 228 | trace. This is useful for inspecting communication-computation overlapping 229 | for different reordering strategies. 230 | """ 231 | trace: dict[str, Any] = {} 232 | trace_events = [] 233 | curr_time = {0: 0} 234 | global_time: dict[torch.fx.Node, int] = {} 235 | for node_idx, node in enumerate(gm.graph.nodes): 236 | dur = int(runtime_estimator(node)) 237 | tid = _get_tid(node) 238 | if tid not in curr_time: 239 | curr_time[tid] = curr_time[0] 240 | event = {"ph": "X", "cat": "kernel", "name": str(node), "pid": 0, "tid": tid} 241 | if _is_communication_node(node): 242 | if tid == 0 and is_wait_tensor(node) and node.args[0].op != "placeholder": 243 | # if it's wait tensor, let's sync with compute stream 244 | comm_end_time = global_time.pop(node.args[0]) 245 | curr_time[tid] = max(curr_time[tid], comm_end_time) 246 | else: 247 | curr_time[tid] = max(curr_time[0], curr_time[tid]) 248 | 249 | event["ts"] = curr_time[tid] 250 | event["dur"] = dur 251 | launch_overhead = 1 # 1us 252 | curr_time[tid] += dur + launch_overhead 253 | if tid != 0: 254 | curr_time[0] += launch_overhead 255 | # keep track of when a given collective will finish 256 | global_time[node] = curr_time[tid] 257 | 258 | args: dict[str, Any] = {} 259 | args["order"] = node_idx 260 | 261 | args["output"] = get_repr(node, mode="content_only") 262 | node_args = [] 263 | for arg in node.args: 264 | node_args.append(get_repr(arg)) 265 | args["inputs"] = node_args 266 | event["args"] = args 267 | trace_events.append(event) 268 | trace["traceEvents"] = trace_events 269 | trace["traceName"] = "fake_trace.json" 270 | with open(file_path, "w") as fp: 271 | json.dump(trace, fp) 272 | -------------------------------------------------------------------------------- /autoparallel/_passes/split_di_dw_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import copy 7 | import itertools 8 | import operator 9 | 10 | import sympy 11 | import torch 12 | import torch.fx as fx 13 | from torch._functorch.partitioners import ( 14 | SavedForBackwardsAOTOutput, 15 | _extract_fwd_bwd_outputs, 16 | _extract_graph_with_inputs_outputs, 17 | _is_backward_state, 18 | _is_bwd_seed_offset, 19 | _is_fwd_seed_offset, 20 | _is_primal, 21 | _remove_by_name, 22 | find_symbol_binding_fx_nodes, 23 | free_symbols, 24 | is_sym_node, 25 | is_symbol_binding_fx_node, 26 | ) 27 | from torch.utils._ordered_set import OrderedSet 28 | 29 | from autoparallel.apply_sharding import rename_placeholder_node 30 | 31 | # we are running the default partitioner on the bw graph, which requires AC tags being removed. 32 | # At this stage we have already finished running AC anyway, since we have a bw graph 33 | 34 | 35 | def remove_recompute_tags(bw_gm): 36 | for n in bw_gm.graph.nodes: 37 | if "recompute" in n.meta: 38 | del n.meta["recompute"] 39 | 40 | 41 | # We are using the default partitioner to split our backward into dI and dW subgraphs. 42 | # We want to generate the dI subgraph *first*, because: 43 | # - in pipelining we generally want to schedule dI compute before dW 44 | # - the dI compute will potentially compute more activations that we need to plumb into dW compute 45 | # Today, the default partitioner requires that your split on the first K outputs of your combined graph. 46 | # So here, we reorder the outputs of the backward so grad_inputs are first. 47 | 48 | 49 | def reorder_output_grads(bw_gm, num_weight_gradients): 50 | outputs = bw_gm.graph.find_nodes(op="output") 51 | assert len(outputs) == 1 52 | output = outputs[0] 53 | assert isinstance(output.args[0], tuple) 54 | grad_weights, grad_inputs = ( 55 | output.args[0][:num_weight_gradients], 56 | output.args[0][num_weight_gradients:], 57 | ) 58 | new_out_tuple = grad_inputs + grad_weights 59 | with bw_gm.graph.inserting_after(output): 60 | # TODO: also set the new node's meta properly 61 | new_out = bw_gm.graph.output(new_out_tuple) 62 | output.replace_all_uses_with(new_out) 63 | bw_gm.graph.erase_node(output) 64 | return len(grad_inputs) 65 | 66 | 67 | # This is a copy of the function used by the default partitioner, 68 | # which does *not* reorder symint activations. 69 | # This is reordering is needed by the custom autograd.Function in AOTDispatcher, 70 | # but isn't needed in our dI/dW splitting since there is no autograd in the loop. 71 | # TODO: provide a way to gt this behavior automatically out of the default partitioner 72 | def _extract_fwd_bwd_modules( 73 | joint_module: fx.GraphModule, 74 | saved_values: list[fx.Node], 75 | saved_sym_nodes: list[fx.Node], 76 | *, 77 | num_fwd_outputs: int, 78 | ) -> tuple[fx.GraphModule, fx.GraphModule]: 79 | ( 80 | fwd_outputs, 81 | bwd_outputs, 82 | fwd_outputs_descs, 83 | bwd_outputs_descs, 84 | ) = _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs) 85 | placeholders = joint_module.graph.find_nodes(op="placeholder") 86 | primal_inputs = [*filter(_is_primal, placeholders)] 87 | fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)] 88 | bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)] 89 | backward_state_inputs = [*filter(_is_backward_state, placeholders)] 90 | 91 | bwd_graph = _extract_graph_with_inputs_outputs( 92 | joint_module.graph, 93 | saved_values + saved_sym_nodes + bwd_seed_offset_inputs, 94 | bwd_outputs, 95 | bwd_outputs_descs, 96 | "backward", 97 | ignore_must_be_in_fw_bw=True, 98 | ) 99 | 100 | distributed_enabled = torch.distributed.is_available() 101 | 102 | for node in bwd_graph.find_nodes(op="placeholder"): 103 | # This is to filter out saved values that don't actually end up being used by the backwards pass 104 | if not node.users: 105 | _remove_by_name(saved_values, node.name) 106 | _remove_by_name(saved_sym_nodes, node.name) 107 | # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw, 108 | # but this dead activation is actually a collective, 109 | # then the collective will generally by followed by a wait_tensor() call. 110 | # we need to peak one node further to see if this wait_tensor is dead as well. 111 | elif distributed_enabled and all( 112 | n.target is torch.ops._c10d_functional.wait_tensor.default 113 | and len(n.users) == 0 114 | for n in node.users 115 | ): 116 | _remove_by_name(saved_values, node.name) 117 | _remove_by_name(saved_sym_nodes, node.name) 118 | elif _is_backward_state(node): 119 | # BackwardState is saved directly 120 | _remove_by_name(saved_values, node.name) 121 | assert backward_state_inputs 122 | 123 | # Now that we have the finalized list of saved values, we need to ensure 124 | # we propagate all symbols which are referenced by backwards inputs. 125 | # These are not directly used in the graph but are required for downstream 126 | # sizevar assignment 127 | saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet() 128 | saved_sym_nodes_binding = [] 129 | saved_sym_nodes_derived = [] 130 | 131 | # Some symbols may already be bound in the directly saved_sym_nodes, 132 | # keep track of them so we don't re-bind them 133 | for node in saved_sym_nodes: 134 | symbol = is_symbol_binding_fx_node(node) 135 | if symbol: 136 | saved_symbols.add(symbol) 137 | saved_sym_nodes_binding.append(node) 138 | else: 139 | saved_sym_nodes_derived.append(node) 140 | 141 | # Now go through all of the prospective backward inputs and track any 142 | # other symbols we need to bind 143 | symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph) 144 | for node in itertools.chain(saved_sym_nodes_derived, saved_values): 145 | if "val" not in node.meta: 146 | continue 147 | new_symbols = free_symbols(node.meta["val"]) - saved_symbols 148 | # NB: Deterministic order please! 149 | for s in sorted(new_symbols, key=lambda s: s.name): 150 | # NB: For well formed graphs, the symbol should always be present, 151 | # but we also have ways to produce ill-formed graphs, e.g., direct 152 | # make_fx usages, so don't choke in this case 153 | if s not in symbol_bindings: 154 | continue 155 | saved_sym_nodes_binding.append(symbol_bindings[s]) 156 | saved_symbols |= new_symbols 157 | 158 | # Update saved_sym_nodes that are now reordered to have all bindings at 159 | # front. This can also be used later on to figure out the position of saved 160 | # sym nodes in the output of fwd graph. 161 | saved_sym_nodes.clear() 162 | saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived) 163 | 164 | # Now, we re-generate the fwd/bwd graphs. 165 | # NB: This might increase compilation time, but I doubt it matters 166 | fwd_graph = _extract_graph_with_inputs_outputs( 167 | joint_module.graph, 168 | primal_inputs + fwd_seed_offset_inputs, 169 | fwd_outputs + saved_values + saved_sym_nodes, 170 | fwd_outputs_descs 171 | + [ 172 | SavedForBackwardsAOTOutput(i) 173 | for i in range(len(saved_values) + len(saved_sym_nodes)) 174 | ], 175 | "forward", 176 | ignore_must_be_in_fw_bw=True, 177 | ) 178 | bwd_graph = _extract_graph_with_inputs_outputs( 179 | joint_module.graph, 180 | saved_values + saved_sym_nodes + bwd_seed_offset_inputs + backward_state_inputs, 181 | bwd_outputs, 182 | bwd_outputs_descs, 183 | "backward", 184 | ignore_must_be_in_fw_bw=True, 185 | ) 186 | 187 | fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph) 188 | bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph) 189 | return fwd_module, bwd_module 190 | 191 | 192 | # TODO: in theory we can infer num_weight_gradients from the graph metadata directly 193 | def split_di_dw_graph( 194 | bw_gm_old: fx.GraphModule, *, num_weight_gradients: int 195 | ) -> tuple[fx.GraphModule, fx.GraphModule, int]: 196 | # we could consider doing this is a non-mutating way 197 | bw_gm = copy.deepcopy(bw_gm_old) 198 | placeholders = bw_gm.graph.find_nodes(op="placeholder") 199 | for p in placeholders: 200 | if p.name.startswith("tangent"): 201 | name_suffix = p.name[8:] 202 | rename_placeholder_node(bw_gm, p, f"not_tngnt{name_suffix}") 203 | 204 | remove_recompute_tags(bw_gm) 205 | num_input_gradients = reorder_output_grads(bw_gm, num_weight_gradients) 206 | bw_gm.recompile() 207 | 208 | args = list(bw_gm.graph.find_nodes(op="placeholder")) 209 | 210 | # bw_inputs, bw_weights = default_partition(bw_gm, args, num_fwd_outputs=num_input_gradients) 211 | # return bw_inputs, bw_weights, num_input_gradients 212 | 213 | ( 214 | grad_inps, 215 | grad_weights, 216 | grad_inp_descs, 217 | grad_weight_descs, 218 | ) = _extract_fwd_bwd_outputs(bw_gm, num_fwd_outputs=num_input_gradients) 219 | bw_inputs_gm = _extract_graph_with_inputs_outputs( 220 | bw_gm.graph, 221 | args, 222 | grad_inps, 223 | grad_inp_descs, 224 | "forward", 225 | ignore_must_be_in_fw_bw=True, 226 | ) 227 | bw_inputs_gm_node_names = OrderedSet( 228 | node.name for node in bw_inputs_gm.nodes if node.op != "output" 229 | ) 230 | saved_values = [] 231 | saved_sym_nodes = [] 232 | 233 | for node in bw_gm.graph.nodes: 234 | if node.name not in bw_inputs_gm_node_names: 235 | # Not handling mutations for now, 236 | # we can try to re-use more of and/or consolidate with default partitioner 237 | continue 238 | if is_sym_node(node): 239 | saved_sym_nodes.append(node) 240 | elif ( 241 | "tensor_meta" not in node.meta 242 | and node.op == "call_function" 243 | and not isinstance(node.meta.get("val"), torch._subclasses.FakeTensor) 244 | ): 245 | users = node.users 246 | assert all(user.target == operator.getitem for user in users) 247 | saved_values.extend(users) 248 | else: 249 | backward_usages = [ 250 | n for n in node.users if n.name not in bw_inputs_gm_node_names 251 | ] 252 | if "tensor_meta" in node.meta and all( 253 | is_sym_node(n) for n in backward_usages 254 | ): 255 | saved_sym_nodes.extend(backward_usages) 256 | else: 257 | saved_values.append(node) 258 | saved_values = list(dict.fromkeys(saved_values).keys()) 259 | saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys()) 260 | bw_inputs, bw_weights = _extract_fwd_bwd_modules( 261 | bw_gm, 262 | saved_values, 263 | saved_sym_nodes=saved_sym_nodes, 264 | num_fwd_outputs=num_input_gradients, 265 | ) 266 | return bw_inputs, bw_weights, num_input_gradients 267 | -------------------------------------------------------------------------------- /autoparallel/graph_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Union 7 | 8 | import torch 9 | from torch._functorch.aot_autograd import JointWithDescriptors 10 | from torch._inductor.fx_passes.joint_graph import patterns 11 | from torch._inductor.fx_passes.post_grad import remove_assert_ops, remove_noop_ops 12 | from torch._inductor.pattern_matcher import stable_topological_sort 13 | from torch.fx import GraphModule 14 | from torch.fx.experimental._backward_state import BackwardState 15 | 16 | 17 | def cleanup_graph(gm: torch.fx.GraphModule, aggressive: bool = False) -> None: 18 | # TODO: we can switch the default "aggresive" to True and things should 19 | # be even better as we can remove more redundant nodes early on 20 | # I'm keeping compatibility with previous behavior for now, and will 21 | # switch the flag in the future 22 | 23 | # TODO: Make the DCE match exactly the AOTAutograd logic, I don't 24 | # think I trust the default FX DCE logic 25 | gm.graph.eliminate_dead_code() 26 | gm.recompile() 27 | remove_noop_ops(gm.graph) 28 | # TODO: We shouldn't actually remove these 29 | remove_assert_ops(gm.graph) 30 | gm.graph.eliminate_dead_code() 31 | gm.graph.lint() 32 | gm.recompile() 33 | 34 | if aggressive: 35 | maybe_count = patterns.apply(gm) 36 | if maybe_count is not None: 37 | stable_topological_sort(gm.graph) 38 | gm.graph.lint() 39 | gm.recompile() 40 | 41 | 42 | def update_joint_with_descriptors( 43 | joint_with_descriptors: JointWithDescriptors, 44 | updated_gm: GraphModule, 45 | ) -> None: 46 | """ 47 | Assuming we have transformed updated_gm since the time it was captured, 48 | (e.g. by parallelizing it), 49 | this util updates the joint_with_descriptors struct to reference the new gm, and 50 | updates any copies of tensor meta/shape stored in joint_with_descriptors relating to input arguments, 51 | which may have changed shape since the initial trace. 52 | """ 53 | # TODO: should we upstream a util like this? 54 | placeholders = [n for n in updated_gm.graph.nodes if n.op == "placeholder"] 55 | new_local_args = [n.meta["val"] for n in placeholders] 56 | joint_with_descriptors.graph_module = updated_gm 57 | joint_with_descriptors._aot_graph_capture.graph_module = updated_gm 58 | 59 | new_flat_args: list[Union[torch.Tensor, int, torch.SymInt, BackwardState]] = [] 60 | for orig, new in zip(joint_with_descriptors._aot_state.flat_args, new_local_args): 61 | if isinstance(orig, torch.nn.Parameter): 62 | new_flat_args.append(torch.nn.Parameter(new)) 63 | else: 64 | new_flat_args.append(new) 65 | 66 | tangent_idx = len(joint_with_descriptors._aot_state.flat_args) 67 | new_local_tangents = new_local_args[tangent_idx:] 68 | 69 | # For inference mode (no tangents), updated_flat_args should be a list. 70 | # For autograd mode (with tangents), it should be a tuple of (primals, tangents). 71 | if new_local_tangents: 72 | joint_with_descriptors._aot_graph_capture.updated_flat_args = ( 73 | new_flat_args, 74 | new_local_tangents, 75 | ) 76 | else: 77 | joint_with_descriptors._aot_graph_capture.updated_flat_args = new_flat_args 78 | 79 | joint_with_descriptors._aot_state.flat_args = new_flat_args 80 | joint_with_descriptors._aot_state.fw_metadata.traced_tangents = new_local_tangents 81 | 82 | 83 | def _add_alias(gm, version="v1"): 84 | """ 85 | Helper function to add alias nodes to every node in the graph 86 | this gives more configuration opportunities 87 | """ 88 | graph = gm.graph 89 | 90 | nodes = list(graph.nodes) 91 | node_map = {node: idx for idx, node in enumerate(nodes)} 92 | 93 | def _insert_alias(node): 94 | first_user = nodes[min(node_map[n] for n in node.users)] 95 | with graph.inserting_before(first_user): 96 | alias_node = graph.call_function(torch.ops.aten.alias.default, args=(node,)) 97 | alias_node.meta.update(node.meta) 98 | 99 | def delete_user_cb(n): 100 | return n != alias_node 101 | 102 | node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb) 103 | 104 | if version == "v1": 105 | # only on inputs 106 | for node in graph.find_nodes(op="placeholder"): 107 | if len(node.users) == 0: 108 | # node is not used, don't add alias for it 109 | continue 110 | if ( 111 | len(node.users) == 1 112 | and list(node.users)[0].target 113 | == torch.ops.autoparallel.dtype_cast.default 114 | ): 115 | node = list(node.users)[0] 116 | _insert_alias(node) 117 | elif version == "v2": 118 | # for every node that has more than one user 119 | for node in nodes: 120 | if len(node.users) < 2: 121 | continue 122 | # don't add alias for ops which return tuple for now 123 | if not isinstance(node.meta["val"], torch.Tensor): 124 | continue 125 | _insert_alias(node) 126 | else: 127 | raise ValueError(f"Unknown version {version}") 128 | 129 | """ 130 | nodes = [n for n in graph.nodes if n.op == "call_function"] 131 | for node in nodes: 132 | # skip ops which return tuple 133 | if not isinstance(node.meta["val"], torch.Tensor): 134 | continue 135 | with graph.inserting_after(node): 136 | alias_node = graph.call_function(torch.ops.aten.alias.default, args=(node,)) 137 | alias_node.meta.update(node.meta) 138 | 139 | def delete_user_cb(n): 140 | return n != alias_node 141 | 142 | node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb) 143 | 144 | """ 145 | 146 | for node in graph.find_nodes(op="output")[0].all_input_nodes: 147 | with graph.inserting_after(node): 148 | alias_node = graph.call_function(torch.ops.aten.alias.default, args=(node,)) 149 | alias_node.meta.update(node.meta) 150 | 151 | def delete_user_cb(n): 152 | return n != alias_node 153 | 154 | node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb) 155 | 156 | gm.recompile() 157 | return gm 158 | 159 | 160 | def is_collective(node: torch.fx.Node) -> bool: 161 | return ( 162 | node.op == "call_function" 163 | and isinstance(node.target, torch._ops.OpOverload) 164 | and node.target.namespace == "_c10d_functional" 165 | ) 166 | 167 | 168 | def assert_has_no_collectives(gm: torch.fx.GraphModule): 169 | for node in gm.graph.nodes: 170 | if is_collective(node): 171 | raise RuntimeError( 172 | f"AutoParallel expects a single-GPU model " 173 | f"implementation with not collectives in it, but found {node} " 174 | f"operation in \n{node.meta['stack_trace']}.\n" 175 | f"If you want to manually add collectives in the model " 176 | f"(e.g., for optimization purposes), please wrap the region " 177 | f"of the code which contains the collectives in an " 178 | f"autoparallel.local_map_hop.apply_local_map, see " 179 | "examples/example_local_map.py for more information." 180 | ) 181 | 182 | 183 | # NOTE: [nn.Linear decomposition] 184 | # PyTorch currently decomposes any 3d-input nn.Linear (and matmul) into a 185 | # sequence of view -> mm -> view operations. 186 | # This has as a consequence of breaking any type of sharding on both the 187 | # batch and the sequence dimension, because the flattening that happens doesn't 188 | # allow to preserve this sharding. 189 | # While we wait for PyTorch to avoid decomposing nn.Linear, we instead take 190 | # the route of pattern-matching the nn.Linear specific occurences, and we replace 191 | # them with an einsum operator. 192 | # We perform this pattern-matching replacement for both the forward as well as 193 | # the backward pass. 194 | # TODO: use graph_patterns to simplify writing this 195 | def _replace_view_mm_view_with_einsum(gm): 196 | mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default) 197 | for node in mm_nodes: 198 | first_input, second_input = node.all_input_nodes 199 | if first_input.target == torch.ops.aten.view.default: 200 | view_input = first_input.all_input_nodes[0] 201 | users = list(node.users) 202 | if ( 203 | len(users) == 1 204 | and users[0].target == torch.ops.aten.view.default 205 | and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1] 206 | and second_input.meta["val"].ndim == 2 207 | ): 208 | print( 209 | f"Found matmul node {node}, {view_input.meta['val'].shape, second_input.meta['val'].shape}" 210 | ) 211 | ndim = view_input.meta["val"].ndim 212 | assert 1 < ndim <= 10, "Only support up to 10D for now" 213 | 214 | # generate the leading dimensions as a, b, c, etc 215 | dims = "".join([chr(97 + i) for i in range(ndim - 1)]) 216 | mm_equation = f"{dims}k,kn->{dims}n" 217 | with gm.graph.inserting_before(node): 218 | new_node = gm.graph.call_function( 219 | torch.ops.aten.einsum.default, 220 | args=(mm_equation, [view_input, second_input]), 221 | ) 222 | new_node.meta.update(users[0].meta) 223 | users[0].replace_all_uses_with(new_node) 224 | 225 | elif second_input.target == torch.ops.aten.view.default: 226 | if first_input.target != torch.ops.aten.permute.default: 227 | continue 228 | if first_input.all_input_nodes[0].target != torch.ops.aten.view.default: 229 | continue 230 | orig_first = first_input.all_input_nodes[0].all_input_nodes[0] 231 | orig_second = second_input.all_input_nodes[0] 232 | users = list(node.users) 233 | if ( 234 | len(users) == 1 235 | and users[0].target == torch.ops.aten.permute.default 236 | and orig_first.meta["val"].shape[:-1] 237 | == orig_second.meta["val"].shape[:-1] 238 | and node.meta["val"].ndim == 2 239 | ): 240 | print( 241 | f"Found matmul node {node} {orig_first.meta['val'].shape, orig_second.meta['val'].shape}" 242 | ) 243 | 244 | ndim = orig_first.meta["val"].ndim 245 | assert 1 < ndim <= 10, "Only support up to 10D for now" 246 | 247 | # generate the leading dimensions as a, b, c, etc 248 | dims = "".join([chr(97 + i) for i in range(ndim - 1)]) 249 | mm_equation = f"{dims}n,{dims}k->kn" 250 | with gm.graph.inserting_before(node): 251 | new_node = gm.graph.call_function( 252 | torch.ops.aten.einsum.default, 253 | args=(mm_equation, [orig_first, orig_second]), 254 | ) 255 | new_node.meta.update(users[0].meta) 256 | users[0].replace_all_uses_with(new_node) 257 | gm.graph.eliminate_dead_code() 258 | gm.recompile() 259 | -------------------------------------------------------------------------------- /autoparallel/autobucketing_util/reorder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # mypy: ignore-errors 7 | from collections import defaultdict 8 | from enum import IntEnum 9 | from typing import Dict, List, Optional, Tuple 10 | 11 | import torch 12 | from torch._inductor import ir, scheduler 13 | from torch._inductor.utils import contains_collective, contains_wait, is_collective 14 | from torch.utils._ordered_set import OrderedSet 15 | 16 | from .bucket_utils import check_ir_node_bucketable 17 | 18 | 19 | class NodeType(IntEnum): 20 | ALL_GATHER = 0 21 | COMPUTE = 1 22 | REDUCE_SCATTER = 2 23 | AG_WAIT = 3 24 | RS_WAIT = 4 25 | 26 | 27 | def compute_node_users( 28 | snodes: List["scheduler.BaseSchedulerNode"], 29 | ) -> Tuple[ 30 | Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]], 31 | Dict["scheduler.BaseSchedulerNode", OrderedSet["scheduler.BaseSchedulerNode"]], 32 | ]: 33 | """ 34 | Compute the inverse users and users of each node 35 | """ 36 | buf_to_snode: Dict[str, scheduler.BaseSchedulerNode] = {} 37 | for node in snodes: 38 | if isinstance(node, scheduler.FusedSchedulerNode): 39 | for x in node.snodes: 40 | for buf in x.get_outputs(): 41 | buf_to_snode[buf.get_name()] = node 42 | 43 | for buf in node.get_outputs(): 44 | buf_to_snode[buf.get_name()] = node 45 | 46 | inverse_users = {} 47 | keys = list(buf_to_snode.keys()) 48 | for node in snodes: 49 | dep_list = [] 50 | for dep in node.unmet_dependencies: 51 | if dep.name in keys: 52 | dep_list.append(buf_to_snode[dep.name]) 53 | inverse_users.update({node: OrderedSet(dep_list)}) 54 | 55 | node_users: Dict[ 56 | scheduler.BaseSchedulerNode, OrderedSet[scheduler.BaseSchedulerNode] 57 | ] = defaultdict(OrderedSet) 58 | for node, node_inverse_users in inverse_users.items(): 59 | for inverse_user in node_inverse_users: 60 | node_users[inverse_user].add(node) 61 | 62 | return inverse_users, node_users 63 | 64 | 65 | def _get_ir_node_type(ir_node: "ir.Operation", bucketable_ir_nodes) -> NodeType: 66 | """ 67 | Determine the type of a ir node 68 | """ 69 | if isinstance(ir_node, ir._WaitKernel): 70 | # Determine if the wait node is waiting for ALL_GATHER or REDUCE_SCATTER 71 | ir_op_overload = getattr(ir_node.inputs[0], "op_overload", None) 72 | if ( 73 | ir_op_overload == torch.ops._c10d_functional.all_gather_into_tensor.default 74 | and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes) 75 | ): 76 | return NodeType.AG_WAIT 77 | elif ( 78 | ir_op_overload == torch.ops._c10d_functional.reduce_scatter_tensor.default 79 | and check_ir_node_bucketable(ir_node.inputs[0], bucketable_ir_nodes) 80 | ): 81 | return NodeType.RS_WAIT 82 | if isinstance(ir_node, ir._CollectiveKernel): 83 | # Determine if the collective kernel is for ALL_GATHER or REDUCE_SCATTER 84 | ir_op_overload = getattr(ir_node, "op_overload", None) 85 | if is_collective( 86 | ir_node, op=torch.ops._c10d_functional.all_gather_into_tensor.default 87 | ) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes): 88 | return NodeType.ALL_GATHER 89 | elif is_collective( 90 | ir_node, op=torch.ops._c10d_functional.reduce_scatter_tensor.default 91 | ) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes): 92 | return NodeType.REDUCE_SCATTER 93 | 94 | if isinstance(ir_node, ir.FallbackKernel): 95 | python_kernel_name = ir_node.python_kernel_name 96 | if ( 97 | python_kernel_name == "torch.ops._c10d_functional.wait_tensor.default" 98 | and check_ir_node_bucketable(ir_node, bucketable_ir_nodes) 99 | ): 100 | inputs_rs_kernel_name1 = ( 101 | getattr(ir_node.inputs[0], "python_kernel_name", "") 102 | == "torch.ops._c10d_functional.reduce_scatter_tensor.default" 103 | ) 104 | inputs_rs_kernel_name2 = ( 105 | hasattr(ir_node.inputs[0], "inputs") 106 | and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "") 107 | == "torch.ops._c10d_functional.reduce_scatter_tensor.default" 108 | ) 109 | if inputs_rs_kernel_name1 or inputs_rs_kernel_name2: 110 | return NodeType.RS_WAIT 111 | 112 | inputs_ag_kernel_name1 = ( 113 | getattr(ir_node.inputs[0], "python_kernel_name", "") 114 | == "torch.ops._c10d_functional.all_gather_into_tensor_out.default" 115 | ) 116 | inputs_ag_kernel_name2 = ( 117 | hasattr(ir_node.inputs[0], "inputs") 118 | and getattr(ir_node.inputs[0].inputs[0], "python_kernel_name", "") 119 | == "torch.ops._c10d_functional.all_gather_into_tensor_out.default" 120 | ) 121 | if inputs_ag_kernel_name1 or inputs_ag_kernel_name2: 122 | return NodeType.AG_WAIT 123 | elif ( 124 | python_kernel_name 125 | == "torch.ops._c10d_functional.reduce_scatter_tensor.default" 126 | ) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes): 127 | return NodeType.REDUCE_SCATTER 128 | elif ( 129 | python_kernel_name 130 | == "torch.ops._c10d_functional.all_gather_into_tensor_out.default" 131 | ) and check_ir_node_bucketable(ir_node, bucketable_ir_nodes): 132 | return NodeType.ALL_GATHER 133 | return NodeType.COMPUTE 134 | 135 | 136 | def get_node_type(node: "scheduler.BaseSchedulerNode", bucketable_ir_nodes) -> NodeType: 137 | """ 138 | Determine the NodeType of a node 139 | """ 140 | if isinstance(node, scheduler.FusedSchedulerNode): 141 | # Only compute nodes are fused 142 | return NodeType.COMPUTE 143 | 144 | if isinstance(node, scheduler.GroupedSchedulerNode): 145 | # [Only for bucketing]: newly created AG and RS are grouped as GroupedSchedulerNode 146 | child_nodes_type = [ 147 | _get_ir_node_type(n.node, bucketable_ir_nodes) for n in node.snodes 148 | ] 149 | if NodeType.AG_WAIT in child_nodes_type: 150 | return NodeType.AG_WAIT 151 | elif NodeType.RS_WAIT in child_nodes_type: 152 | return NodeType.RS_WAIT 153 | elif NodeType.ALL_GATHER in child_nodes_type: 154 | return NodeType.ALL_GATHER 155 | elif NodeType.REDUCE_SCATTER in child_nodes_type: 156 | return NodeType.REDUCE_SCATTER 157 | else: 158 | return NodeType.COMPUTE 159 | 160 | return _get_ir_node_type(node.node, bucketable_ir_nodes) 161 | 162 | 163 | def reorder_all_gather( 164 | snodes: List["scheduler.BaseSchedulerNode"], 165 | bucketable_ir_nodes: set[str], 166 | all_gather_before_last_wait: Optional[bool] = True, 167 | ) -> List["scheduler.BaseSchedulerNode"]: 168 | """ 169 | Reorder All Gather and Wait in the forward/backward pass; 170 | 1. all_gather_before_last_wait set to True: all_gather_i is reordered before wait_i-1 171 | 2. all_gather_before_last_wait set to False: all_gather_i is reordered after wait_i-1 172 | """ 173 | result_list: List[scheduler.BaseSchedulerNode] = [] 174 | all_gather_list: List[scheduler.BaseSchedulerNode] = [] 175 | node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {} 176 | inverse_users, node_users = compute_node_users(snodes) 177 | 178 | for node in snodes: 179 | node_to_type[node] = get_node_type(node, bucketable_ir_nodes) 180 | snodes.reverse() 181 | for idx, node in enumerate(snodes): 182 | node_type = node_to_type[node] 183 | if node_type in [NodeType.REDUCE_SCATTER, NodeType.COMPUTE, NodeType.RS_WAIT]: 184 | # we do not reorder reduce scatter and compute node 185 | if node not in result_list and node not in all_gather_list: 186 | result_list.append(node) 187 | elif node_type == NodeType.ALL_GATHER: 188 | # gather i-th all gather node and its dependencies 189 | all_gather_list.append(node) 190 | inverse_user = list(inverse_users[node]) 191 | inverse_user = [ 192 | n 193 | for n in inverse_user 194 | if node_to_type[n] == NodeType.COMPUTE 195 | and not contains_collective(n) 196 | and not contains_wait(n) 197 | ] 198 | if len(inverse_user) > 0: 199 | all_gather_list.extend(inverse_user) 200 | elif node_type == NodeType.AG_WAIT: 201 | if not all_gather_before_last_wait and len(all_gather_list) > 0: 202 | assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER 203 | # move i-th all gather node and its dependencies after (i-1)-th wait node (bc this is a reverse list) 204 | result_list.extend(all_gather_list) 205 | all_gather_list = [] 206 | 207 | result_list.append(node) 208 | 209 | if all_gather_before_last_wait and len(all_gather_list) > 0: 210 | assert node_to_type[snodes[idx + 1]] == NodeType.ALL_GATHER 211 | # move i-th all gather node and its dependencies before (i-1)-th wait node (bc this is a reverse list) 212 | result_list.extend(all_gather_list) 213 | all_gather_list = [] 214 | if len(all_gather_list) > 0: 215 | result_list.extend(all_gather_list) 216 | result_list.reverse() 217 | 218 | return result_list 219 | 220 | 221 | def reorder_reduce_scatter( 222 | snodes: List["scheduler.BaseSchedulerNode"], 223 | bucketable_ir_nodes: set[str], 224 | ) -> List["scheduler.BaseSchedulerNode"]: 225 | """ 226 | Reorder Reduce Scatter and Wait in the backward pass 227 | reorder wait_i_rs before reduce_scatter_i+1 228 | """ 229 | result_list: List[scheduler.BaseSchedulerNode] = [] 230 | wait_list: List[scheduler.BaseSchedulerNode] = [] 231 | node_to_type: Dict[scheduler.BaseSchedulerNode, int] = {} 232 | inverse_users, node_users = compute_node_users(snodes) 233 | types = [] 234 | for node in snodes: 235 | node_to_type[node] = get_node_type(node, bucketable_ir_nodes) 236 | types.append(get_node_type(node, bucketable_ir_nodes)) 237 | 238 | if NodeType.REDUCE_SCATTER not in types: 239 | return snodes 240 | 241 | for idx, node in enumerate(snodes): 242 | node_type = node_to_type[node] 243 | if node_type in [NodeType.ALL_GATHER, NodeType.COMPUTE, NodeType.AG_WAIT]: 244 | if node not in result_list and node not in wait_list: 245 | result_list.append(node) 246 | elif node_type == NodeType.RS_WAIT: 247 | # there will sometimes be a memory checker node between rs and rs wait 248 | assert node_to_type[snodes[idx - 1]] == NodeType.REDUCE_SCATTER 249 | # gather wait node after reduce scatter 250 | wait_list.append(node) 251 | node_user = node_users[node] 252 | node_user = [n for n in node_user if node_to_type[n] == NodeType.COMPUTE] 253 | # wait_list.extend(node_user) 254 | elif node_type == NodeType.REDUCE_SCATTER: 255 | if len(wait_list) > 0: 256 | # move the i-th wait node before (i+1)-th reduce scatter node 257 | result_list.extend(wait_list) 258 | wait_list = [] 259 | # add reduce scatter node 260 | result_list.append(node) 261 | 262 | if len(wait_list) > 0: 263 | result_list.extend(wait_list) 264 | return result_list 265 | -------------------------------------------------------------------------------- /mast/sweep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | # 3 | # This source code is licensed under the BSD license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import argparse 7 | import importlib.util 8 | import logging 9 | import os 10 | import re 11 | import subprocess 12 | from typing import Optional 13 | 14 | import git 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def print_tbm(results: dict[str, str]) -> None: 20 | tbm_str = " ".join(f"{name}:{job_id}" for name, job_id in results.items()) 21 | print(f"tbm {tbm_str}") 22 | 23 | 24 | def get_git_hash(repo_path): 25 | try: 26 | repo = git.Repo(repo_path) 27 | latest_commit = repo.head.commit 28 | return latest_commit.hexsha 29 | except Exception as e: 30 | logger.error(f"Error accessing Git repository at {repo_path}: {e}") 31 | return None 32 | 33 | 34 | def maybe_tabulate(data, headers=()): 35 | if importlib.util.find_spec("tabulate"): 36 | from tabulate import tabulate 37 | 38 | return tabulate(data, headers=headers) 39 | return f"Please pip install `tabulate` for better printing\n{headers}\n{data}" 40 | 41 | 42 | def is_git_repo_clean(repo_path): 43 | try: 44 | repo = git.Repo(repo_path) 45 | # Check for unstaged changes (modified, added, deleted) 46 | if repo.is_dirty(untracked_files=True): 47 | return False 48 | # Check for staged but uncommitted changes 49 | if repo.index.diff(None): 50 | return False 51 | return True 52 | except git.InvalidGitRepositoryError: 53 | logger.error(f"Error: '{repo_path}' is not a valid Git repository.") 54 | return False 55 | except Exception as e: 56 | logger.error(f"An error occurred: {e}") 57 | return False 58 | 59 | 60 | def find_repo(path: str, name: str) -> str: 61 | try: 62 | # error if not a repo or if not a valid path 63 | _ = git.Repo(path) 64 | assert os.path.exists(os.path.join(path, name)) 65 | return os.path.abspath(path) 66 | except Exception: 67 | logger.error(f"Failed to find {name} repo, pass valid path as argument.") 68 | raise 69 | 70 | 71 | def find_torchtitan(maybe_path: Optional[str] = None) -> str: 72 | return find_repo(maybe_path or "../../torchtitan", "torchtitan") 73 | 74 | 75 | def find_autoparallel(maybe_path: Optional[str] = None) -> str: 76 | return find_repo(maybe_path or "../", "autoparallel") 77 | 78 | 79 | def maybe_find_pulp(maybe_path: Optional[str] = None) -> Optional[str]: 80 | 81 | try: 82 | return find_repo(maybe_path or "../../pulp", "pulp") 83 | except Exception: 84 | logger.error( 85 | "Failed to find pulp repo, will not include it in the run. " 86 | "This is OK if the fbpkg itself includes pulp, which should be true for latest nightly" 87 | ) 88 | return None 89 | 90 | 91 | llama3_1d_common_opts = [ 92 | "--training.local_batch_size=1", 93 | "--parallelism.tensor_parallel_degree=1", 94 | ] 95 | llama3_2d_common_opts = [ 96 | "--training.local_batch_size=2", 97 | "--parallelism.tensor_parallel_degree=8", 98 | ] 99 | llama3_1d = { 100 | "llama3_FSDP_compile": llama3_1d_common_opts 101 | + [ 102 | "--model.name=llama3", 103 | "--compile.enable", 104 | ], 105 | "llama3_autop_1d_compile": llama3_1d_common_opts 106 | + [ 107 | "--model.name=auto_parallel.llama3", 108 | "--compile.enable", 109 | "--experimental.comms_bucket_reorder_strategy=none", 110 | ], 111 | "llama3_autop_1d_compile_aten_bucket_reorder": llama3_1d_common_opts 112 | + [ 113 | "--model.name=auto_parallel.llama3", 114 | "--compile.enable", 115 | "--experimental.comms_bucket_reorder_strategy=aten", 116 | ], 117 | } 118 | 119 | llama3_2d = { 120 | "llama3_FSDP_tp_compile": llama3_2d_common_opts 121 | + [ 122 | "--model.name=llama3", 123 | "--compile.enable", 124 | ], 125 | "llama3_autop_2d_compile": llama3_2d_common_opts 126 | + [ 127 | "--model.name=auto_parallel.llama3", 128 | "--compile.enable", 129 | "--experimental.comms_bucket_reorder_strategy=none", 130 | ], 131 | "llama3_autop_2d_compile_aten_bucket_reorder": llama3_2d_common_opts 132 | + [ 133 | "--model.name=auto_parallel.llama3", 134 | "--compile.enable", 135 | "--experimental.comms_bucket_reorder_strategy=aten", 136 | ], 137 | } 138 | 139 | all_runs = ( 140 | llama3_1d 141 | | llama3_2d 142 | | { 143 | "llama3_autop_1d_compile_inductor_bucket_reorder": llama3_1d_common_opts 144 | + [ 145 | "--model.name=auto_parallel.llama3", 146 | "--compile.enable", 147 | "--experimental.comms_bucket_reorder_strategy=inductor", 148 | ], 149 | "llama3_autop_2d_compile_inductor_bucket_reorder": llama3_2d_common_opts 150 | + [ 151 | "--model.name=auto_parallel.llama3", 152 | "--compile.enable", 153 | "--experimental.comms_bucket_reorder_strategy=inductor", 154 | ], 155 | } 156 | ) 157 | 158 | 159 | def build_sweep(names): 160 | return {name: all_runs[name] for name in names} 161 | 162 | 163 | sweeps = { 164 | "llama3_1d": llama3_1d, 165 | "llama3_2d": llama3_2d, 166 | "update3": build_sweep( 167 | [ 168 | "llama3_FSDP_compile", 169 | "llama3_autop_1d_compile", 170 | "llama3_autop_1d_compile_inductor_bucket_reorder", 171 | "llama3_FSDP_tp_compile", 172 | "llama3_autop_2d_compile", 173 | "llama3_autop_2d_compile_inductor_bucket_reorder", 174 | ] 175 | ), 176 | "compare_1d_bucketing": build_sweep( 177 | [ 178 | "llama3_FSDP_compile", 179 | "llama3_autop_1d_compile", 180 | "llama3_autop_1d_compile_aten_bucket_reorder", 181 | "llama3_autop_1d_compile_inductor_bucket_reorder", 182 | ] 183 | ), 184 | "compare_2d_bucketing": build_sweep( 185 | [ 186 | "llama3_FSDP_tp_compile", 187 | "llama3_autop_2d_compile", 188 | "llama3_autop_2d_compile_aten_bucket_reorder", 189 | "llama3_autop_2d_compile_inductor_bucket_reorder", 190 | ] 191 | ), 192 | } 193 | 194 | 195 | def run(args: argparse.Namespace) -> None: 196 | 197 | if args.runs: 198 | runs = {name: all_runs[name] for name in args.runs} 199 | else: 200 | runs = {} 201 | for sweep in args.sweep: 202 | runs.update(sweeps[sweep]) 203 | 204 | # overrides values in .torchxconfig 205 | scheduler_args = ",".join([f"conda_fbpkg_id={args.fbpkg}"]) 206 | 207 | base_cmd = [ 208 | "torchx", 209 | "run", 210 | f"--scheduler_args={scheduler_args}", 211 | "mast.py:train", 212 | "--nodes", 213 | f"{args.nodes}", 214 | "--additional_folders", 215 | args.torchtitan_dir, 216 | "--twtask_bootstrap_script", 217 | "run_torchtitan.sh", 218 | ] 219 | addl_libs_str = ",".join( 220 | [ 221 | args.autoparallel_dir, 222 | ] 223 | + [args.pulp_dir] 224 | if args.pulp_dir 225 | else [] 226 | ) 227 | addl_libs = [f"--additional_libraries={addl_libs_str}"] 228 | llama3_base = [ 229 | "torchtitan/models/llama3/train_configs/llama3_8b.toml", 230 | "--training.dataset", 231 | "c4", 232 | ] 233 | 234 | def launch_job(cmd: list[str]) -> str: 235 | result = subprocess.run( 236 | cmd, 237 | capture_output=True, 238 | text=True, 239 | # todo move the checking to later, print stdout/err first 240 | check=True, 241 | ) 242 | job_id_pattern = r".*runs\/mast\/([a-zA-Z0-9\-]+)" 243 | for line in result.stdout.splitlines() + result.stderr.splitlines(): 244 | if m := re.match(job_id_pattern, line): 245 | return m.group(1) 246 | 247 | raise RuntimeError( 248 | f"Failed to find job id in torchx launch output. Full stdout:\n {result.stdout}" 249 | ) 250 | 251 | results = {} 252 | autoparallel_hash = get_git_hash(args.autoparallel_dir) 253 | autoparallel_clean = is_git_repo_clean(args.autoparallel_dir) 254 | torchtitan_hash = get_git_hash(args.torchtitan_dir) 255 | torchtitan_clean = is_git_repo_clean(args.torchtitan_dir) 256 | 257 | if not torchtitan_clean or not autoparallel_clean: 258 | logger.warning( 259 | f"Repo is not clean. Please commit your changes before running the script. {autoparallel_clean=} {torchtitan_clean=}" 260 | ) 261 | 262 | extra_torchtitan_args = args.extra_torchtitan_args or [] 263 | extra_torchtitan_name = "_".join(extra_torchtitan_args) 264 | extra_torchtitan_args = ["--" + arg for arg in extra_torchtitan_args] 265 | for name, sub_cmd in runs.items(): 266 | if extra_torchtitan_name: 267 | name += "_" + extra_torchtitan_name 268 | logger.info(f"Launching {name}") 269 | cmd = base_cmd + addl_libs + llama3_base + sub_cmd + extra_torchtitan_args 270 | if args.dry_run: 271 | # TODO configure log levels.. 272 | logger.warning(f"Dry-run: command for {name} is\n" + " ".join(cmd)) 273 | job_id = "dry-run" 274 | else: 275 | job_id = launch_job(cmd) 276 | results[name] = job_id 277 | 278 | print("") 279 | print( 280 | maybe_tabulate( 281 | [ 282 | ["fbpkg", args.fbpkg, "n/a"], 283 | ["autoparallel", autoparallel_hash, autoparallel_clean], 284 | ["torchtitan", torchtitan_hash, torchtitan_clean], 285 | ], 286 | headers=["Repo", "Hash", "Is Clean"], 287 | ) 288 | ) 289 | print("") 290 | print(maybe_tabulate(results.items(), headers=["Name", "Job ID"])) 291 | print("") 292 | print("tbm command:\n") 293 | print_tbm(results) 294 | 295 | 296 | if __name__ == "__main__": 297 | parser = argparse.ArgumentParser( 298 | description="Launch autoparallel runs from a stable configuration. Run from autoparallel/scripts dir." 299 | ) 300 | parser.add_argument( 301 | "--dry-run", 302 | action="store_true", 303 | help="Only show the commands that would be run, don't actually run them", 304 | ) 305 | parser.add_argument( 306 | "--torchtitan_dir", 307 | type=find_torchtitan, 308 | default=find_torchtitan(), 309 | help="Path to torchtitan repo", 310 | ) 311 | parser.add_argument( 312 | "--autoparallel_dir", 313 | type=find_autoparallel, 314 | default=find_autoparallel(), 315 | help="Path to autoparallel repo", 316 | ) 317 | parser.add_argument( 318 | "--pulp_dir", 319 | type=maybe_find_pulp, 320 | default=maybe_find_pulp(), 321 | help="Path to pulp repo, not strictly required but recommended since not all fbpkgs include pulp dep", 322 | ) 323 | parser.add_argument( 324 | "--fbpkg", 325 | default="torchtitan_conda_prod:latest_conveyor_build", 326 | help="Fbpkg to use for job", 327 | ) 328 | parser.add_argument( 329 | "--sweep", 330 | choices=sweeps.keys(), 331 | default="llama3_1d", 332 | nargs="+", 333 | help="Sweep to run, if not specified will run only specified runs", 334 | ) 335 | parser.add_argument( 336 | "--runs", 337 | nargs="+", 338 | choices=all_runs.keys(), 339 | help="exact list of runs to run, overrides sweep", 340 | ) 341 | parser.add_argument( 342 | "--extra_torchtitan_args", 343 | nargs="+", 344 | help="arguments to pass to torchtitan, e.g. 'training.batch_size=2'", 345 | ) 346 | parser.add_argument( 347 | "--nodes", 348 | type=int, 349 | default=8, 350 | help="How many nodes to use for the job, defaults to 8.", 351 | ) 352 | 353 | args = parser.parse_args() 354 | run(args) 355 | --------------------------------------------------------------------------------