├── .ci └── docker │ ├── README.md │ ├── build.sh │ ├── common │ ├── install_base.sh │ ├── install_clang.sh │ ├── install_conda.sh │ ├── install_gcc.sh │ ├── install_user.sh │ └── utils.sh │ ├── conda-env-ci.txt │ ├── requirements-dev.txt │ ├── requirements-flux.txt │ ├── requirements.txt │ └── ubuntu │ └── Dockerfile ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug.yml │ └── config.yml └── workflows │ ├── docker-builds.yml │ ├── integration_test_8gpu.yaml │ ├── integration_test_8gpu_flux.yaml │ ├── integration_test_8gpu_h100.yaml │ ├── integration_test_8gpu_simple_fsdp.yaml │ ├── lint.yaml │ ├── unit_test_cpu.yaml │ └── unit_test_cpu_flux.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── images │ └── loss_curves.png ├── license_header.txt └── version.txt ├── docs ├── checkpoint.md ├── composability.md ├── converging.md ├── datasets.md ├── debugging.md ├── extension.md ├── float8.md ├── fsdp.md ├── metrics.md ├── performance.md └── torchft.md ├── multinode_trainer.slurm ├── pyproject.toml ├── requirements-dev.txt ├── requirements.txt ├── run_train.sh ├── scripts ├── convert_llama_to_dcp.py ├── download_tokenizer.py ├── estimate │ ├── estimation.py │ └── run_memory_estimation.sh └── generate │ ├── README.md │ ├── _generation.py │ ├── run_llama_generate.sh │ └── test_generate.py ├── tests ├── README.md ├── __init__.py ├── assets │ ├── c4_test │ │ └── data.json │ ├── custom_schedule.csv │ ├── extend_jobconfig_example.py │ └── test_tiktoken.model ├── integration_tests.py ├── integration_tests_h100.py └── unit_tests │ ├── __init__.py │ ├── test_checkpoint.py │ ├── test_dataset_checkpointing.py │ ├── test_job_config.py │ ├── test_model_converter.py │ └── test_train_spec.py └── torchtitan ├── __init__.py ├── components ├── checkpoint.py ├── dataloader.py ├── ft.py ├── loss.py ├── lr_scheduler.py ├── metrics.py ├── optimizer.py ├── quantization │ ├── __init__.py │ ├── float8.py │ ├── mx.py │ └── utils.py └── tokenizer.py ├── config_manager.py ├── datasets ├── hf_datasets.py └── tokenizer │ └── tiktoken.py ├── distributed ├── __init__.py ├── parallel_dims.py ├── pipeline.py └── utils.py ├── experiments ├── README.md ├── __init__.py ├── deepseek_v3 │ ├── LICENSE-CODE │ ├── README.md │ ├── __init__.py │ ├── attn_mask_utils.py │ ├── checkpoint.py │ ├── download.py │ ├── dsgemm_kernels.py │ ├── dsgemm_utils.py │ ├── generate.py │ ├── group_gemms.py │ ├── inference.sh │ ├── infra │ │ └── parallelize_deepseek.py │ ├── model.py │ ├── model_args.py │ ├── model_config.py │ ├── moe_kernels.py │ ├── requirements.txt │ ├── run_training.sh │ ├── symm_mem_recipes │ │ ├── __init__.py │ │ ├── triton_barrier.py │ │ ├── triton_on_device_all_to_all_v.py │ │ └── triton_utils.py │ ├── tokenizers │ │ └── hf_tokenizer.py │ ├── train_configs │ │ ├── custom_args.py │ │ └── deepseek_v2.toml │ ├── train_ds_dev.py │ ├── train_ds_real.py │ └── unit_testing │ │ ├── benchmark_kernels.py │ │ ├── dsgemm_unit_testing.py │ │ ├── permute_indices_testing.py │ │ └── test_create_m_indices.py ├── flux │ ├── README.md │ ├── __init__.py │ ├── dataset │ │ ├── flux_dataset.py │ │ └── tokenizer.py │ ├── job_config.py │ ├── loss.py │ ├── model │ │ ├── autoencoder.py │ │ ├── hf_embedder.py │ │ ├── layers.py │ │ ├── math.py │ │ └── model.py │ ├── parallelize_flux.py │ ├── requirements-flux.txt │ ├── run_train.sh │ ├── sampling.py │ ├── scripts │ │ └── download_autoencoder.py │ ├── tests │ │ ├── __init__.py │ │ ├── assets │ │ │ ├── cc12m_test │ │ │ │ ├── cc12m-train-0000.tar │ │ │ │ └── pack_test_dataset.py │ │ │ ├── clip-vit-large-patch14 │ │ │ │ └── config.json │ │ │ └── t5-v1_1-xxl │ │ │ │ └── config.json │ │ ├── integration_tests.py │ │ ├── test_generate_image.py │ │ └── unit_tests │ │ │ ├── __init__.py │ │ │ └── test_flux_dataloader.py │ ├── train.py │ ├── train_configs │ │ ├── debug_model.toml │ │ ├── flux_dev_model.toml │ │ └── flux_schnell_model.toml │ └── utils.py ├── kernels │ ├── moe │ │ ├── indices.py │ │ └── unit_tests │ │ │ └── permute_indices_testing.py │ ├── triton_contiguous_group_gemm │ │ ├── cg_backward.py │ │ ├── cg_forward.py │ │ ├── cg_reference.py │ │ ├── debug.py │ │ ├── tma_cuda_autotune.py │ │ └── unit_test_cg.py │ └── triton_mg_group_gemm │ │ ├── benchmark.py │ │ ├── simpleMoE.py │ │ └── torchao_pr │ │ ├── __init__.py │ │ ├── fast_debug_ao.py │ │ ├── mg_grouped_gemm.py │ │ ├── reference_utils.py │ │ ├── tma_autotuning.py │ │ ├── unit_test_backwards.py │ │ └── unit_test_forwards.py ├── llama4 │ ├── README.md │ ├── __init__.py │ ├── infra │ │ ├── expert_parallel.py │ │ └── parallelize_llama.py │ ├── model │ │ ├── args.py │ │ ├── model.py │ │ └── moe.py │ ├── scripts │ │ ├── REAME.md │ │ ├── convert_hf_to_dcp_with_gpus.py │ │ ├── convert_hf_to_dcp_with_gpus.sh │ │ ├── convert_meta_to_dcp_with_gpus.py │ │ └── convert_meta_to_dcp_with_gpus.sh │ └── train_configs │ │ ├── debug_model.toml │ │ ├── llama4_17bx128e.toml │ │ └── llama4_17bx16e.toml ├── multimodal │ ├── __init__.py │ ├── check_padding_mm.py │ ├── mm_collator.py │ ├── mm_dataset.py │ ├── model.py │ ├── requirements.txt │ ├── tests │ │ ├── __init__.py │ │ ├── test_multimodal_model.py │ │ └── test_utils.py │ ├── tokenizer │ │ └── tiktoken.py │ ├── transform.py │ └── utils.py └── simple_fsdp │ ├── README.md │ ├── __init__.py │ ├── model.py │ ├── parallelize_llama.py │ ├── simple_fsdp.py │ └── tests │ ├── __init__.py │ ├── integration_tests.py │ └── test_numerics.py ├── models ├── __init__.py ├── attention.py └── llama3 │ ├── __init__.py │ ├── model.py │ ├── parallelize_llama.py │ ├── pipeline_llama.py │ └── train_configs │ ├── debug_model.toml │ ├── llama3_405b.toml │ ├── llama3_70b.toml │ └── llama3_8b.toml ├── protocols ├── model_converter.py └── train_spec.py ├── tools ├── logging.py ├── profiling.py └── utils.py └── train.py /.ci/docker/README.md: -------------------------------------------------------------------------------- 1 | # Docker images for TorchTitan CI 2 | 3 | This directory contains everything needed to build the Docker images 4 | that are used in TorchTitan CI. The content of this directory are copied 5 | from PyTorch CI https://github.com/pytorch/pytorch/tree/main/.ci/docker. 6 | It also uses the same directory structure as PyTorch. 7 | 8 | ## Contents 9 | 10 | * `build.sh` -- dispatch script to launch all builds 11 | * `common` -- scripts used to execute individual Docker build stages 12 | * `ubuntu` -- Dockerfile for Ubuntu image for CPU build and test jobs 13 | 14 | ## Usage 15 | 16 | ```bash 17 | # Generic usage 18 | ./build.sh "${IMAGE_NAME}" "${DOCKER_BUILD_PARAMETERS}" 19 | 20 | # Build a specific image 21 | ./build.sh torchtitan-ubuntu-20.04-clang12 -t myimage:latest 22 | ``` 23 | -------------------------------------------------------------------------------- /.ci/docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -exu 9 | 10 | IMAGE_NAME="$1" 11 | shift 12 | 13 | echo "Building ${IMAGE_NAME} Docker image" 14 | 15 | OS=ubuntu 16 | OS_VERSION=20.04 17 | CLANG_VERSION="" 18 | PYTHON_VERSION=3.11 19 | MINICONDA_VERSION=24.3.0-0 20 | 21 | case "${IMAGE_NAME}" in 22 | torchtitan-ubuntu-20.04-clang12) 23 | CLANG_VERSION=12 24 | ;; 25 | *) 26 | echo "Invalid image name ${IMAGE_NAME}" 27 | exit 1 28 | esac 29 | 30 | docker build \ 31 | --no-cache \ 32 | --progress=plain \ 33 | --build-arg "OS_VERSION=${OS_VERSION}" \ 34 | --build-arg "CLANG_VERSION=${CLANG_VERSION}" \ 35 | --build-arg "PYTHON_VERSION=${PYTHON_VERSION}" \ 36 | --build-arg "MINICONDA_VERSION=${MINICONDA_VERSION}" \ 37 | --shm-size=1g \ 38 | -f "${OS}"/Dockerfile \ 39 | "$@" \ 40 | . 41 | -------------------------------------------------------------------------------- /.ci/docker/common/install_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | install_ubuntu() { 11 | apt-get update 12 | 13 | apt-get install -y --no-install-recommends \ 14 | build-essential \ 15 | ca-certificates \ 16 | curl \ 17 | git \ 18 | wget \ 19 | sudo \ 20 | vim \ 21 | jq \ 22 | vim \ 23 | unzip \ 24 | gdb \ 25 | rsync \ 26 | libssl-dev \ 27 | zip 28 | 29 | # Cleanup package manager 30 | apt-get autoclean && apt-get clean 31 | rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 32 | } 33 | 34 | # Install base packages depending on the base OS 35 | ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') 36 | case "$ID" in 37 | ubuntu) 38 | install_ubuntu 39 | ;; 40 | *) 41 | echo "Unable to determine OS..." 42 | exit 1 43 | ;; 44 | esac 45 | -------------------------------------------------------------------------------- /.ci/docker/common/install_clang.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | install_ubuntu() { 11 | apt-get update 12 | 13 | apt-get install -y --no-install-recommends clang-"$CLANG_VERSION" 14 | apt-get install -y --no-install-recommends llvm-"$CLANG_VERSION" 15 | # Also require LLD linker from llvm and libomp to build PyTorch from source 16 | apt-get install -y lld "libomp-${CLANG_VERSION}-dev" 17 | 18 | # Use update-alternatives to make this version the default 19 | update-alternatives --install /usr/bin/clang clang /usr/bin/clang-"$CLANG_VERSION" 50 20 | update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-"$CLANG_VERSION" 50 21 | # Override cc/c++ to clang as well 22 | update-alternatives --install /usr/bin/cc cc /usr/bin/clang 50 23 | update-alternatives --install /usr/bin/c++ c++ /usr/bin/clang++ 50 24 | 25 | # Cleanup package manager 26 | apt-get autoclean && apt-get clean 27 | rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 28 | } 29 | 30 | if [ -n "$CLANG_VERSION" ]; then 31 | # Install base packages depending on the base OS 32 | ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') 33 | case "$ID" in 34 | ubuntu) 35 | install_ubuntu 36 | ;; 37 | *) 38 | echo "Unable to determine OS..." 39 | exit 1 40 | ;; 41 | esac 42 | fi 43 | -------------------------------------------------------------------------------- /.ci/docker/common/install_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | # shellcheck source=/dev/null 11 | source "$(dirname "${BASH_SOURCE[0]}")/utils.sh" 12 | 13 | install_miniconda() { 14 | BASE_URL="https://repo.anaconda.com/miniconda" 15 | CONDA_FILE="Miniconda3-py${PYTHON_VERSION//./}_${MINICONDA_VERSION}-Linux-x86_64.sh" 16 | 17 | mkdir -p /opt/conda 18 | chown ci-user:ci-user /opt/conda 19 | 20 | pushd /tmp 21 | wget -q "${BASE_URL}/${CONDA_FILE}" 22 | # Install miniconda 23 | as_ci_user bash "${CONDA_FILE}" -b -f -p "/opt/conda" 24 | # Clean up the download file 25 | rm "${CONDA_FILE}" 26 | popd 27 | 28 | sed -e 's|PATH="\(.*\)"|PATH="/opt/conda/bin:\1"|g' -i /etc/environment 29 | export PATH="/opt/conda/bin:$PATH" 30 | } 31 | 32 | install_python() { 33 | pushd /opt/conda 34 | # Install the correct Python version 35 | as_ci_user conda create -n "py_${PYTHON_VERSION}" -y --file /opt/conda/conda-env-ci.txt python="${PYTHON_VERSION}" 36 | popd 37 | } 38 | 39 | install_pip_dependencies() { 40 | pushd /opt/conda 41 | # Install all Python dependencies 42 | pip_install -r /opt/conda/requirements-dev.txt 43 | pip_install -r /opt/conda/requirements.txt 44 | pip_install -r /opt/conda/requirements-flux.txt 45 | popd 46 | } 47 | 48 | fix_conda_ubuntu_libstdcxx() { 49 | cat /etc/issue 50 | # WARNING: This is a HACK from PyTorch core to be able to build PyTorch on 22.04. 51 | # Specifically, ubuntu-20+ all comes lib libstdc++ newer than 3.30+, but anaconda 52 | # is stuck with 3.29. So, remove libstdc++6.so.3.29 as installed by 53 | # https://anaconda.org/anaconda/libstdcxx-ng/files?version=11.2.0 54 | # 55 | # PyTorch sev: https://github.com/pytorch/pytorch/issues/105248 56 | # Ref: https://github.com/pytorch/pytorch/blob/main/.ci/docker/common/install_conda.sh 57 | if grep -e "2[02].04." /etc/issue >/dev/null; then 58 | rm "/opt/conda/envs/py_${PYTHON_VERSION}/lib/libstdc++.so.6" 59 | fi 60 | } 61 | 62 | install_miniconda 63 | install_python 64 | install_pip_dependencies 65 | fix_conda_ubuntu_libstdcxx 66 | -------------------------------------------------------------------------------- /.ci/docker/common/install_gcc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | if [ -n "$GCC_VERSION" ]; then 11 | 12 | apt-get update 13 | apt-get install -y g++-"$GCC_VERSION" 14 | update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50 15 | update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50 16 | update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50 17 | 18 | # Cleanup package manager 19 | apt-get autoclean && apt-get clean 20 | rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* 21 | 22 | fi 23 | -------------------------------------------------------------------------------- /.ci/docker/common/install_user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | # Same as ec2-user 11 | echo "ci-user:x:1000:1000::/var/lib/ci-user:" >> /etc/passwd 12 | echo "ci-user:x:1000:" >> /etc/group 13 | # Needed on Focal or newer 14 | echo "ci-user:*:19110:0:99999:7:::" >> /etc/shadow 15 | 16 | # Create $HOME 17 | mkdir -p /var/lib/ci-user 18 | chown ci-user:ci-user /var/lib/ci-user 19 | 20 | # Allow sudo 21 | echo 'ci-user ALL=(ALL) NOPASSWD:ALL' > /etc/sudoers.d/ci-user 22 | 23 | # Test that sudo works 24 | sudo -u ci-user sudo -v 25 | -------------------------------------------------------------------------------- /.ci/docker/common/utils.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | as_ci_user() { 9 | # NB: unsetting the environment variables works around a conda bug 10 | # https://github.com/conda/conda/issues/6576 11 | # NB: Pass on PATH and LD_LIBRARY_PATH to sudo invocation 12 | # NB: This must be run from a directory that the user has access to 13 | sudo -E -H -u ci-user env -u SUDO_UID -u SUDO_GID -u SUDO_COMMAND -u SUDO_USER env "PATH=${PATH}" "LD_LIBRARY_PATH=${LD_LIBRARY_PATH:-}" "$@" 14 | } 15 | 16 | conda_install() { 17 | # Ensure that the install command don't upgrade/downgrade Python 18 | # This should be called as 19 | # conda_install pkg1 pkg2 ... [-c channel] 20 | as_ci_user conda install -q -n "py_${PYTHON_VERSION}" -y python="${PYTHON_VERSION}" "$@" 21 | } 22 | 23 | conda_run() { 24 | as_ci_user conda run -n "py_${PYTHON_VERSION}" --no-capture-output "$@" 25 | } 26 | 27 | pip_install() { 28 | as_ci_user conda run -n "py_${PYTHON_VERSION}" pip install --progress-bar off "$@" 29 | } 30 | -------------------------------------------------------------------------------- /.ci/docker/conda-env-ci.txt: -------------------------------------------------------------------------------- 1 | cmake=3.22.1 2 | ninja=1.10.2 3 | -------------------------------------------------------------------------------- /.ci/docker/requirements-dev.txt: -------------------------------------------------------------------------------- 1 | expecttest==0.1.6 2 | pytest==7.3.2 3 | pytest-cov 4 | pre-commit 5 | tomli-w >= 1.1.0 6 | -------------------------------------------------------------------------------- /.ci/docker/requirements-flux.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.51.1 2 | einops 3 | sentencepiece 4 | pillow 5 | -------------------------------------------------------------------------------- /.ci/docker/requirements.txt: -------------------------------------------------------------------------------- 1 | torchdata >= 0.8.0 2 | datasets >= 3.6.0 3 | tomli >= 1.1.0 ; python_version < "3.11" 4 | tensorboard 5 | tiktoken 6 | blobfile 7 | tabulate 8 | wandb 9 | fsspec 10 | tyro 11 | -------------------------------------------------------------------------------- /.ci/docker/ubuntu/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG OS_VERSION 2 | 3 | FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu${OS_VERSION} 4 | 5 | ARG OS_VERSION 6 | 7 | ENV DEBIAN_FRONTEND noninteractive 8 | 9 | # Install common dependencies 10 | COPY ./common/install_base.sh install_base.sh 11 | RUN bash ./install_base.sh && rm install_base.sh 12 | 13 | # Install clang 14 | ARG CLANG_VERSION 15 | COPY ./common/install_clang.sh install_clang.sh 16 | RUN bash ./install_clang.sh && rm install_clang.sh 17 | 18 | # Install gcc 19 | ARG GCC_VERSION 20 | COPY ./common/install_gcc.sh install_gcc.sh 21 | RUN bash ./install_gcc.sh && rm install_gcc.sh 22 | 23 | # Setup user 24 | COPY ./common/install_user.sh install_user.sh 25 | RUN bash ./install_user.sh && rm install_user.sh 26 | 27 | # Install conda and other dependencies 28 | ARG MINICONDA_VERSION 29 | ARG PYTHON_VERSION 30 | ENV PYTHON_VERSION=$PYTHON_VERSION 31 | ENV PATH /opt/conda/envs/py_$PYTHON_VERSION/bin:/opt/conda/bin:$PATH 32 | COPY requirements-dev.txt /opt/conda/ 33 | COPY requirements.txt /opt/conda/ 34 | COPY requirements-flux.txt /opt/conda/ 35 | COPY conda-env-ci.txt /opt/conda/ 36 | COPY ./common/install_conda.sh install_conda.sh 37 | COPY ./common/utils.sh utils.sh 38 | RUN bash ./install_conda.sh && rm install_conda.sh utils.sh /opt/conda/requirements-dev.txt /opt/conda/requirements.txt /opt/conda/requirements-flux.txt /opt/conda/conda-env-ci.txt 39 | 40 | USER ci-user 41 | CMD ["bash"] 42 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # Suggested config from pytorch that we can adapt 3 | select = B,C,E,F,N,P,T4,W,B9,TOR0,TOR1,TOR2 4 | max-line-length = 120 5 | # C408 ignored because we like the dict keyword argument syntax 6 | # E501 is not flexible enough, we're using B950 instead 7 | # N812 ignored because import torch.nn.functional as F is PyTorch convention 8 | # N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP) 9 | # E731 allow usage of assigning lambda expressions 10 | # N803,N806 allow caps and mixed case in function params. This is to work with Triton kernel coding style. 11 | ignore = 12 | E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,N812,N817,E731,N803,N806 13 | # shebang has extra meaning in fbcode lints, so I think it's not worth trying 14 | # to line this up with executable bit 15 | EXE001, 16 | # these ignores are from flake8-bugbear; please fix! 17 | B007,B008, 18 | optional-ascii-coding = True 19 | exclude = 20 | ./.git, 21 | ./docs 22 | ./build 23 | ./scripts, 24 | ./venv, 25 | *.pyi 26 | .pre-commit-config.yaml 27 | *.md 28 | .flake8 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug.yml: -------------------------------------------------------------------------------- 1 | name: Bug Report 2 | description: Create a report to help us reproduce and fix the bug 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: Bug description 8 | description: | 9 | Please provide a clear and concise description of what the bug is. 10 | validations: 11 | required: true 12 | - type: textarea 13 | attributes: 14 | label: Versions 15 | description: | 16 | Please include the following information if relevant: 17 | 1. Please confirm if your issue can be reproduced on the latest PyTorch nightly, including torch and torchao. 18 | 2. Please attach your .toml config file and command line overrides. 19 | 3. Other relevant information, e.g. runnable code snippet, if you modify torchtitan or work on a fork. 20 | validations: 21 | required: true 22 | - type: markdown 23 | attributes: 24 | value: > 25 | Thanks for submitting the bug report! 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true 2 | -------------------------------------------------------------------------------- /.github/workflows/docker-builds.yml: -------------------------------------------------------------------------------- 1 | name: docker-builds 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | paths: 7 | - .ci/docker/** 8 | - .github/workflows/docker-builds.yml 9 | push: 10 | branches: 11 | - main 12 | - release/* 13 | paths: 14 | - .ci/docker/** 15 | - .github/workflows/docker-builds.yml 16 | schedule: 17 | - cron: 1 3 * * 3 18 | 19 | concurrency: 20 | group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} 21 | cancel-in-progress: true 22 | 23 | jobs: 24 | docker-build: 25 | runs-on: [self-hosted, linux.2xlarge] 26 | timeout-minutes: 240 27 | strategy: 28 | fail-fast: false 29 | matrix: 30 | include: 31 | - docker-image-name: torchtitan-ubuntu-20.04-clang12 32 | env: 33 | DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/torchtitan/${{ matrix.docker-image-name }} 34 | steps: 35 | - name: Clean workspace 36 | shell: bash 37 | run: | 38 | echo "${GITHUB_WORKSPACE}" 39 | sudo rm -rf "${GITHUB_WORKSPACE}" 40 | mkdir "${GITHUB_WORKSPACE}" 41 | 42 | - name: Setup SSH (Click me for login details) 43 | uses: pytorch/test-infra/.github/actions/setup-ssh@main 44 | with: 45 | github-secret: ${{ secrets.GITHUB_TOKEN }} 46 | 47 | - name: Checkout the repo 48 | uses: actions/checkout@v3 49 | 50 | - name: Setup Linux 51 | uses: pytorch/test-infra/.github/actions/setup-linux@main 52 | 53 | - name: Build docker image 54 | id: build-docker-image 55 | uses: pytorch/test-infra/.github/actions/calculate-docker-image@main 56 | with: 57 | docker-image-name: ${{ matrix.docker-image-name }} 58 | always-rebuild: true 59 | push: true 60 | force-push: true 61 | 62 | - name: Teardown Linux 63 | uses: pytorch/test-infra/.github/actions/teardown-linux@main 64 | if: always() 65 | -------------------------------------------------------------------------------- /.github/workflows/integration_test_8gpu.yaml: -------------------------------------------------------------------------------- 1 | name: 8 GPU Integration Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | schedule: 8 | # Runs every 6 hours 9 | - cron: '0 */6 * * *' 10 | concurrency: 11 | group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 12 | cancel-in-progress: true 13 | 14 | defaults: 15 | run: 16 | shell: bash -l -eo pipefail {0} 17 | 18 | jobs: 19 | build-test: 20 | uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main 21 | with: 22 | runner: linux.g5.48xlarge.nvidia.gpu 23 | gpu-arch-type: cuda 24 | gpu-arch-version: "12.6" 25 | # This image is faster to clone than the default, but it lacks CC needed by triton 26 | # (1m25s vs 2m37s). 27 | docker-image: torchtitan-ubuntu-20.04-clang12 28 | repository: pytorch/torchtitan 29 | upload-artifact: outputs 30 | script: | 31 | set -eux 32 | 33 | # The generic Linux job chooses to use base env, not the one setup by the image 34 | CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") 35 | conda activate "${CONDA_ENV}" 36 | 37 | pip config --user set global.progress_bar off 38 | 39 | python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 40 | 41 | USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 42 | 43 | mkdir artifacts-to-be-uploaded 44 | python ./tests/integration_tests.py artifacts-to-be-uploaded --ngpu 8 45 | -------------------------------------------------------------------------------- /.github/workflows/integration_test_8gpu_flux.yaml: -------------------------------------------------------------------------------- 1 | name: Flux 8 GPU Integration Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - 'torchtitan/experiments/flux/**' 8 | pull_request: 9 | paths: 10 | - 'torchtitan/experiments/flux/**' 11 | schedule: 12 | # Runs every 6 hours 13 | - cron: '0 */6 * * *' 14 | concurrency: 15 | group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 16 | cancel-in-progress: true 17 | 18 | defaults: 19 | run: 20 | shell: bash -l -eo pipefail {0} 21 | 22 | jobs: 23 | build-test: 24 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 25 | with: 26 | runner: linux.g5.48xlarge.nvidia.gpu 27 | gpu-arch-type: cuda 28 | gpu-arch-version: "12.6" 29 | # This image is faster to clone than the default, but it lacks CC needed by triton 30 | # (1m25s vs 2m37s). 31 | docker-image: torchtitan-ubuntu-20.04-clang12 32 | repository: pytorch/torchtitan 33 | upload-artifact: outputs 34 | script: | 35 | set -eux 36 | 37 | # The generic Linux job chooses to use base env, not the one setup by the image 38 | CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") 39 | conda activate "${CONDA_ENV}" 40 | 41 | pip config --user set global.progress_bar off 42 | 43 | python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 44 | 45 | mkdir artifacts-to-be-uploaded 46 | python -m torchtitan.experiments.flux.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 47 | -------------------------------------------------------------------------------- /.github/workflows/integration_test_8gpu_h100.yaml: -------------------------------------------------------------------------------- 1 | name: 8 GPU Integration Test at H100 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | schedule: 8 | # Runs every 6 hours 9 | - cron: '0 */6 * * *' 10 | concurrency: 11 | group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 12 | cancel-in-progress: true 13 | 14 | defaults: 15 | run: 16 | shell: bash -l -eo pipefail {0} 17 | 18 | jobs: 19 | build-test: 20 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 21 | with: 22 | runner: linux.aws.h100.8 23 | gpu-arch-type: cuda 24 | gpu-arch-version: "12.6" 25 | # This image is faster to clone than the default, but it lacks CC needed by triton 26 | # (1m25s vs 2m37s). 27 | docker-image: torchtitan-ubuntu-20.04-clang12 28 | repository: pytorch/torchtitan 29 | upload-artifact: outputs 30 | script: | 31 | set -eux 32 | 33 | # The generic Linux job chooses to use base env, not the one setup by the image 34 | CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") 35 | conda activate "${CONDA_ENV}" 36 | 37 | pip config --user set global.progress_bar off 38 | 39 | python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 40 | 41 | USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 42 | 43 | mkdir artifacts-to-be-uploaded 44 | python ./tests/integration_tests_h100.py artifacts-to-be-uploaded --ngpu 8 45 | -------------------------------------------------------------------------------- /.github/workflows/integration_test_8gpu_simple_fsdp.yaml: -------------------------------------------------------------------------------- 1 | name: SimpleFSDP 8 GPU Integration Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - 'torchtitan/experiments/simple_fsdp/**' 8 | pull_request: 9 | paths: 10 | - 'torchtitan/experiments/simple_fsdp/**' 11 | schedule: 12 | # Runs every 6 hours 13 | - cron: '0 */6 * * *' 14 | concurrency: 15 | group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 16 | cancel-in-progress: true 17 | 18 | defaults: 19 | run: 20 | shell: bash -l -eo pipefail {0} 21 | 22 | jobs: 23 | build-test: 24 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 25 | with: 26 | runner: linux.g5.48xlarge.nvidia.gpu 27 | gpu-arch-type: cuda 28 | gpu-arch-version: "12.6" 29 | # This image is faster to clone than the default, but it lacks CC needed by triton 30 | # (1m25s vs 2m37s). 31 | docker-image: torchtitan-ubuntu-20.04-clang12 32 | repository: pytorch/torchtitan 33 | upload-artifact: outputs 34 | script: | 35 | set -eux 36 | 37 | # The generic Linux job chooses to use base env, not the one setup by the image 38 | CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") 39 | conda activate "${CONDA_ENV}" 40 | 41 | pip config --user set global.progress_bar off 42 | 43 | python -m pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 44 | 45 | mkdir artifacts-to-be-uploaded 46 | python -m torchtitan.experiments.simple_fsdp.tests.integration_tests artifacts-to-be-uploaded --ngpu 8 47 | -------------------------------------------------------------------------------- /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: lint-${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 8 | cancel-in-progress: true 9 | 10 | defaults: 11 | run: 12 | shell: bash -l -eo pipefail {0} 13 | 14 | jobs: 15 | lint: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | python-version: ['3.10'] 20 | steps: 21 | - name: Check out repo 22 | uses: actions/checkout@v3 23 | - name: Setup python 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Update pip 28 | run: python -m pip install --upgrade pip 29 | - name: Install lint utilities 30 | run: | 31 | python -m pip install pre-commit 32 | pre-commit install-hooks 33 | - name: Get changed files 34 | id: changed-files 35 | uses: tj-actions/changed-files@d6e91a2266cdb9d62096cebf1e8546899c6aa18f # v45.0.6 36 | - name: Lint modified files 37 | run: pre-commit run --files ${{ steps.changed-files.outputs.all_changed_files }} 38 | -------------------------------------------------------------------------------- /.github/workflows/unit_test_cpu.yaml: -------------------------------------------------------------------------------- 1 | name: CPU Unit Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | 8 | concurrency: 9 | group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 10 | cancel-in-progress: true 11 | 12 | jobs: 13 | build-test: 14 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 15 | with: 16 | docker-image: torchtitan-ubuntu-20.04-clang12 17 | repository: pytorch/torchtitan 18 | script: | 19 | set -eux 20 | 21 | # The generic Linux job chooses to use base env, not the one setup by the image 22 | CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") 23 | conda activate "${CONDA_ENV}" 24 | 25 | pip config --user set global.progress_bar off 26 | 27 | pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu 28 | 29 | USE_CPP=0 python -m pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cpu 30 | 31 | pytest tests/unit_tests --cov=. --cov-report=xml --durations=20 -vv 32 | -------------------------------------------------------------------------------- /.github/workflows/unit_test_cpu_flux.yaml: -------------------------------------------------------------------------------- 1 | name: Flux Model CPU Unit Test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | - 'torchtitan/experiments/flux/**' 8 | pull_request: 9 | paths: 10 | - 'torchtitan/experiments/flux/**' 11 | 12 | 13 | concurrency: 14 | group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | build-test: 19 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 20 | with: 21 | docker-image: torchtitan-ubuntu-20.04-clang12 22 | repository: pytorch/torchtitan 23 | script: | 24 | set -eux 25 | 26 | # The generic Linux job chooses to use base env, not the one setup by the image 27 | CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]") 28 | conda activate "${CONDA_ENV}" 29 | 30 | pip config --user set global.progress_bar off 31 | 32 | pip install --force-reinstall --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu 33 | pytest torchtitan/experiments/flux/tests/unit_tests/ --cov=. --cov-report=xml --durations=20 -vv 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | outputs 7 | dist/* 8 | .vscode 9 | 10 | # data 11 | data 12 | out 13 | wandb 14 | 15 | torchtitan/datasets/**/*.model 16 | assets/**/*.model 17 | torchtitan/experiments/flux/assets/* 18 | 19 | # temp files 20 | *.log 21 | error.json 22 | _remote_module_non_scriptable.py 23 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: 'build' 2 | 3 | default_language_version: 4 | python: python3 5 | 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: 6306a48f7dae5861702d573c9c247e4e9498e867 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: check-ast 12 | - id: check-merge-conflict 13 | - id: no-commit-to-branch 14 | args: ['--branch=main'] 15 | - id: check-added-large-files 16 | args: ['--maxkb=500'] 17 | - id: end-of-file-fixer 18 | exclude: '^(.*\.svg)$' 19 | 20 | - repo: https://github.com/Lucas-C/pre-commit-hooks 21 | rev: v1.5.4 22 | hooks: 23 | - id: insert-license 24 | files: \.py$ 25 | args: 26 | - --license-filepath 27 | - assets/license_header.txt 28 | 29 | - repo: https://github.com/pycqa/flake8 30 | rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b 31 | hooks: 32 | - id: flake8 33 | additional_dependencies: 34 | - flake8-bugbear == 22.4.25 35 | - pep8-naming == 0.12.1 36 | - torchfix 37 | args: ['--config=.flake8'] 38 | 39 | - repo: https://github.com/omnilib/ufmt 40 | rev: v2.3.0 41 | hooks: 42 | - id: ufmt 43 | additional_dependencies: 44 | - black == 22.12.0 45 | - usort == 1.0.5 46 | 47 | - repo: https://github.com/jsh9/pydoclint 48 | rev: d88180a8632bb1602a4d81344085cf320f288c5a 49 | hooks: 50 | - id: pydoclint 51 | args: [--config=pyproject.toml] 52 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, 6 | are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice,this list 9 | of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, this 12 | list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its contributors may 16 | be used to endorse or promote products derived from this software without specific 17 | prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY 20 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 21 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 22 | SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 23 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 24 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 25 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 27 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 28 | DAMAGE. 29 | -------------------------------------------------------------------------------- /assets/images/loss_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchtitan/1cb1fa19033b42bbd11a64c9a227698949c7740b/assets/images/loss_curves.png -------------------------------------------------------------------------------- /assets/license_header.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) Meta Platforms, Inc. and affiliates. 2 | All rights reserved. 3 | 4 | This source code is licensed under the BSD-style license found in the 5 | LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /assets/version.txt: -------------------------------------------------------------------------------- 1 | 0.0.2 2 | -------------------------------------------------------------------------------- /docs/checkpoint.md: -------------------------------------------------------------------------------- 1 | ## How to convert a Llama 3 checkpoint for use in torchtitan 2 | 3 | If you want to continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager. 4 | An example script for converting the original Llama3 checkpoints into the expected DCP format can be found in `scripts/convert_llama_to_dcp.py`. 5 | 6 | The script expects a path to the original checkpoint files, and a path to an output directory: 7 | ```bash 8 | python -m scripts.convert_llama_to_dcp 9 | ``` 10 | 11 | 12 | ## How to convert a torchtitan checkpoint for use in torchtune 13 | 14 | This guide will walk you through the steps required to convert a checkpoint from torchtitan so that it can be loaded into torchtune. 15 | 16 | ### Steps 17 | 1. ENABLE CHECKPOINTING 18 | In your torchtitan training config, ensure that `enable_checkpoint` is set to True. 19 | ``` 20 | [checkpoint] 21 | enable_checkpoint = true 22 | folder = "checkpoint" 23 | interval = 500 24 | ``` 25 | 26 | 27 | 2. SAVE ONLY MODEL WEIGHTS 28 | By setting `model_weights_only` to `True`, the checkpoint will only contain the model weights and exclude the optimizer state and extra train states, resulting in a smaller checkpoint size. 29 | ``` 30 | [checkpoint] 31 | enable_checkpoint = true 32 | model_weights_only = true 33 | ``` 34 | 35 | 3. CHOOSE DESIRED EXPORT PRECISION 36 | The default model states are in `float32`. You can choose to export the checkpoint in a lower precision format such as `bfloat16`. 37 | ``` 38 | [checkpoint] 39 | enable_checkpoint = true 40 | model_weights_only = true 41 | export_dtype = "bfloat16" 42 | ``` 43 | 44 | 4. EXAMPLE CHECKPOINT CONFIGURATION 45 | ``` 46 | [checkpoint] 47 | enable_checkpoint = true 48 | folder = "checkpoint" 49 | interval = 10 50 | load_step = 5 51 | model_weights_only = true 52 | export_dtype = "bfloat16" 53 | ``` 54 | 55 | 5. SAVE THE FINAL CHECKPOINT\ 56 | Once the above have been set, the final checkpoint at the end of the training step will consist of model weights only with the desired export dtype. However, if the final step has not been reached yet, full checkpoints will still be saved so that training can be resumed. 57 | 58 | 6. CONVERT SHARDED CHECKPOINTS TO A SINGLE FILE\ 59 | Finally, once you have obtained the last checkpoint, you can use the following command to convert the sharded checkpoints to a single .pt file that can be loaded into torchtune: 60 | 61 | ``` 62 | python -m torch.distributed.checkpoint.format_utils dcp_to_torch torchtitan/outputs/checkpoint/step-1000 checkpoint.pt 63 | ``` 64 | 65 | 7. EXCLUDING SPECIFIC KEYS FROM CHECKPOINT LOADING 66 | In some cases, you may want to partially load from a previous-trained checkpoint and modify certain settings, such as the number of GPUs or the current step. To achieve this, you can use the `exclude_from_loading` parameter to specify which keys should be excluded from loading. 67 | This parameter takes a comma-separated list of keys that should be excluded from loading. 68 | ``` 69 | [checkpoint] 70 | enable_checkpoint = true 71 | exclude_from_loading = "data_loader,lr_scheduler" 72 | ``` 73 | 74 | That's it. You have now successfully converted a sharded torchtitan checkpoint for use in torchtune. 75 | 76 | 77 | ## How to create a seed checkpoint 78 | Sometimes one needs to create a seed checkpoint to initialize a model from step 0. 79 | E.g. it is hard, if not impossible, for meta initialization on multiple devices to reproduce the initialization on a single device. 80 | A seed checkpoint does initialization of the model on a single CPU, and can be loaded from another job on an arbitrary number of GPUs via DCP resharding. 81 | 82 | To create a seed checkpoint, use the same model config as you use for training. 83 | e.g. 84 | ```bash 85 | NGPU=1 CONFIG= ./run_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --parallelism.data_parallel_replicate_degree 1 --parallelism.data_parallel_shard_degree 1 --parallelism.tensor_parallel_degree 1 --parallelism.pipeline_parallel_degree 1 --parallelism.context_parallel_degree 1 86 | ``` 87 | -------------------------------------------------------------------------------- /docs/datasets.md: -------------------------------------------------------------------------------- 1 | # Custom Datasets in torchtitan 2 | 3 | `torchtitan` is designed to work seamlessly with most HuggingFace datasets. While we provide the C4 dataset for numerics and convergence testing, you can easily add support for your own datasets. Here's how to do it using Wikipedia as an example. 4 | 5 | ## Quick Start 6 | Locate the dataset configuration file: 7 | ``` 8 | torchtitan/datasets/hf_datasets/hf_datasets.py 9 | ``` 10 | 11 | ## Adding Your Dataset 12 | You'll need to add three components: 13 | 1. A dataset loader function 14 | 2. A sample processor function 15 | 3. A dataset configuration entry 16 | 17 | ### 1. Define Dataset Loader 18 | Create a function that specifies how to load your dataset: 19 | 20 | ```python 21 | def load_wikipedia_dataset(dataset_path: str, **kwargs): 22 | """Load Wikipedia dataset with specific configuration.""" 23 | logger.info("Loading Wikipedia dataset...") 24 | return load_dataset( 25 | dataset_path, 26 | name="20220301.en", 27 | split="train", 28 | streaming=True, 29 | trust_remote_code=True, 30 | ) 31 | ``` 32 | 33 | ### 2. Define Sample Processor 34 | Create a function that processes individual samples from your dataset: 35 | 36 | ```python 37 | def process_wikipedia_text(sample: Dict[str, Any]) -> str: 38 | """Process Wikipedia dataset sample text.""" 39 | return f"{sample['title']}\n\n{sample['text']}" 40 | ``` 41 | 42 | ### 3. Register Your Dataset 43 | Add your dataset configuration to the DATASETS dictionary: 44 | 45 | ```python 46 | DATASETS = { 47 | # ... existing datasets ... 48 | "wikipedia": DatasetConfig( 49 | path="wikipedia", # default HuggingFace dataset path 50 | loader=load_wikipedia_dataset, 51 | text_processor=process_wikipedia_text, 52 | ), 53 | } 54 | ``` 55 | 56 | ### 4. Configure Your Training 57 | In your training configuration file (`.toml`), set your dataset: 58 | 59 | ```toml 60 | dataset = "wikipedia" 61 | ``` 62 | 63 | That's it! Your custom dataset is now ready to use with `torchtitan`. 64 | 65 | ## Key Points 66 | - The DatasetConfig contains all necessary components for a dataset: 67 | - `path`: The default path to the dataset (can be overridden during training) 68 | - `loader`: Function to load the dataset 69 | - `text_processor`: Function to process individual samples 70 | - The loader function should return a HuggingFace dataset object 71 | - The processor function should return a string that combines the relevant fields from your dataset 72 | - Use `streaming=True` for large datasets to manage memory efficiently 73 | 74 | Now you can start training with your custom dataset! 75 | -------------------------------------------------------------------------------- /docs/debugging.md: -------------------------------------------------------------------------------- 1 | ## Enable Memory Profiling 2 | 3 | Launch training job with the following command (or alternatively set configs in toml files) 4 | ``` 5 | CONFIG_FILE="./train_configs/debug_model.toml" ./run_train.sh --profiling.enable_memory_snapshot --profiling.save_memory_snapshot_folder memory_snapshot 6 | ``` 7 | * `--profiling.enable_memory_snapshot`: to enable memory profiling 8 | * `--profiling.save_memory_snapshot_folder`: configures the folder which memory snapshots are dumped into (`./outputs/memory_snapshot/` by default) 9 | + In case of OOMs, the snapshots will be in `./outputs/memory_snapshot/iteration_x_exit`. 10 | + Regular snapshots (taken every `profiling.profile_freq` iterations) will be in `memory_snapshot/iteration_x`. 11 | 12 | You can find the saved pickle files in your output folder. 13 | To visualize a snapshot file, you can drag and drop it to . To learn more details on memory profiling, please visit this [tutorial](https://pytorch.org/blog/understanding-gpu-memory-1/). 14 | 15 | ## Overriding Boolean Flags from `.toml` via CLI 16 | 17 | Boolean flags are treated as **actions**. To disable a flag from the command line, use the `--no` prefix. 18 | 19 | For example, given the following in your `.toml` file: 20 | 21 | ```toml 22 | [profiling] 23 | enable_memory_snapshot = true 24 | 25 | ``` 26 | You can override it at runtime via CLI with: 27 | 28 | ```bash 29 | --profiling.no_enable_memory_snapshot 30 | --profiling.no-enable-memory-snapshot # Equivalent 31 | ``` 32 | 33 | > Note: `--enable_memory_snapshot=False` will **not** work. Use `--no_enable_memory_snapshot` instead. 34 | 35 | ## Debugging Config Values 36 | 37 | To inspect how configuration values are interpreted—including those from `.toml` files and CLI overrides—run the config manager directly: 38 | 39 | ```bash 40 | python -m torchtitan.config_manager [your cli args...] 41 | ``` 42 | 43 | For example, 44 | 45 | ```bash 46 | python -m torchtitan.config_manager --job.config_file ./torchtitan/models/llama3/train_configs/llama3_8b.toml --profiling.enable_memory_snapshot 47 | ``` 48 | 49 | To list all available CLI flags and usage: 50 | 51 | ```bash 52 | python -m torchtitan.config_manager --help 53 | ``` 54 | 55 | This will print a structured configuration to `stdout`, allowing you to verify that overrides are being applied correctly. 56 | 57 | ## Troubleshooting jobs that timeout 58 | 59 | If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs. 60 | When a job times out, Flight Recorder automatically generates dump files on every rank containing valuable debugging data. You can find these dump files in the `job.dump_folder` directory. 61 | To learn how to analyze and diagnose issues using these logs, follow our step-by-step tutorial [link](https://pytorch.org/tutorials/prototype/flight_recorder_tutorial.html). 62 | -------------------------------------------------------------------------------- /docs/extension.md: -------------------------------------------------------------------------------- 1 | To support rapid experimentation with torchtitan, we provide several extension points. The principle for adding these extension points is to support various use cases with flexible component swapping and reuse, while trying to keep the code clean and minimal. 2 | 3 | The extension points and protocols mentioned in this note are subject to change. 4 | 5 | 6 | ### `TrainSpec` 7 | 8 | [`TrainSpec`](../torchtitan/protocols/train_spec.py) supports configuring high-level components in model training, including 9 | - definitions of model class and model args config 10 | - model parallelization functions 11 | - loss functions 12 | - factory methods for creating dataloader / tokenizer / optimizer / learning rate scheduler / metrics processor 13 | 14 | The coarse level abstraction tries to hit a balance between flexible component swapping and a straightforward train script ([train.py](../torchtitan/train.py)). 15 | Note that among all training components, currently [`CheckpointManager`](../torchtitan/components/checkpoint.py) and [`FTManager`](../torchtitan/components/ft.py) are not configurable since we do not expect them to be customized, but we are open to requests. 16 | 17 | To register a `TrainSpec`, please follow the example of [Llama 3.1](../torchtitan/models/llama3/__init__.py) to `register_train_spec`. Please make sure the registration code is called before training initialization. In torchtitan, it is performed during [module import](../torchtitan/__init__.py). 18 | 19 | 20 | ### `ModelConverter` 21 | 22 | Originated from a [request](https://github.com/pytorch/torchtitan/issues/790) to unify quantization interface and supports dynamic registration, 23 | [`ModelConverter`](../torchtitan/protocols/model_converter.py) defines the following general interface: 24 | - `convert` is called after model definition and meta device initialization, but before model parallelization. It can perform general module rewrite, e.g. [Float8](../torchtitan/components/float8.py) module swapping, as long as it is compatible with other components. 25 | - `post_optimizer_hook`, as its name suggests, would be registered (via `torch.optim.Optimizer.register_step_post_hook`) to perform necessary post optimizer step operations. As an example, the [Float8](../torchtitan/components/float8.py) component in torchtitan uses this hook to issue a single all-reduce for all FSDP2 parameters (at once for better performance) to calculate the dynamic scale. 26 | 27 | To register a `ModelConverter`, please follow the example of [Float8](../torchtitan/components/float8.py) to `register_model_converter`. Please make sure the registration code is called before training initialization. In torchtitan, it is performed during [module import](../torchtitan/__init__.py). 28 | 29 | 30 | ### Train script 31 | 32 | To perform various tasks, from adding a new model (possibly with a new modality), to trying out a new training paradigm (e.g. async training), a single train script cannot handle all the cases, unless customization points are inserted everywhere to make it less readable. Instead of always starting and maintaining a standalone train script, we group code in [train.py](../torchtitan/train.py) into functions to allow for reuse. 33 | 34 | This is an ongoing effort, and the level of grouping is subject to change. 35 | 36 | 37 | ### Extending `JobConfig` 38 | 39 | [`JobConfig`](../torchtitan/config_manager.py) supports custom extension through the `--experimental.custom_args_module` flag. 40 | This lets you define a custom module that extends `JobConfig` with additional fields. 41 | 42 | When specified, your custom `JobConfig` is merged with the default: 43 | - If a field exists in both, the custom config’s value replaces the default. 44 | - Fields unique to either config are retained. 45 | 46 | #### Example 47 | 48 | To add a custom `custom_args` section, define your own `JobConfig`: 49 | 50 | ```python 51 | # torchtitan/experiments/your_folder/custom_args.py 52 | from dataclasses import dataclass, field 53 | 54 | @dataclass 55 | class CustomArgs: 56 | how_is_your_day: str = "good" 57 | """Just an example.""" 58 | 59 | @dataclass 60 | class Training: 61 | steps: int = 500 62 | """Replaces the default value""" 63 | 64 | my_mini_steps: int = 10000 65 | """New field is added""" 66 | 67 | ... # Original fields are preserved 68 | 69 | @dataclass 70 | class JobConfig: 71 | custom_args: CustomArgs = field(default_factory=CustomArgs) 72 | training: Training= field(default_factory=Training) 73 | ``` 74 | 75 | Then run your script with: 76 | 77 | ```bash 78 | --experimental.custom_args_module=torchtitan.experiments.your_folder.custom_args 79 | ``` 80 | 81 | Or specify it in your `.toml` config: 82 | 83 | ```toml 84 | [experimental] 85 | custom_args_module = "torchtitan.experiments.your_folder.custom_args" 86 | ``` 87 | -------------------------------------------------------------------------------- /docs/metrics.md: -------------------------------------------------------------------------------- 1 | We support automatically collecting metrics such as 2 | 1. High level system metrics such as MFU, average loss, max loss and words per second along with some 3 | 2. Memory metrics to measure max VRAM consumption and the number of OOMs 4 | 3. Timing metrics to measure data loading bottlenecks 5 | 6 | Those metrics can then be visualized in either a TensorBoard or WandDB dashboard 7 | 8 | ## TensorBoard 9 | 10 | To visualize TensorBoard metrics of models trained on a remote server via a local web browser: 11 | 12 | 1. Make sure `metrics.enable_tensorboard` option is set to true in model training (either from a .toml file or from CLI). 13 | 14 | 2. Set up SSH tunneling, by running the following from local CLI 15 | ``` 16 | ssh -L 6006:127.0.0.1:6006 [username]@[hostname] 17 | ``` 18 | 19 | 3. Inside the SSH tunnel that logged into the remote server, go to the torchtitan repo, and start the TensorBoard backend 20 | ``` 21 | tensorboard --logdir=./outputs/tb 22 | ``` 23 | 24 | 4. In the local web browser, go to the URL it provides OR to http://localhost:6006/. 25 | 26 | ## Weights and Biases 27 | 28 | Weights and Biases will automatically send metrics to a remote server if you login with `wandb login` 29 | 30 | So all you need to do is make sure that `metrics.enable_wandb` is enabled 31 | 32 | For an example you can inspect the Llama 3 [debug_model.toml](../torchtitan/models/llama3/train_configs/debug_model.toml) 33 | 34 | Note that if both W&B and Tensorboard are enabled then we will prioritize W&B. 35 | -------------------------------------------------------------------------------- /docs/performance.md: -------------------------------------------------------------------------------- 1 | We demonstrate the effectiveness of elastic distributed training using torchtitan, via experiments on Llama 3.1 8B, 70B, and 405B models, from 1D parallelism to 4D parallelism, at the scale from 8 GPUs to 512 GPUs. 2 | 3 | We ran our performance benchmarks on the [Grand Teton platform](https://engineering.fb.com/2022/10/18/open-source/ocp-summit-2022-grand-teton/), where 4 | - Each host has 8 NVIDIA H100 GPUs fully connected with NVLink. 5 | - Each H100 GPU is equipped with 96GB HBM2e with 2.4 TB/sec peak memory bandwidth. 6 | - Hosts are inter-connected with backend RDMA network with 400 Gb/s per GPU. 7 | - We used the default 500W power limit, although tuning it up to 700W TDP can potentially provide further speedups. 8 | 9 | We note that, throughout our experimentation, memory readings are stable across the whole training process[^1], whereas throughput numbers (TPS/GPU) are calculated and logged every 10 iterations, and always read at the (arbitrarily determined) 90th iteration. 10 | 11 | We do not report Model FLOPS Utilization (MFU) because when Float8 is enabled (on `nn.Linear` modules), both BFLOAT16 Tensor Core and FP8 Tensor Core are involved in model training, but they have different peak FLOPS and the definition of MFU under such scenario is not well-defined. We note that the 1D Llama 3.1 8B model training on 8 or 128 H100 GPUs without Float8 achieves 33% to 39% MFU[^2] (with or without torch.compile, respectively). 12 | 13 | **Table 1** 1D Parallelism (FSDP). Llama 3.1 8B model. 8 GPUs. Local batch size 2, global batch size 16. Selective activation checkpointing. 14 | 15 | | Techniques | TPS/GPU | Memory(GiB) | 16 | | ----- | ----: | ----: | 17 | | FSDP | 5,762 | 82.4 | 18 | | FSDP + torch.compile | 6,667 | 77.0 | 19 | | FSDP + torch.compile + Float8 | 8,532 | 76.8 | 20 | 21 | **Table 2** FSDP + CP + torch.compile + Float8. Llama 3.1 8B model. 8 GPUs. Local batch size 1. Full activation checkpointing. 22 | 23 | | Parallelism | Sequence Length | TPS/GPU | Memory(GiB) | 24 | | ----- | ----: | ----: | ----: | 25 | | FSDP 8, CP 1 | 32768 | 3,890 | 83.9 | 26 | | FSDP 4, CP 2 | 65536 | 2,540 | 84.2 | 27 | | FSDP 2, CP 4 | 131072 | 1,071 | 84.0 | 28 | | FSDP 1, CP 8 | 262144 | 548 | 84.5 | 29 | 30 | **Table 3** 1D Parallelism (FSDP). Llama 3.1 8B model. 128 GPUs. Local batch size 2, global batch size 256. Selective activation checkpointing. 31 | 32 | | Techniques | TPS/GPU | Memory(GiB) | 33 | | ----- | ----: | ----: | 34 | | FSDP | 5,605 | 67.0 | 35 | | FSDP + torch.compile | 6,514 | 62.0 | 36 | | FSDP + torch.compile + Float8 | 8,380 | 61.8 | 37 | 38 | **Table 4** 2D parallelism (FSDP + TP) + torch.compile + Float8. Llama 3.1 70B model. 256 GPUs (FSDP 32, TP 8). Local batch size 16, global batch size 512. Full activation checkpointing. 39 | 40 | | Techniques | TPS/GPU | Memory(GiB) | 41 | | ----- | ----: | ----: | 42 | | 2D | 829 | 71.9 | 43 | | 2D + AsyncTP | 876 | 67.6 | 44 | 45 | **Table 5** 3D parallelism (FSDP + TP + PP) + torch.compile + Float8 + AsyncTP. Llama 3.1 405B model. 512 GPUs (FSDP 8, TP 8, PP8). Local batch size 32, global batch size 256. Full activation checkpointing. 46 | 47 | | Schedule | TPS/GPU | Memory(GiB) | 48 | | ----- | ----: | ----: | 49 | | 1F1B | 100 | 82.5 | 50 | | Interleaved 1F1B | 128 | 72.7 | 51 | 52 | **Table 6** 4D parallelism (FSDP + TP + PP + CP) + torch.compile + Float8 + AsyncTP + 1F1B. Llama 3.1 405B model. 512 GPUs (TP 8, PP8). Local batch size 8. Full activation checkpointing. 53 | 54 | | Parallelism | Sequence Length | TPS/GPU | Memory(GiB) | 55 | | ----- | ----: | ----: | ----: | 56 | | FSDP 8, CP 1 | 32768 | 76 | 75.3 | 57 | | FSDP 4, CP 2 | 65536 | 47 | 75.9 | 58 | | FSDP 2, CP 4 | 131072 | 31 | 77.1 | 59 | | FSDP 1, CP 8 | 262144 | 16 | 84.9 | 60 | 61 | 62 | #### Versions used for performance testing 63 | | repo | commit | date | 64 | | --- | --- | --- | 65 | | torch | [1963fc8](https://github.com/pytorch/pytorch/commit/1963fc83a1c32e162162e2414f78b043f0674bae) | 2024/12/23 | 66 | | torchao | [eab345c](https://github.com/pytorch/ao/commit/eab345c2268a7506355d506ebfc27b5d28e5e7d0) | 2024/12/23 | 67 | | torchtitan | [9dec370](https://github.com/pytorch/torchtitan/commit/9dec370ad26b5f8e9a7333a0e36165018262644b) | 2024/12/26 | 68 | 69 | 70 | [^1]: Different PP ranks can have different peak memory usages. We take the maximum across all GPUs. 71 | 72 | [^2]: In our test we used HBM2e-based SXM H100 with lower TDP, the actual peak TFLOPs number is between SXM and NVL, and we don't know its exact value. So this MFU number is lower than actual MFU because we use the peak number of SXM directly. 73 | -------------------------------------------------------------------------------- /docs/torchft.md: -------------------------------------------------------------------------------- 1 | # Enabling Fault Tolerance with TorchFT in TorchTitan 2 | 3 | ## Why Use TorchFT with TorchTitan? 4 | 5 | TorchFT is designed to provide fault tolerance when training with replicated weights, such as in DDP or HSDP. By enabling TorchFT in TorchTitan, we can ensure that our training process can continue even if some machines fail. For more information on TorchFT, please refer to the [TorchFT repository](https://github.com/pytorch/torchft/). 6 | 7 | **Note:** This is an ongoing development effort, and everything is subject to change. 8 | 9 | ## Prerequisites for Using TorchFT with TorchTitan 10 | 11 | Before using TorchFT with TorchTitan, you need to install TorchFT by following the instructions in the [TorchFT README](https://github.com/pytorch/torchft/blob/main/README.md) to install TorchFT. 12 | 13 | ## Configuring TorchTitan for Using TorchFT 14 | 15 | When using TorchFT with TorchTitan, you need to launch multiple replica groups, each of which is a separate TorchTitan instance. Each replica group is responsible for maintaining a copy of the model weights. In case of a failure, the other replica groups can continue training without lossing weight information. 16 | 17 | For example, if you want to run HSDP on a single machine with eight GPUs, where weights are sharded within four GPUs with two replica groups (2, 4 device mesh), you can do this with TorchTitan by specifying `--data_parallel_replica_degree=2` and `--data_parallel_shard_degree=4`. However, to utilize TorchFT, you will need to launch two TorchTitan instances, each managing four GPUs and communicating with each other through TorchFT. 18 | 19 | ## Example Configuration 20 | 21 | Let's consider an example where we want to run HSDP on a single machine with eight GPUs, where weights are sharded within four GPUs with two replica groups (2, 4 device mesh). Without using TorchFT, you can launch such a training process by specifying `--parallelism.data_parallel_replica_degree=2 --parallelism.data_parallel_shard_degree=4`. However, in the event of a trainer failure (emulating a real-world machine failure), the entire training process would need to stop and recover from the last checkpoint. This can lead to significant downtime and wasted resources. 22 | 23 | With TorchFT, we can tolerate one replica group failure, ensuring that the training process continues uninterrupted. To achieve this, we can launch two TorchTitan instances, each managing four GPUs and communicating with each other through TorchFT. This setup allows for seamless fault tolerance and minimizes the impact of individual trainer failures. 24 | 25 | ### Launching TorchFT with TorchTitan 26 | 27 | To launch TorchFT with TorchTitan, you need to execute the following three commands in different shell sessions: 28 | 29 | 1. Launch TorchFT lighthouse: 30 | ```bash 31 | RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000 32 | ``` 33 | 2. Launch the first TorchTitan instance: 34 | ```bash 35 | NGPU=4 CUDA_VISIBLE_DEVICES=0,1,2,3 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=0 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 36 | ``` 37 | 3. Launch the second TorchTitan instance: 38 | ```bash 39 | NGPU=4 CUDA_VISIBLE_DEVICES=4,5,6,7 CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --fault_tolerance.enable --fault_tolerance.replica_id=1 --fault_tolerance.group_size=2 --parallelism.data_parallel_shard_degree=4 40 | ``` 41 | 42 | ### Explanation 43 | 44 | * We limit the visibility of GPUs for each TorchTitan instance using environment variables `NGPU` and `CUDA_VISIBLE_DEVICES`, as we are running on a single machine. In reality, each TorchTitan instance will not share machines, so these variables are not required. 45 | * `--fault_tolerance.enable` enables TorchFT functionality. 46 | * `--fault_tolerance.group_size=2` tells TorchTitan that there are two replica groups. 47 | * `--fault_tolerance.replica_id=1` tells TorchTitan that the replica ID of this instance is 1. 48 | * Note that the alive replica group with the smallest replica ID will perform checkpointing saving. 49 | 50 | In a real-world scenario, `torchft_lighthouse` would likely be on a different machine. The `TORCHFT_LIGHTHOUSE` environment variable is used to tell TorchFT how to communicate with `torchft_lighthouse`. The default value is `http://localhost:29510`. 51 | -------------------------------------------------------------------------------- /multinode_trainer.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # --- This script is optimized for AWS with EFA 9 | # --- adjust NCCL_BUFFSIZE if you encounter memory 10 | # --- constraint issues or to tune for improved performance. 11 | # --- 12 | 13 | #SBATCH --job-name=torchtitan_multi_node 14 | 15 | #SBATCH --ntasks=4 16 | 17 | #SBATCH --nodes=4 18 | 19 | #SBATCH --gpus-per-task=8 20 | 21 | #SBATCH --cpus-per-task=96 22 | 23 | #SBATCH --partition=train 24 | 25 | 26 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 27 | nodes_array=($nodes) 28 | head_node=${nodes_array[0]} 29 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 30 | 31 | echo Node IP: $head_node_ip 32 | export LOGLEVEL=INFO 33 | # Enable for A100 34 | export FI_PROVIDER="efa" 35 | # Ensure that P2P is available 36 | # export NCCL_P2P_DISABLE=1 37 | export NCCL_IB_DISABLE=1 38 | 39 | # debugging flags (optional) 40 | export NCCL_DEBUG=WARN 41 | export PYTHONFAULTHANDLER=1 42 | # optional debug settings 43 | # export NCCL_DEBUG=INFO 44 | # NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV 45 | 46 | export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH 47 | export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH 48 | export CUDA_LAUNCH_BLOCKING=0 49 | 50 | # on your cluster you might need these: 51 | # set the network interface 52 | export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" 53 | export NCCL_BUFFSIZE=2097152 54 | #export TORCH_DIST_INIT_BARRIER=1 55 | export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 56 | 57 | CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama/train_configs/llama3_8b.toml"} 58 | 59 | dcgmi profile --pause 60 | # adjust sbatch --ntasks and sbatch --nodes above and --nnodes below 61 | # to your specific node count, and update target launch file. 62 | srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE} 63 | dcgmi profile --resume 64 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # ---- All project specifications ---- # 2 | [project] 3 | name = "torchtitan" 4 | description = "A PyTorch native library for large-scale model training" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | license = {file = "LICENSE"} 8 | authors = [ 9 | { name = "PyTorch Team", email = "packages@pytorch.org" }, 10 | ] 11 | keywords = ["pytorch", "training", "llm"] 12 | dependencies = [ 13 | # Stateful Dataloader 14 | "torchdata>=0.8.0", 15 | 16 | # Hugging Face integrations 17 | "datasets>=2.21.0", 18 | 19 | # Tokenization 20 | "blobfile", 21 | "tiktoken", 22 | 23 | # Miscellaneous 24 | "tomli>=1.1.0", 25 | "fsspec" 26 | ] 27 | dynamic = ["version"] 28 | 29 | [project.urls] 30 | GitHub = "https://github.com/pytorch/torchtitan" 31 | Documentation = "https://github.com/pytorch/torchtitan/tree/main/docs" 32 | Issues = "https://github.com/pytorch/torchtitan/issues" 33 | 34 | [project.optional-dependencies] 35 | dev = [ 36 | "pre-commit", 37 | "pytest", 38 | "pytest-cov", 39 | "tensorboard", 40 | ] 41 | 42 | [tool.setuptools.dynamic] 43 | version = {file = "assets/version.txt"} 44 | 45 | 46 | # ---- Explicit project build information ---- # 47 | [build-system] 48 | requires = ["setuptools>=61.0"] 49 | build-backend = "setuptools.build_meta" 50 | 51 | [tool.setuptools.packages.find] 52 | where = [""] 53 | include = ["torchtitan*"] 54 | 55 | [tool.pytest.ini_options] 56 | addopts = ["--showlocals"] # show local variables in tracebacks 57 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | .ci/docker/requirements-dev.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | .ci/docker/requirements.txt -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | # use envs as local overrides for convenience 11 | # e.g. 12 | # LOG_RANK=0,1 NGPU=4 ./run_train.sh 13 | NGPU=${NGPU:-"8"} 14 | export LOG_RANK=${LOG_RANK:-0} 15 | CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} 16 | 17 | overrides="" 18 | if [ $# -ne 0 ]; then 19 | overrides="$*" 20 | fi 21 | 22 | TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} 23 | 24 | PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ 25 | TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ 26 | torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ 27 | --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ 28 | -m torchtitan.train --job.config_file ${CONFIG_FILE} $overrides 29 | -------------------------------------------------------------------------------- /scripts/download_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | from requests.exceptions import HTTPError 10 | 11 | 12 | def hf_download( 13 | repo_id: str, tokenizer_path: str, local_dir: str, hf_token: Optional[str] = None 14 | ) -> None: 15 | from huggingface_hub import hf_hub_download 16 | 17 | tokenizer_path = ( 18 | f"{tokenizer_path}/tokenizer.model" if tokenizer_path else "tokenizer.model" 19 | ) 20 | 21 | try: 22 | hf_hub_download( 23 | repo_id=repo_id, 24 | filename=tokenizer_path, 25 | local_dir=local_dir, 26 | local_dir_use_symlinks=False, 27 | token=hf_token, 28 | ) 29 | except HTTPError as e: 30 | if e.response.status_code == 401: 31 | print( 32 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 33 | ) 34 | else: 35 | raise e 36 | 37 | 38 | if __name__ == "__main__": 39 | import argparse 40 | 41 | parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.") 42 | parser.add_argument( 43 | "--repo_id", 44 | type=str, 45 | default="meta-llama/Meta-Llama-3.1-8B", 46 | help="Repository ID to download from. default to Llama-3.1-8B", 47 | ) 48 | parser.add_argument( 49 | "--tokenizer_path", 50 | type=str, 51 | default="original", 52 | help="the tokenizer.model path relative to repo_id", 53 | ) 54 | parser.add_argument( 55 | "--hf_token", type=str, default=None, help="HuggingFace API token" 56 | ) 57 | parser.add_argument( 58 | "--local_dir", 59 | type=str, 60 | default="assets/tokenizer/", 61 | help="local directory to save the tokenizer.model", 62 | ) 63 | 64 | args = parser.parse_args() 65 | hf_download(args.repo_id, args.tokenizer_path, args.local_dir, args.hf_token) 66 | -------------------------------------------------------------------------------- /scripts/estimate/run_memory_estimation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | # use envs as local overrides for convenience 11 | # e.g. 12 | # NGPU=4 ./run_memory_estimation.sh 13 | NGPU=${NGPU:-"8"} 14 | NNODES=${NNODES:-"1"} 15 | CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} 16 | 17 | overrides="" 18 | if [ $# -ne 0 ]; then 19 | overrides="$*" 20 | fi 21 | 22 | # Calculate WORLD_SIZE as the product of NGPU and NNODES 23 | # Export WORLD_SIZE and LOCAL_RANK 24 | export WORLD_SIZE=$((NGPU * NNODES)) 25 | export LOCAL_RANK=0 26 | python -m scripts.estimate.estimation --job.config_file ${CONFIG_FILE} --memory_estimation.enabled $overrides 27 | -------------------------------------------------------------------------------- /scripts/generate/README.md: -------------------------------------------------------------------------------- 1 | # Model Generation Check 2 | 3 | The `test_generate` script provides a straightforward way to validate models, tokenizers, checkpoints, and device compatibility by running a single forward pass. This script functions as a sanity check to ensure everything is set up correctly. 4 | 5 | While **torchtitan** focuses on advanced features for distributed pre-training, this script acts as a lightweight integration test to verify runtime setup. For more extensive inference and generation capabilities, consider tools like [pytorch/torchchat](https://github.com/pytorch/torchchat/). 6 | 7 | ## Purpose and Use Case 8 | 9 | This script is ideal for users who need to: 10 | 11 | - **Run Sanity Checks**: Confirm that models, tokenizers, and checkpoints load without errors. 12 | - **Test Compatibility**: Execute a forward pass to assess model response and memory usage. 13 | - **Evaluate Device Scaling**: Optionally test distributed generation using tensor parallel (TP) to confirm multi-device functionality. 14 | 15 | ## Usage Instructions 16 | 17 | #### Run on a single GPU. 18 | 19 | ```bash 20 | NGPU=1 CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \ 21 | PROMPT="What is the meaning of life?" \ 22 | ./scripts/generate/run_llama_generate.sh --max_new_tokens=32 --temperature=0.8 --seed=3 23 | ``` 24 | 25 | #### Run on 4 GPUs and pipe results to a json file. 26 | 27 | ```bash 28 | NGPU=4 CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml CHECKPOINT_DIR=./outputs/checkpoint/ \ 29 | PROMPT="What is the meaning of life?" \ 30 | ./scripts/generate/run_llama_generate.sh --max_new_tokens=32 --temperature=0.8 --seed=3 --out > output.json 31 | ``` 32 | 33 | #### View Available Arguments 34 | 35 | ```bash 36 | > python -m scripts.generate.test_generate --help 37 | ``` 38 | -------------------------------------------------------------------------------- /scripts/generate/_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | 11 | 12 | def multinomial_sample_one( 13 | probs: torch.Tensor, rng: Optional[torch.Generator] = None 14 | ) -> torch.Tensor: 15 | q = torch.empty_like(probs).exponential_(1, generator=rng) 16 | return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) 17 | 18 | 19 | def logits_to_probs( 20 | logits: torch.Tensor, 21 | temperature: float = 1.0, 22 | top_k: Optional[int] = None, 23 | ) -> torch.Tensor: 24 | logits = logits / max(temperature, 1e-5) 25 | 26 | if top_k is not None: 27 | v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) 28 | pivot = v.select(dim=-1, index=-1).unsqueeze(-1) 29 | logits = torch.where(logits < pivot, -float("Inf"), logits) 30 | 31 | probs = torch.nn.functional.softmax(logits, dim=-1) 32 | return probs 33 | 34 | 35 | def generate_next_token( 36 | model, 37 | x: torch.Tensor, 38 | *, 39 | temperature: float = 1.0, 40 | top_k: Optional[int] = None, 41 | rng: Optional[torch.Generator] = None, 42 | ) -> torch.Tensor: 43 | logits = model(x) # (B, T, vocab_size) 44 | probs = logits_to_probs(logits[:, -1, :], temperature, top_k) 45 | next_token = multinomial_sample_one(probs, rng=rng) 46 | return next_token 47 | 48 | 49 | @torch.no_grad() 50 | def generate( 51 | model, 52 | input_ids: torch.Tensor, 53 | *, 54 | max_new_tokens: int, 55 | temperature: float = 1.0, 56 | top_k: Optional[int] = None, 57 | seed: Optional[int] = None, 58 | ) -> torch.Tensor: 59 | # ensure batch dimension (T,) --> (B, T) 60 | if input_ids.ndim == 1: 61 | input_ids = input_ids.unsqueeze(0) 62 | 63 | rng = None 64 | if seed is not None: 65 | rng = torch.Generator(input_ids.device).manual_seed(seed) 66 | 67 | generated_tokens = input_ids.clone() 68 | 69 | for _ in range(max_new_tokens): 70 | next_token = generate_next_token( 71 | model, 72 | x=generated_tokens, 73 | temperature=temperature, 74 | top_k=top_k, 75 | rng=rng, 76 | ) 77 | 78 | generated_tokens = torch.cat([generated_tokens, next_token], dim=1) 79 | 80 | return generated_tokens 81 | -------------------------------------------------------------------------------- /scripts/generate/run_llama_generate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -e 9 | 10 | # use envs as local overrides for convenience 11 | # e.g. 12 | # LOG_RANK=0,1 NGPU=4 ./run_llama_generate.sh 13 | NGPU=${NGPU:-"1"} 14 | LOG_RANK=${LOG_RANK:-0} 15 | CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} 16 | CHECKPOINT_DIR=${CHECKPOINT_DIR:-"./outputs/checkpoint/"} 17 | PROMPT=${PROMPT:-""} 18 | 19 | overrides=() 20 | if [ $# -ne 0 ]; then 21 | for arg in "$@"; do 22 | # special case to handle prompt in quotes 23 | if [[ "$arg" == --prompt=* ]]; then 24 | PROMPT="${arg#--prompt=}" 25 | # check if file 26 | if [[ -f "$PROMPT" ]]; then 27 | PROMPT=$(<"$PROMPT") 28 | fi 29 | else 30 | # handle other args 31 | overrides+=("$arg") 32 | fi 33 | done 34 | fi 35 | 36 | set -x 37 | torchrun --standalone \ 38 | --nproc_per_node="${NGPU}" \ 39 | --local-ranks-filter="${LOG_RANK}" \ 40 | -m scripts.generate.test_generate \ 41 | --config="${CONFIG_FILE}" \ 42 | --checkpoint="${CHECKPOINT_DIR}" \ 43 | --prompt="${PROMPT}" \ 44 | "${overrides[@]}" 45 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | This directory contains tests for the TorchTitan project, including unit tests and integration tests. 4 | 5 | ## Test Structure 6 | 7 | - `unit_tests/`: Contains unit tests for individual components 8 | - `integration_tests.py`: Contains integration tests that test multiple components together 9 | - `integration_tests_h100.py`: Contains integration tests specifically designed for H100 GPUs, which utilize symmetric memory and float8. 10 | - `assets/`: Contains test assets and fixtures used by the tests 11 | 12 | ## Running Tests 13 | 14 | ### Prerequisites 15 | 16 | Ensure you have all development dependencies installed: 17 | 18 | ```bash 19 | pip install -r dev-requirements.txt 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ### Running Integration Tests 24 | 25 | To run the integration tests: 26 | 27 | ```bash 28 | python ./tests/integration_tests.py [--config_dir CONFIG_DIR] [--test TEST] [--ngpu NGPU] 29 | ``` 30 | 31 | Arguments: 32 | - `output_dir`: (Required) Directory where test outputs will be stored 33 | - `--config_dir`: (Optional) Directory containing configuration files (default: "./torchtitan/models/llama3/train_configs") 34 | - `--test`: (Optional) Specific test to run, use test names from the `build_test_list()` function (default: "all") 35 | - `--ngpu`: (Optional) Number of GPUs to use for testing (default: 8) 36 | 37 | Examples: 38 | ```bash 39 | # Run all integration tests with 8 GPUs 40 | python ./tests/integration_tests.py ./test_output 41 | 42 | # Run a specific test with 4 GPUs 43 | python ./tests/integration_tests.py ./test_output --test default --ngpu 4 44 | 45 | # Run all tests with a custom config directory 46 | python ./tests/integration_tests.py ./test_output --config_dir ./my_configs 47 | ``` 48 | 49 | ### Running Unit Tests 50 | 51 | To run only the unit tests: 52 | 53 | ```bash 54 | pytest -s tests/unit_tests/ 55 | ``` 56 | 57 | ### Running Specific Unit Test Files 58 | 59 | To run a specific test file: 60 | 61 | ```bash 62 | pytest -s tests/unit_tests/test_job_config.py 63 | ``` 64 | 65 | ### Running Specific Test Functions in Unit Tests 66 | 67 | To run a specific test function: 68 | 69 | ```bash 70 | pytest -s tests/unit_tests/test_job_config.py::TestJobConfig::test_command_line_args 71 | ``` 72 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/assets/custom_schedule.csv: -------------------------------------------------------------------------------- 1 | 0F0,0F1,0F2,0F3,0F4,0F5,0F6,0F7,2F0,2F1,2F2,2F3,2F4,2F5,2F6,2F7,2I0,2W0,2I1,2W1,0I0,0W0,0I1,0W1,2I2,2W2,2I3,2W3,0I2,0W2,0I3,0W3,2I4,2W4,2I5,2W5,0I4,0W4,0I5,0W5,2I6,2W6,2I7,2W7,0I6,0W6,0I7,0W7 2 | 1F0,1F1,1F2,1F3,1F4,1F5,1F6,1F7,3F0,3F1,3F2,3F3,3F4,3F5,3F6,3F7,3I0,3W0,3I1,3W1,1I0,1W0,1I1,1W1,3I2,3W2,3I3,3W3,1I2,1W2,1I3,1W3,3I4,3W4,3I5,3W5,1I4,1W4,1I5,1W5,3I6,3W6,3I7,3W7,1I6,1W6,1I7,1W7 3 | -------------------------------------------------------------------------------- /tests/assets/extend_jobconfig_example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass, field 8 | 9 | 10 | @dataclass 11 | class CustomArgs: 12 | how_is_your_day: str = "good" 13 | """Just an example helptext""" 14 | 15 | num_days: int = 7 16 | """Number of days in a week""" 17 | 18 | 19 | @dataclass 20 | class Training: 21 | steps: int = 99 22 | my_custom_steps: int = 32 23 | 24 | 25 | @dataclass 26 | class JobConfig: 27 | """ 28 | This is an example of how to extend the tyro parser with custom config classes. 29 | """ 30 | 31 | custom_args: CustomArgs = field(default_factory=CustomArgs) 32 | training: Training = field(default_factory=Training) 33 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/test_dataset_checkpointing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import unittest 8 | 9 | import torch 10 | from datasets import load_dataset 11 | from torchtitan.config_manager import ConfigManager 12 | from torchtitan.datasets.hf_datasets import build_hf_dataloader, DatasetConfig, DATASETS 13 | from torchtitan.datasets.tokenizer.tiktoken import TikTokenizer 14 | 15 | 16 | class TestDatasetCheckpointing(unittest.TestCase): 17 | def setUp(self): 18 | DATASETS["c4_test_streaming"] = DatasetConfig( 19 | path="tests/assets/c4_test", 20 | loader=lambda path: load_dataset(path, split="train").to_iterable_dataset( 21 | num_shards=4 22 | ), 23 | text_processor=lambda sample: sample["text"], 24 | ) 25 | 26 | def tearDown(self): 27 | del DATASETS["c4_test_streaming"] 28 | 29 | def test_c4_resumption(self): 30 | for dataset_name in ["c4_test", "c4_test_streaming"]: 31 | for world_size in [2, 4]: 32 | for rank in range(world_size): 33 | batch_size = 1 34 | seq_len = 1024 35 | 36 | dl = self._build_dataloader( 37 | dataset_name, batch_size, seq_len, world_size, rank 38 | ) 39 | 40 | it = iter(dl) 41 | for _ in range(250): 42 | next(it) 43 | state = dl.state_dict() 44 | 45 | # Create new dataloader, restore checkpoint, and check if next data yielded is the same as above 46 | dl_resumed = self._build_dataloader( 47 | dataset_name, batch_size, seq_len, world_size, rank 48 | ) 49 | dl_resumed.load_state_dict(state) 50 | it_resumed = iter(dl_resumed) 51 | 52 | for _ in range(500): 53 | expected_input_ids, expected_labels = next(it) 54 | input_ids, labels = next(it_resumed) 55 | assert torch.equal( 56 | input_ids["input"], expected_input_ids["input"] 57 | ) 58 | assert torch.equal(labels, expected_labels) 59 | 60 | def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank): 61 | tokenizer = TikTokenizer("./tests/assets/test_tiktoken.model") 62 | config_manager = ConfigManager() 63 | config = config_manager.parse_args( 64 | [ 65 | "--training.dataset", 66 | dataset_name, 67 | "--training.batch_size", 68 | str(batch_size), 69 | "--training.seq_len", 70 | str(seq_len), 71 | ] 72 | ) 73 | 74 | return build_hf_dataloader( 75 | tokenizer=tokenizer, 76 | dp_world_size=world_size, 77 | dp_rank=rank, 78 | job_config=config, 79 | ) 80 | -------------------------------------------------------------------------------- /tests/unit_tests/test_model_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtitan.components.quantization.float8 import Float8Converter 8 | from torchtitan.config_manager import ConfigManager 9 | from torchtitan.distributed import ParallelDims 10 | from torchtitan.protocols.model_converter import ( 11 | build_model_converters, 12 | ModelConvertersContainer, 13 | ) 14 | 15 | 16 | def build_parallel_dims(job_config, world_size): 17 | parallelism_config = job_config.parallelism 18 | parallel_dims = ParallelDims( 19 | dp_shard=parallelism_config.data_parallel_shard_degree, 20 | dp_replicate=parallelism_config.data_parallel_replicate_degree, 21 | cp=parallelism_config.context_parallel_degree, 22 | tp=parallelism_config.tensor_parallel_degree, 23 | pp=parallelism_config.pipeline_parallel_degree, 24 | world_size=world_size, 25 | enable_loss_parallel=not parallelism_config.disable_loss_parallel, 26 | ) 27 | return parallel_dims 28 | 29 | 30 | def test_build_model_converters_empty_list(): 31 | config_manager = ConfigManager() 32 | config = config_manager.parse_args([]) 33 | parallel_dims = build_parallel_dims(config, 1) 34 | 35 | model_converters = build_model_converters(config, parallel_dims) 36 | assert isinstance(model_converters, ModelConvertersContainer) 37 | assert model_converters.converters == [] 38 | 39 | 40 | def test_build_model_converters_float8_converter(): 41 | config_manager = ConfigManager() 42 | config = config_manager.parse_args( 43 | ["--model.converters", "float8", "--float8.emulate"] 44 | ) 45 | parallel_dims = build_parallel_dims(config, 1) 46 | 47 | model_converters = build_model_converters(config, parallel_dims) 48 | assert isinstance(model_converters, ModelConvertersContainer) 49 | assert len(model_converters.converters) == 1 50 | assert isinstance(model_converters.converters[0], Float8Converter) 51 | -------------------------------------------------------------------------------- /torchtitan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Import to register quantization modules. 8 | import torchtitan.components.quantization # noqa: F401 9 | 10 | # Import the built-in models here so that the corresponding register_model_spec() 11 | # will be called. 12 | import torchtitan.experiments # noqa: F401 13 | import torchtitan.models # noqa: F401 14 | -------------------------------------------------------------------------------- /torchtitan/components/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | import pickle 10 | from abc import ABC, abstractmethod 11 | from collections.abc import Callable 12 | from typing import Any 13 | 14 | from torch.distributed.checkpoint.stateful import Stateful 15 | from torch.utils.data import IterableDataset 16 | from torchdata.stateful_dataloader import StatefulDataLoader 17 | from torchtitan.tools.logging import logger 18 | 19 | 20 | class BaseDataLoader(Stateful, ABC): 21 | """Base class for all dataloaders. 22 | 23 | This is used to enforce that all dataloaders have the methods defined in ``Stateful``, 24 | ``state_dict()`` and ``load_state_dict()``. 25 | """ 26 | 27 | @abstractmethod 28 | def __iter__(self): 29 | ... 30 | 31 | 32 | class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): 33 | """Dataloader that is aware of distributed data parallelism. 34 | 35 | This dataloader is used to load data in a distributed data parallel fashion. It also 36 | utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary 37 | methods such as ``__iter__``. 38 | 39 | Args: 40 | dataset (IterableDataset): The dataset to iterate over. 41 | dp_rank: Data parallelism rank for this dataloader. 42 | dp_world_size: The world size of the data parallelism. 43 | batch_size: The batch size to use for each iteration. 44 | collate_fn: Optional function to collate samples in a batch. 45 | """ 46 | 47 | dp_rank: int 48 | dp_world_size: int 49 | batch_size: int 50 | 51 | def __init__( 52 | self, 53 | dataset: IterableDataset, 54 | dp_rank: int, 55 | dp_world_size: int, 56 | batch_size: int, 57 | collate_fn: Callable | None = None, 58 | ): 59 | self.dp_world_size = dp_world_size 60 | self.dp_rank = dp_rank 61 | self.batch_size = batch_size 62 | super().__init__(dataset, batch_size, collate_fn=collate_fn) 63 | self._rank_id = f"dp_rank_{dp_rank}" 64 | 65 | def state_dict(self) -> dict[str, Any]: 66 | # Store state only for dp rank to avoid replicating the same state across other dimensions. 67 | return { 68 | # We don't have to use pickle as DCP will serialize the state_dict. However, 69 | # we have to keep this for backward compatibility. 70 | self._rank_id: pickle.dumps(super().state_dict()), 71 | "world_size": self.dp_world_size, 72 | } 73 | 74 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 75 | # State being empty is valid. 76 | if not state_dict: 77 | return 78 | 79 | if self._rank_id not in state_dict: 80 | logger.warning( 81 | f"DataLoader state is empty for dp rank {self.dp_rank}, " 82 | "expected key {self._rank_id}" 83 | ) 84 | return 85 | 86 | assert self.dp_world_size == state_dict["world_size"], ( 87 | "dp_degree is inconsistent before and after checkpoint, " 88 | "dataloader resharding is not supported yet." 89 | ) 90 | # We don't have to use pickle as DCP will serialize the state_dict. However, we have to 91 | # keep this for backward compatibility. 92 | super().load_state_dict(pickle.loads(state_dict[self._rank_id])) 93 | -------------------------------------------------------------------------------- /torchtitan/components/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, TypeAlias 8 | 9 | import torch 10 | 11 | from torchtitan.config_manager import JobConfig 12 | from torchtitan.tools.logging import logger 13 | 14 | LossFunction: TypeAlias = Callable[..., torch.Tensor] 15 | 16 | 17 | def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 18 | """Common cross-entropy loss function for Transformer models training.""" 19 | return torch.nn.functional.cross_entropy( 20 | pred.flatten(0, 1).float(), labels.flatten(0, 1) 21 | ) 22 | 23 | 24 | def build_cross_entropy_loss(job_config: JobConfig): 25 | loss_fn = cross_entropy_loss 26 | if job_config.training.compile: 27 | logger.info("Compiling the loss function with torch.compile") 28 | loss_fn = torch.compile(loss_fn) 29 | return loss_fn 30 | -------------------------------------------------------------------------------- /torchtitan/components/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # [Note] Getting the 'torchao' package: 8 | # This script requires the 'torchao' package to function correctly. 9 | # Please ensure you have this package installed from the appropriate repository. 10 | # You can obtain it from https://github.com/pytorch/ao by following the 11 | # installation instructions. 12 | 13 | # Note: Performance 14 | # The quantization modules are intended to be ran under `torch.compile`` for competitive performance 15 | 16 | # Import to register quantization modules as ModelConverter 17 | import torchtitan.components.quantization.float8 # noqa: F401 18 | import torchtitan.components.quantization.mx # noqa: F401 19 | -------------------------------------------------------------------------------- /torchtitan/components/quantization/mx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | from importlib.metadata import version 9 | from importlib.util import find_spec 10 | from typing import Any, List 11 | 12 | import torch.nn as nn 13 | 14 | from torchtitan.config_manager import JobConfig, MX 15 | from torchtitan.distributed import ParallelDims 16 | from torchtitan.protocols.model_converter import ( 17 | ModelConverter, 18 | register_model_converter, 19 | ) 20 | from torchtitan.tools.logging import logger 21 | from torchtitan.tools.utils import has_cuda_capability 22 | 23 | from .utils import module_filter_fn 24 | 25 | # Maps titan recipe names to torchao mx recipe names 26 | NAME_MAP = {"mxfp8": "mxfp8_cublas"} 27 | 28 | 29 | class MXConverter(ModelConverter): 30 | """Converts the linear layers of `model` to `MXLinear`.""" 31 | 32 | enabled: bool 33 | filter_fqns: List[str] 34 | mx_config: Any # MXLinearConfig type when imported 35 | 36 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 37 | # Ensure minimum torchao versions 38 | if find_spec("torchao") is None: 39 | raise ImportError( 40 | "torchao is not installed. Please install it to use MXFP8 linear layers." 41 | ) 42 | torchao_version = version("torchao") 43 | mxfp8_min_version = "0.11.0" 44 | if torchao_version < mxfp8_min_version: 45 | raise ImportError( 46 | f"torchao version {torchao_version} is too old, please install torchao {mxfp8_min_version} or later and try again" 47 | ) 48 | 49 | # Can be removed if we enable the emulated versions 50 | assert has_cuda_capability( 51 | 10, 0 52 | ), "MXFP8 is only supported on SM100 or architectures" 53 | 54 | self.enabled = True 55 | mx_job_config: MX = job_config.mx 56 | self.filter_fqns = mx_job_config.filter_fqns 57 | 58 | # Configure MXFP8 59 | from torchao.prototype.mx_formats.config import MXLinearConfig 60 | 61 | config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name]) 62 | config.use_fp8_dim1_cast_triton_kernel = ( 63 | mx_job_config.use_fp8_dim1_cast_triton_kernel 64 | ) 65 | self.config = config 66 | 67 | logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}") 68 | 69 | def convert(self, model: nn.Module): 70 | """ 71 | Converts the linear layers of `model` to `MXLinear`. 72 | Note that today, only dynamic tensor scaling (the default) is supported. 73 | This will mutate the model inplace. 74 | """ 75 | if not self.enabled: 76 | return 77 | 78 | from torchao.prototype.mx_formats.config import MXLinearConfig 79 | from torchao.quantization import quantize_ 80 | 81 | assert isinstance(self.config, MXLinearConfig) 82 | quantize_( 83 | model, 84 | config=self.config, 85 | filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns), 86 | ) 87 | logger.info("Swapped to MXLinear layers") 88 | 89 | def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): 90 | """ 91 | MXFP8 doesn't require any post-optimizer hooks at the moment 92 | """ 93 | return 94 | 95 | 96 | register_model_converter(MXConverter, "mx") 97 | -------------------------------------------------------------------------------- /torchtitan/components/quantization/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch.nn as nn 8 | 9 | 10 | def module_filter_fn(mod: nn.Module, fqn: str, filter_fqns: list[str]) -> bool: 11 | """ 12 | Filter function to determine which modules should be converted. 13 | For both Float8 and MXFP8, we only convert Linear modules 14 | with dimensions divisible by 16 and not matching any filtered FQNs. 15 | """ 16 | if not isinstance(mod, nn.Linear): 17 | return False 18 | 19 | # All dims must be divisible by 16 due to float8 tensorcore hardware requirements. 20 | dims_multiples_of_16 = ( 21 | mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0 22 | ) 23 | 24 | # If the fqn matches any filtered fqn, then we should not convert this module. 25 | is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns) 26 | 27 | return dims_multiples_of_16 and not is_filtered_fqn 28 | -------------------------------------------------------------------------------- /torchtitan/components/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from abc import ABC, abstractmethod 9 | 10 | 11 | class Tokenizer(ABC): 12 | # basic tokenizer interface, for typing purpose mainly 13 | def __init__(self): 14 | self._n_words = 8 15 | self.eos_id = 0 16 | 17 | @abstractmethod 18 | def encode(self, *args, **kwargs) -> list[int]: 19 | ... 20 | 21 | @abstractmethod 22 | def decode(self, *args, **kwargs) -> str: 23 | ... 24 | 25 | @property 26 | def n_words(self) -> int: 27 | return self._n_words 28 | -------------------------------------------------------------------------------- /torchtitan/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from torchtitan.distributed.parallel_dims import ParallelDims 9 | 10 | 11 | __all__ = ["ParallelDims"] 12 | -------------------------------------------------------------------------------- /torchtitan/distributed/parallel_dims.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections.abc import Callable 8 | from dataclasses import dataclass 9 | from functools import cached_property 10 | 11 | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh 12 | 13 | from torchtitan.tools.logging import logger 14 | 15 | 16 | __all__ = ["ParallelDims"] 17 | 18 | 19 | @dataclass 20 | class ParallelDims: 21 | dp_replicate: int 22 | dp_shard: int 23 | cp: int 24 | tp: int 25 | pp: int 26 | world_size: int 27 | enable_loss_parallel: bool 28 | 29 | def __post_init__(self): 30 | self._validate() 31 | 32 | def _validate(self): 33 | dp_replicate, dp_shard, cp, tp, pp = ( 34 | self.dp_replicate, 35 | self.dp_shard, 36 | self.cp, 37 | self.tp, 38 | self.pp, 39 | ) 40 | for d in (dp_replicate, cp, tp, pp): 41 | assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" 42 | 43 | assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." 44 | if dp_shard < 0: 45 | self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp) 46 | assert dp_shard >= 1 47 | 48 | assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, ( 49 | f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " 50 | f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" 51 | ) 52 | 53 | def build_mesh(self, device_type: str) -> DeviceMesh: 54 | dims = [] 55 | names = [] 56 | for d, name in zip( 57 | [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], 58 | ["pp", "dp_replicate", "dp_shard", "cp", "tp"], 59 | ): 60 | if d > 1: 61 | dims.append(d) 62 | names.append(name) 63 | 64 | return self._build_mesh(device_type, dims, names, init_device_mesh) 65 | 66 | def _build_mesh( 67 | self, 68 | device_type: str, 69 | dims: list[int], 70 | names: list[str], 71 | init_device_mesh_fn: Callable, 72 | ) -> DeviceMesh: 73 | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") 74 | mesh = init_device_mesh_fn(device_type, dims, mesh_dim_names=names) 75 | 76 | # Create all the submesh here to ensure all required process groups are 77 | # initialized: 78 | # Mesh for data loading (no communication on this mesh) 79 | dp_mesh_dim_names = [] 80 | # Mesh for param sharding 81 | dp_shard_cp_mesh_dim_names = [] 82 | # Mesh for loss all-reduce 83 | dp_cp_mesh_dim_names = [] 84 | 85 | if self.dp_replicate_enabled: 86 | dp_mesh_dim_names.append("dp_replicate") 87 | dp_cp_mesh_dim_names.append("dp_replicate") 88 | if self.dp_shard_enabled: 89 | dp_mesh_dim_names.append("dp_shard") 90 | dp_shard_cp_mesh_dim_names.append("dp_shard") 91 | dp_cp_mesh_dim_names.append("dp_shard") 92 | if self.cp_enabled: 93 | dp_shard_cp_mesh_dim_names.append("cp") 94 | dp_cp_mesh_dim_names.append("cp") 95 | 96 | if dp_mesh_dim_names != []: 97 | mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") 98 | if dp_shard_cp_mesh_dim_names != []: 99 | mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( 100 | mesh_dim_name="dp_shard_cp" 101 | ) 102 | if dp_cp_mesh_dim_names != []: 103 | mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") 104 | 105 | return mesh 106 | 107 | @property 108 | def dp_enabled(self): 109 | return self.dp_replicate > 1 or self.dp_shard > 1 110 | 111 | @property 112 | def dp_replicate_enabled(self): 113 | return self.dp_replicate > 1 114 | 115 | @property 116 | def dp_shard_enabled(self): 117 | return self.dp_shard > 1 118 | 119 | @property 120 | def cp_enabled(self): 121 | return self.cp > 1 122 | 123 | @property 124 | def tp_enabled(self): 125 | return self.tp > 1 126 | 127 | @property 128 | def pp_enabled(self): 129 | return self.pp > 1 130 | 131 | @property 132 | def loss_parallel_enabled(self): 133 | return self.tp > 1 and self.enable_loss_parallel 134 | 135 | @cached_property 136 | def non_data_parallel_size(self): 137 | return self.cp * self.tp * self.pp 138 | -------------------------------------------------------------------------------- /torchtitan/experiments/README.md: -------------------------------------------------------------------------------- 1 | To accelerate contributions to and innovations around `torchtitan`, we are adding this new, experimental folder. Below are the general contributing guidelines, and we look forward to your contributions! 2 | 3 | ## Contributing Guidelines 4 | 5 | We provide this `experiments/` folder to host experiments that add significant value to `torchtitan`, with the following principles. We refer to the part of `torchtitan` outside `experiments` as `core`. 6 | 1. Each subfolder in `experiments` will be an experiment, with a clear theme which can be flexible, such as 7 | - a new model, or preferably a new model architecture, with its training infrastructure including parallelization functions; 8 | - an enhancement or addition to the existing infrastructure of `torchtitan`. 9 | 2. It is the contributors' responsibility to justify the value of an experiment. `torchtitan` team will review proposals on a case-by-case basis. As part of the contribution, the contributors should provide documentation that clearly showcases the motivation and innovation of an experiment, including reports on performance and loss convergence. 10 | 3. An experiment should reuse existing `torchtitan` code as much as possible, such as modules in [`components/`](../components/) (via a new [`TrainSpec`](../protocols/train_spec.py)) and [`train.py`](../train.py). For a list of extension points we provide, please refer to [docs/extension.md](../../docs/extension.md). 11 | - The extension points are subject to change. We kindly request that contributors provide feedback if they encounter issues reusing any components, rather than simply using a copy-and-paste approach. 12 | - The degree to which existing components are reused and whether duplications are legit will also be a criteria of whether an experiment would be accepted. 13 | 4. Each experiment is independent from other experiments, and can have its own dependencies (on top of [core dependencies](../../requirements.txt)), and its own tests. 14 | 5. The dependency from `experiments` to `core` is one-way. Anything in `experiments` is optional for `core` to run successfully. In particular, development in `core` is not blocked by breakage in `experiments`. We will utilize GitHub's [CI mechanism](https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#onpushpull_requestpull_request_targetpathspaths-ignore) to help test an experiment periodically and only if the experiment itself is affected by a PR. 15 | 6. Each experiment needs to have an owner. The owner is responsible to work with `torchtitan` team to maintain the quality and healthiness of an experiment, which includes 16 | - adapting an experiment to changes in `core` and fix broken tests, no later than the next official `torchtitan` release; 17 | - responding to GitHub issues and questions in a timely manner. 18 | 7. `torchtitan` team reserve the right to remove an experiment. In particular, an experiment should be removed if 19 | - it has served its purpose (e.g., providing findings, or getting some features upstreamed to `core` or PyTorch, etc.), or 20 | - it gets stale (e.g. not being maintained). 21 | -------------------------------------------------------------------------------- /torchtitan/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torchtitan.experiments.llama4 # noqa: F401 8 | import torchtitan.experiments.simple_fsdp # noqa: F401 9 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DeepSeek 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/README.md: -------------------------------------------------------------------------------- 1 | # Running DeepSeek in Titan (experimental) 2 | 3 | This folder contains a DeepSeek model supporting v2 and v3 as well as kernels 4 | and scripts needed to run it. 5 | 6 | ## Inference 7 | 8 | ### Prerequisites: 9 | 10 | You will need to download a DeepSeek model's weights if you want to run a 11 | pre-trained checkpoint. We provided a script to download the weights from 12 | HuggingFace Model Hub: 13 | ```bash 14 | python download.py [vX] 15 | ``` 16 | where `vX` can be v2 or v3, both are supported. You may be required to create a 17 | HuggingFace account and log in first. 18 | 19 | ### Running inference: 20 | 21 | The inference script is in `generate.py`. You can run it with the following 22 | command: 23 | ```bash 24 | torchrun --standalone --nproc-per-node 4 generate.py 25 | ``` 26 | This will run inference on the `DeepSeek-V2-Lite-Chat` model using 4 GPUs by 27 | default. 28 | 29 | Alternatively, you can run inference by using `bash inference.sh`, optionally 30 | followed by your prompt. 31 | 32 | ## Training 33 | 34 | The training script is in `train.py`. You can run it by the following command: 35 | ```bash 36 | torchrun --standalone --nproc-per-node 8 train.py 37 | ``` 38 | 39 | This will run training on the `DeepSeek-V2-Lite-Chat` model using 8 GPUs by 40 | default, with pipeline parallel, expert parallel, and data parallel enabled. 41 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from torchtitan.components.loss import build_cross_entropy_loss 9 | from torchtitan.components.lr_scheduler import build_lr_schedulers 10 | from torchtitan.components.optimizer import build_optimizers 11 | from torchtitan.datasets.hf_datasets import build_hf_dataloader 12 | from torchtitan.experiments.deepseek_v3.tokenizers.hf_tokenizer import get_hf_tokenizer 13 | 14 | # ToDO - this is not suitable for deepseek but using for now... 15 | from torchtitan.models.llama3 import pipeline_llama 16 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec 17 | 18 | from .infra.parallelize_deepseek import parallelize_deepseek 19 | 20 | from .model import DeepseekForCausalLM 21 | 22 | from .model_args import TransformerModelArgs 23 | 24 | 25 | __all__ = [ 26 | "TransformerModelArgs", 27 | "DeepseekForCausalLM", 28 | "deepseek_configs", 29 | ] 30 | 31 | 32 | deepseek_configs = { 33 | "debugmodel": TransformerModelArgs( 34 | dim=256, 35 | n_layers=6, 36 | n_heads=16, 37 | rope_theta=500000, 38 | ), 39 | } 40 | 41 | 42 | register_train_spec( 43 | TrainSpec( 44 | name="deepseek3", 45 | cls=DeepseekForCausalLM, 46 | config=deepseek_configs, 47 | parallelize_fn=parallelize_deepseek, 48 | pipelining_fn=pipeline_llama, 49 | build_optimizers_fn=build_optimizers, 50 | build_lr_schedulers_fn=build_lr_schedulers, 51 | build_dataloader_fn=build_hf_dataloader, 52 | build_tokenizer_fn=get_hf_tokenizer, 53 | build_loss_fn=build_cross_entropy_loss, 54 | ) 55 | ) 56 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Usage: 8 | # Downloads a given model to the HF Cache. Pass in a listed option ala "v3" or your own custom model path. 9 | # python download.py {model_id} [custom_model_path] 10 | # Examples: 11 | # python download.py v2 # Use predefined model: deepseek-ai/DeepSeek-V2 12 | # python download.py custom "deepseek-ai/new-model" # Download a custom model path 13 | 14 | # Available models: 15 | # "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat", 16 | # "v2-lite": "deepseek-ai/DeepSeek-V2-Lite", 17 | # "v2": "deepseek-ai/DeepSeek-V2", 18 | # "v3": "deepseek-ai/deepseek-v3", 19 | # "v3-0324": "deepseek-ai/DeepSeek-V3-0324", 20 | # "custom": None, # Placeholder for custom models 21 | 22 | 23 | import sys 24 | 25 | from transformers import AutoModelForCausalLM 26 | 27 | 28 | MODELS = { 29 | "v2-lite-chat": "deepseek-ai/DeepSeek-V2-Lite-Chat", 30 | "v2-lite": "deepseek-ai/DeepSeek-V2-Lite", 31 | "v2": "deepseek-ai/DeepSeek-V2", 32 | "v3": "deepseek-ai/deepseek-v3", 33 | "v3-0324": "deepseek-ai/DeepSeek-V3-0324", 34 | "custom": None, # For custom (any) models 35 | } 36 | 37 | 38 | def print_usage(): 39 | print("Usage:") 40 | print(" python download.py [model_version]") 41 | print(" python download.py custom [custom_model_path]") 42 | print("\nAvailable predefined models:") 43 | for key, model in MODELS.items(): 44 | if key != "custom": # Skip the custom placeholder 45 | print(f" {key}: {model}") 46 | print("\nFor custom models:") 47 | print(" custom: Specify your own model path") 48 | print(' Example: python download.py custom "organization/model-name"') 49 | sys.exit(1) 50 | 51 | 52 | # Process command line arguments 53 | if len(sys.argv) < 2 or sys.argv[1] not in MODELS: 54 | print_usage() 55 | 56 | if sys.argv[1] == "custom": 57 | if len(sys.argv) != 3: 58 | print("Error: Custom model requires a model path") 59 | print_usage() 60 | model_id = sys.argv[2] 61 | print(f"Using custom model: {model_id}") 62 | else: 63 | model_id = MODELS[sys.argv[1]] 64 | print(f"Downloading model: {model_id}") 65 | 66 | model = AutoModelForCausalLM.from_pretrained( 67 | model_id, 68 | device_map="auto", 69 | trust_remote_code=True, 70 | ) 71 | 72 | print(f"{model=}") 73 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/inference.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/bash 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | NGPU=${NGPU:-"4"} 10 | 11 | # Get the prompt from command line argument or use a default 12 | prompt="${1:-What is 2+2?}" 13 | 14 | # Run the model with the prompt 15 | torchrun --standalone --nproc-per-node ${NGPU} generate.py "$prompt" 16 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/infra/parallelize_deepseek.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Optional 9 | 10 | import torch 11 | import torch.distributed as dist 12 | 13 | from torch.distributed.device_mesh import DeviceMesh 14 | from torch.distributed.fsdp import fully_shard 15 | 16 | # from checkpoint import load_weights_from_hf 17 | from torchtitan.experiments.deepseek_v3.model import DeepseekForCausalLM 18 | 19 | from torchtitan.tools.logging import logger 20 | 21 | 22 | # Use DeepSeek-V2-Lite as a proxy 23 | model_id = "deepseek-ai/DeepSeek-V2-Lite" 24 | 25 | 26 | # from ..model.moe import MoE 27 | 28 | 29 | # Get model parallel subgroup by name: 30 | # e.g. "pp", "ep", None 31 | def get_group(dim_name: Optional[str] = None) -> dist.ProcessGroup: 32 | glob = torch.distributed.device_mesh._mesh_resources.get_current_mesh() 33 | return glob.get_group(dim_name) 34 | 35 | 36 | def parallelize_deepseek( 37 | # model: nn.Module, 38 | world_mesh: DeviceMesh, 39 | device: torch.device, 40 | model_args, 41 | rank: int, 42 | # parallel_dims: ParallelDims, 43 | # job_config: JobConfig, 44 | ): 45 | """ 46 | Apply parallelism to the model. 47 | 48 | NOTE: The passed-in model preferably should be on meta device. Otherwise, 49 | the model must fit on GPU or CPU memory. 50 | """ 51 | logger.info("Applying parallelism to the model...") 52 | world_size = int(os.environ["WORLD_SIZE"]) 53 | 54 | pp_mesh = world_mesh["pp"] 55 | ep_mesh = world_mesh["ep"] 56 | pp_rank = pp_mesh.get_local_rank() 57 | ep_rank = ep_mesh.get_local_rank() 58 | pp_size = pp_mesh.size() 59 | ep_size = ep_mesh.size() 60 | 61 | # Apply data parallelism 62 | fsdp_mesh = world_mesh["fsdp"] 63 | hsdp_mesh = world_mesh["ep", "fsdp"] 64 | 65 | hsdp_size = hsdp_mesh.size() 66 | 67 | # Apply model parallelism 68 | model_args.ep_size = ep_size 69 | model_args.num_stages = pp_size 70 | model_args.stage_idx = pp_rank 71 | logger.info( 72 | f"Parallelism: {rank=}, {ep_size=}, {pp_size=}, {model_args.ep_size=}, {model_args.num_stages=}, {model_args.stage_idx=}" 73 | ) 74 | # print(model_args) 75 | # verify world size matches parallelized total 76 | parallelized_world_size = pp_size * hsdp_size 77 | logger.info(f"Total Parallelized World size {parallelized_world_size}") 78 | assert ( 79 | world_size == parallelized_world_size 80 | ), f"mismatch between total world size {world_size=} and parallelized total {parallelized_world_size}" 81 | 82 | # Instantiate model 83 | with device, world_mesh: 84 | model = DeepseekForCausalLM(model_args) 85 | # Load weights 86 | # load_weights_from_hf(model, model_id, device) 87 | model.train() 88 | 89 | # Using `reshard_after_forward=False` to implement Zero-2, i.e. sharding the 90 | # optimizer (Zero-1) and gradients (Zero-2), but not the model weights. 91 | # Reason: the MoE is "sparsely activated" compared to the dense model, thus 92 | # it will be ineconomical re-gather the weights. 93 | for layer in model.model.layers.values(): 94 | # Apply FSDP to experts 95 | if hasattr(layer.mlp, "experts"): 96 | for expert in layer.mlp.experts.values(): 97 | fully_shard(expert, mesh=fsdp_mesh, reshard_after_forward=False) 98 | # Apply HSDP to other parts such as attention, layernorm, because they 99 | # are doing DDP on EP dimension 100 | fully_shard(layer, mesh=hsdp_mesh, reshard_after_forward=False) 101 | 102 | # Apply HSDP on root model (lm_head, embeddings, etc) 103 | fully_shard(model, mesh=hsdp_mesh, reshard_after_forward=False) 104 | 105 | return ( 106 | model, 107 | pp_size, 108 | pp_rank, 109 | pp_mesh, 110 | ep_size, 111 | ep_rank, 112 | ) 113 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | accelerate 3 | torchdata >= 0.8.0 4 | datasets >= 2.21.0 5 | tomli >= 1.1.0 ; python_version < "3.11" 6 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/run_training.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/bash 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the BSD-style license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | set -ex 10 | 11 | # use envs as local overrides for convenience 12 | # e.g. 13 | # LOG_RANK=0,1 NGPU=4 ./run_train.sh 14 | NGPU=${NGPU:-"4"} 15 | 16 | export LOG_RANK=${LOG_RANK:-0} 17 | CONFIG_FILE=${CONFIG_FILE:-"./train_configs/deepseek_v2.toml"} 18 | 19 | overrides="" 20 | if [ $# -ne 0 ]; then 21 | overrides="$*" 22 | fi 23 | 24 | TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} 25 | 26 | PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ 27 | TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ 28 | torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ 29 | --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ 30 | train_ds_real.py --job.config_file ${CONFIG_FILE} $overrides 31 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/symm_mem_recipes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .triton_on_device_all_to_all_v import OnDeviceAllToAllV 8 | 9 | __all__ = [ 10 | "OnDeviceAllToAllV", 11 | ] 12 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/symm_mem_recipes/triton_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import triton 8 | import triton.language as tl 9 | 10 | 11 | @triton.jit 12 | def get_tid(): 13 | return tl.inline_asm_elementwise( 14 | """ 15 | mov.u32 $0, %tid.x; 16 | mov.u32 $1, %tid.y; 17 | mov.u32 $2, %tid.z; 18 | """, 19 | "=r,=r,=r", 20 | [], 21 | dtype=(tl.uint32, tl.uint32, tl.uint32), 22 | is_pure=True, 23 | pack=1, 24 | ) 25 | 26 | 27 | @triton.jit 28 | def get_ntid(): 29 | return tl.inline_asm_elementwise( 30 | """ 31 | mov.u32 $0, %ntid.x; 32 | mov.u32 $1, %ntid.y; 33 | mov.u32 $2, %ntid.z; 34 | """, 35 | "=r,=r,=r", 36 | [], 37 | dtype=(tl.uint32, tl.uint32, tl.uint32), 38 | is_pure=True, 39 | pack=1, 40 | ) 41 | 42 | 43 | @triton.jit 44 | def get_flat_tid(): 45 | tid_x, tid_y, tid_z = get_tid() 46 | ntid_x, ntid_y, _ = get_ntid() 47 | return tid_z * ntid_y * ntid_x + tid_y * ntid_x + tid_x 48 | 49 | 50 | @triton.jit 51 | def get_flat_bid(): 52 | return ( 53 | tl.program_id(2) * tl.num_programs(1) * tl.num_programs(0) 54 | + tl.program_id(1) * tl.num_programs(0) 55 | + tl.program_id(0) 56 | ) 57 | 58 | 59 | @triton.jit 60 | def sync_threads(): 61 | tl.inline_asm_elementwise( 62 | "bar.sync 0;", "=r", [], dtype=tl.int32, is_pure=False, pack=1 63 | ) 64 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/tokenizers/hf_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | from torchtitan.tools.logging import logger 10 | from transformers import AutoTokenizer 11 | 12 | 13 | # HF AutoTokenizer will instantiate a root level logger, which will cause 14 | # duplicate logs. We need to disable their root logger to avoid this. 15 | def remove_notset_root_handlers(): 16 | """ 17 | Remove handlers with level NOTSET from root logger. 18 | Titan's logger is set, and thus we can differentiate between these. 19 | """ 20 | for handler in logger.handlers[:]: 21 | if handler.level == logging.NOTSET: 22 | logger.removeHandler(handler) 23 | 24 | 25 | class TokenizerWrapper: 26 | def __init__(self, tokenizer): 27 | self.tokenizer = tokenizer 28 | 29 | def encode(self, text, bos=False, eos=False, **kwargs): 30 | # Handle bos and eos parameters 31 | if bos: 32 | kwargs["add_special_tokens"] = True 33 | if eos: 34 | kwargs["add_special_tokens"] = True 35 | 36 | return self.tokenizer.encode(text, **kwargs) 37 | 38 | def __getattr__(self, name): 39 | # Delegate all other attributes/methods to the underlying tokenizer 40 | return getattr(self.tokenizer, name) 41 | 42 | 43 | def get_hf_tokenizer(model_id: str): 44 | logger.info(f"Instantiating tokenizer for {model_id}") 45 | tokenizer = AutoTokenizer.from_pretrained(model_id) 46 | return TokenizerWrapper(tokenizer) 47 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/train_configs/custom_args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass, field 8 | 9 | 10 | @dataclass 11 | class Parallelism: 12 | expert_parallel_degree: int = 2 13 | """ degree to parallelize experts """ 14 | 15 | 16 | @dataclass 17 | class Training: 18 | steps: int = 22222222 19 | 20 | 21 | @dataclass 22 | class JobConfig: 23 | parallelism: Parallelism = field(default_factory=Parallelism) 24 | training: Training = field(default_factory=Training) 25 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml: -------------------------------------------------------------------------------- 1 | [job] 2 | dump_folder = "./outputs" 3 | description = "DeepSeek v2 debug training" 4 | print_args = false 5 | use_for_integration_test = true 6 | 7 | [profiling] 8 | enable_profiling = false 9 | save_traces_folder = "profile_trace" 10 | profile_freq = 10 11 | enable_memory_snapshot = false 12 | save_memory_snapshot_folder = "memory_snapshot" 13 | 14 | [metrics] 15 | log_freq = 1 16 | disable_color_printing = false 17 | enable_tensorboard = false 18 | save_tb_folder = "tb" 19 | enable_wandb = false 20 | 21 | [model] 22 | name = "deepseek_v2" 23 | flavor = "deepseek-ai/DeepSeek-V2-Lite" 24 | # test tokenizer.model, for debug purpose only 25 | tokenizer_path = "./tests/assets/test_tiktoken.model" 26 | # converters = ["float8"] 27 | 28 | [optimizer] 29 | name = "AdamW" 30 | lr = 4e-3 31 | eps = 1e-15 32 | implementation = "foreach" 33 | 34 | [lr_scheduler] 35 | warmup_steps = 100 # lr scheduler warm up, normally 20% of the train steps 36 | decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps 37 | decay_type = "linear" 38 | lr_min = 0.1 39 | 40 | [training] 41 | batch_size = 2 # 8 42 | seq_len = 1024 # 2048 43 | max_norm = 1.0 # grad norm clipping 44 | steps = 200 45 | compile = false 46 | dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) 47 | 48 | [parallelism] 49 | data_parallel_replicate_degree = 1 50 | data_parallel_shard_degree = 2 # we use Zero2 so it's not really sharding per se... 51 | fsdp_reshard_after_forward = "default" # default / never / always 52 | tensor_parallel_degree = 1 53 | enable_async_tensor_parallel = false 54 | pipeline_parallel_degree = 1 55 | context_parallel_degree = 1 56 | # expert_parallel_degree = 2 set in custom_args 57 | 58 | [checkpoint] 59 | enable_checkpoint = false 60 | folder = "checkpoint" 61 | interval = 10 62 | model_weights_only = false 63 | export_dtype = "float32" 64 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 65 | 66 | [activation_checkpoint] 67 | mode = "none" # ["none", "selective", "full"] 68 | selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy 69 | 70 | [float8] 71 | enable_fsdp_float8_all_gather = false 72 | precompute_float8_dynamic_scale_for_fsdp = false 73 | filter_fqns = ["output", "router.gate"] 74 | 75 | [experimental] 76 | # expert parallelism is set here (default is 2) 77 | custom_args_module = "torchtitan.experiments.deepseek_v3.train_configs.custom_args" 78 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/unit_testing/benchmark_kernels.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # benchmark quantization kernels 8 | 9 | # sizes to benchmark: 10 | # valid_tokens.shape=torch.Size([2688, 2048]) 11 | # hidden_states.shape=torch.Size([2688, 1408]) 12 | 13 | # benchmark quantization kernels 14 | import time 15 | 16 | import dsgemm_kernels 17 | import dsgemm_utils 18 | import torch 19 | 20 | 21 | def benchmark_quant_kernels(shapes, dtype=torch.bfloat16, warmup=10, iters=100): 22 | results = [] 23 | 24 | for shape in shapes: 25 | m, k = shape 26 | print(f"Benchmarking shape: {shape}") 27 | 28 | # Create input tensor 29 | x = torch.randn(m, k, device="cuda", dtype=dtype) 30 | 31 | # Warmup groupwise_activation_quant 32 | for _ in range(warmup): 33 | _ = dsgemm_kernels.groupwise_activation_quant(x) 34 | torch.cuda.synchronize() 35 | 36 | # Benchmark groupwise_activation_quant 37 | torch.cuda.synchronize() 38 | start = time.time() 39 | for _ in range(iters): 40 | y1 = dsgemm_kernels.groupwise_activation_quant(x) 41 | torch.cuda.synchronize() 42 | groupwise_time = (time.time() - start) / iters * 1000 # ms 43 | 44 | # Benchmark per_token_cast_to_fp8 45 | for _ in range(warmup): 46 | _ = dsgemm_utils.per_token_cast_to_fp8(x) 47 | torch.cuda.synchronize() 48 | torch.cuda.synchronize() 49 | start = time.time() 50 | for _ in range(iters): 51 | y2 = dsgemm_utils.per_token_cast_to_fp8(x) 52 | torch.cuda.synchronize() 53 | per_token_time = (time.time() - start) / iters * 1000 # ms 54 | 55 | # Calculate speedup 56 | groupwise_vs_per_token = per_token_time / groupwise_time 57 | 58 | # Check correctness 59 | max_diff = dsgemm_utils.compare_fp8_tensors(y1[0], y2[0]) 60 | 61 | results.append( 62 | { 63 | "shape": shape, 64 | "groupwise_ms": groupwise_time, 65 | "per_token_ms": per_token_time, 66 | "groupwise_vs_per_token": groupwise_vs_per_token, 67 | "max_diff": max_diff, 68 | } 69 | ) 70 | 71 | print(f" groupwise: {groupwise_time:.3f} ms") 72 | print(f" per_token: {per_token_time:.3f} ms") 73 | print( 74 | f" groupwise vs per_token: Groupwise is {groupwise_vs_per_token:.2f}x faster than per_token" 75 | ) 76 | print(f" max_diff: {max_diff}") 77 | print() 78 | 79 | return results 80 | 81 | 82 | def print_results_table(results): 83 | print("\nResults Summary:") 84 | print( 85 | f"{'Shape':>15} | {'Groupwise (ms)':>15} | {'PyTorch Eager(ms)':>18} | {'Groupwise/PyTorch':>18} " 86 | ) 87 | print("-" * 85) 88 | 89 | for r in results: 90 | print( 91 | f"{str(r['shape']):>15} | {r['groupwise_ms']:>15.3f} | {r['per_token_ms']:>18.3f} | " 92 | f"{r['groupwise_vs_per_token']:>18.2f}x | " 93 | ) 94 | 95 | 96 | if __name__ == "__main__": 97 | # Test various shapes 98 | shapes = [ 99 | (2048, 2048), 100 | (2688, 2048), 101 | (4096, 2048), # Large batch 102 | (8192, 2048), 103 | (2048, 8192), 104 | # Different feature dimensions 105 | (1024, 2048), 106 | (1024, 4096), 107 | (2048, 4096), 108 | (2688, 1408), 109 | # Square matrices 110 | (2048, 2048), 111 | (4096, 4096), 112 | ] 113 | 114 | results = benchmark_quant_kernels(shapes) 115 | print_results_table(results) 116 | -------------------------------------------------------------------------------- /torchtitan/experiments/deepseek_v3/unit_testing/test_create_m_indices.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torchtitan.experiments.deepseek_v3.dsgemm_utils import ( 9 | create_indices_from_offsets_nosync, 10 | ) 11 | 12 | 13 | def test_create_indices_from_offsets_nosync(): 14 | # Test case 1: Regular offsets with increasing values 15 | m_offsets = torch.tensor([128, 256, 384, 512], device="cuda", dtype=torch.int32) 16 | indices = create_indices_from_offsets_nosync(m_offsets) 17 | 18 | # Expected: 128 zeros, 128 ones, 128 twos, 128 threes 19 | expected = torch.cat( 20 | [ 21 | torch.zeros(128, dtype=torch.int32, device="cuda"), 22 | torch.ones(128, dtype=torch.int32, device="cuda"), 23 | 2 * torch.ones(128, dtype=torch.int32, device="cuda"), 24 | 3 * torch.ones(128, dtype=torch.int32, device="cuda"), 25 | ] 26 | ) 27 | 28 | assert torch.all(indices == expected), "Test case 1 failed" 29 | 30 | # Test case 2: Offsets with empty groups 31 | m_offsets = torch.tensor([128, 128, 256, 384], device="cuda", dtype=torch.int32) 32 | indices = create_indices_from_offsets_nosync(m_offsets) 33 | 34 | # Expected: 128 zeros, 0 ones (empty group), 128 twos, 128 threes 35 | expected = torch.cat( 36 | [ 37 | torch.zeros(128, dtype=torch.int32, device="cuda"), 38 | 2 * torch.ones(128, dtype=torch.int32, device="cuda"), 39 | 3 * torch.ones(128, dtype=torch.int32, device="cuda"), 40 | ] 41 | ) 42 | 43 | assert torch.all(indices == expected), "Test case 2 failed" 44 | 45 | print("All tests passed!") 46 | 47 | 48 | if __name__ == "__main__": 49 | test_create_indices_from_offsets_nosync() 50 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # FLUX model in torchtitan 4 | 5 | [![integration tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_flux.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_flux.yaml/badge.svg?branch=main) 6 | 7 |
8 | 9 | ## Overview 10 | This directory contains the implementation of the [FLUX](https://github.com/black-forest-labs/flux/tree/main) model in torchtitan. In torchtitan, we showcase the pre-training process of text-to-image part of the FLUX model. 11 | 12 | ## Prerequisites 13 | Install the required dependencies: 14 | ```bash 15 | pip install -r requirements-flux.txt 16 | ``` 17 | 18 | ## Usage 19 | First, download the autoencoder model from HuggingFace with your own access token: 20 | ```bash 21 | python torchtitan/experiments/flux/scripts/download_autoencoder.py --repo_id black-forest-labs/FLUX.1-dev --ae_path ae.safetensors --hf_token 22 | ``` 23 | 24 | This step will download the autoencoder model from HuggingFace and save it to the `torchtitan/experiments/flux/assets/autoencoder/ae.safetensors` file. 25 | 26 | Run the following command to train the model on a single GPU: 27 | ```bash 28 | ./torchtitan/experiments/flux/run_train.sh 29 | 30 | ``` 31 | 32 | If you want to train with other model config, run the following command: 33 | ```bash 34 | CONFIG_FILE="./torchtitan/experiments/flux/train_configs/flux_schnell_model.toml" ./torchtitan/experiments/flux/run_train.sh 35 | ``` 36 | 37 | ## Running Tests 38 | 39 | ### Unit Tests 40 | To run the unit tests for the FLUX model, use the following command: 41 | ```bash 42 | pytest -s torchtitan/experiments/flux/tests/ 43 | ``` 44 | 45 | ### Integration Tests 46 | To run the integration tests for the FLUX model, use the following command: 47 | ```bash 48 | python -m torchtitan.experiments.flux.tests.integration_tests 49 | ``` 50 | 51 | 52 | ## Supported Features 53 | - Parallelism: The model supports FSDP, HSDP for training on multiple GPUs. 54 | - Activation checkpointing: The model uses activation checkpointing to reduce memory usage during training. 55 | - Distributed checkpointing and loading. 56 | - Notes on the current checkpointing implementation: To keep the model wieghts are sharded the same way as checkpointing, we need to shard the model weights before saving the checkpoint. This is done by checking each module at the end of envaluation, and sharding the weights of the module if it is a FSDPModule. 57 | - CI for FLUX model. Supported periodically running integration tests on 8 GPUs, and unittests. 58 | 59 | 60 | 61 | ## TODO 62 | - [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc) 63 | - [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function 64 | - [ ] Add `torch.compile` support 65 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | 10 | from torchtitan.components.lr_scheduler import build_lr_schedulers 11 | from torchtitan.components.optimizer import build_optimizers 12 | from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader 13 | from torchtitan.experiments.flux.loss import build_mse_loss 14 | from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams 15 | from torchtitan.experiments.flux.parallelize_flux import parallelize_flux 16 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec 17 | 18 | from .model.model import FluxModel, FluxModelArgs 19 | 20 | __all__ = [ 21 | "FluxModelArgs", 22 | "FluxModel", 23 | "flux_configs", 24 | "parallelize_flux", 25 | ] 26 | 27 | 28 | flux_configs = { 29 | "flux-dev": FluxModelArgs( 30 | in_channels=64, 31 | out_channels=64, 32 | vec_in_dim=768, 33 | context_in_dim=4096, 34 | hidden_size=3072, 35 | mlp_ratio=4.0, 36 | num_heads=24, 37 | depth=19, 38 | depth_single_blocks=38, 39 | axes_dim=(16, 56, 56), 40 | theta=10_000, 41 | qkv_bias=True, 42 | autoencoder_params=AutoEncoderParams( 43 | resolution=256, 44 | in_channels=3, 45 | ch=128, 46 | out_ch=3, 47 | ch_mult=(1, 2, 4, 4), 48 | num_res_blocks=2, 49 | z_channels=16, 50 | scale_factor=0.3611, 51 | shift_factor=0.1159, 52 | ), 53 | ), 54 | "flux-schnell": FluxModelArgs( 55 | in_channels=64, 56 | out_channels=64, 57 | vec_in_dim=768, 58 | context_in_dim=4096, 59 | hidden_size=3072, 60 | mlp_ratio=4.0, 61 | num_heads=24, 62 | depth=19, 63 | depth_single_blocks=38, 64 | axes_dim=(16, 56, 56), 65 | theta=10_000, 66 | qkv_bias=True, 67 | autoencoder_params=AutoEncoderParams( 68 | resolution=256, 69 | in_channels=3, 70 | ch=128, 71 | out_ch=3, 72 | ch_mult=(1, 2, 4, 4), 73 | num_res_blocks=2, 74 | z_channels=16, 75 | scale_factor=0.3611, 76 | shift_factor=0.1159, 77 | ), 78 | ), 79 | "flux-debug": FluxModelArgs( 80 | in_channels=64, 81 | out_channels=64, 82 | vec_in_dim=768, 83 | context_in_dim=4096, 84 | hidden_size=3072, 85 | mlp_ratio=4.0, 86 | num_heads=24, 87 | depth=2, 88 | depth_single_blocks=2, 89 | axes_dim=(16, 56, 56), 90 | theta=10_000, 91 | qkv_bias=True, 92 | autoencoder_params=AutoEncoderParams( 93 | resolution=256, 94 | in_channels=3, 95 | ch=128, 96 | out_ch=3, 97 | ch_mult=(1, 2, 4, 4), 98 | num_res_blocks=2, 99 | z_channels=16, 100 | scale_factor=0.3611, 101 | shift_factor=0.1159, 102 | ), 103 | ), 104 | } 105 | 106 | 107 | register_train_spec( 108 | TrainSpec( 109 | name="flux", 110 | cls=FluxModel, 111 | config=flux_configs, 112 | parallelize_fn=parallelize_flux, 113 | pipelining_fn=None, 114 | build_optimizers_fn=build_optimizers, 115 | build_lr_schedulers_fn=build_lr_schedulers, 116 | build_dataloader_fn=build_flux_dataloader, 117 | build_tokenizer_fn=None, 118 | build_loss_fn=build_mse_loss, 119 | ) 120 | ) 121 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/job_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass, field 8 | 9 | 10 | @dataclass 11 | class Training: 12 | classifer_free_guidance_prob: float = 0.0 13 | """Classifier-free guidance with probability p to dropout the text conditioning""" 14 | img_size: int = 256 15 | """Image width to sample""" 16 | test_mode: bool = False 17 | """Whether to use intergration test mode, which will randomly initialize the encoder and use a dummy tokenizer""" 18 | 19 | 20 | @dataclass 21 | class Encoder: 22 | t5_encoder: str = "google/t5-v1_1-small" 23 | """T5 encoder to use, HuggingFace model name. This field could be either a local folder path, 24 | or a Huggingface repo name.""" 25 | clip_encoder: str = "openai/clip-vit-large-patch14" 26 | """Clip encoder to use, HuggingFace model name. This field could be either a local folder path, 27 | or a Huggingface repo name.""" 28 | autoencoder_path: str = ( 29 | "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" 30 | ) 31 | """Autoencoder checkpoint path to load. This should be a local path referring to a safetensors file.""" 32 | max_t5_encoding_len: int = 256 33 | """Maximum length of the T5 encoding.""" 34 | 35 | 36 | @dataclass 37 | class Eval: 38 | enable_classifer_free_guidance: bool = False 39 | """Whether to use classifier-free guidance during sampling""" 40 | classifier_free_guidance_scale: float = 5.0 41 | """Classifier-free guidance scale when sampling""" 42 | denoising_steps: int = 50 43 | """How many denoising steps to sample when generating an image""" 44 | eval_freq: int = 100 45 | """Frequency of evaluation/sampling during training""" 46 | save_img_folder: str = "img" 47 | """Directory to save image generated/sampled from the model""" 48 | 49 | 50 | @dataclass 51 | class JobConfig: 52 | """ 53 | Extend the tyro parser with custom config classe for Flux model. 54 | """ 55 | 56 | training: Training = field(default_factory=Training) 57 | encoder: Encoder = field(default_factory=Encoder) 58 | eval: Eval = field(default_factory=Eval) 59 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Callable, TypeAlias 8 | 9 | import torch 10 | 11 | from torchtitan.config_manager import JobConfig 12 | from torchtitan.tools.logging import logger 13 | 14 | LossFunction: TypeAlias = Callable[..., torch.Tensor] 15 | 16 | 17 | def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 18 | """Common MSE loss function for Transformer models training.""" 19 | return torch.nn.functional.mse_loss(pred.float(), labels.float().detach()) 20 | 21 | 22 | def build_mse_loss(job_config: JobConfig): 23 | loss_fn = mse_loss 24 | if job_config.training.compile: 25 | logger.info("Compiling the loss function with torch.compile") 26 | loss_fn = torch.compile(loss_fn) 27 | return loss_fn 28 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/model/hf_embedder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | from torch import nn, Tensor 10 | from transformers import CLIPTextModel, T5EncoderModel 11 | 12 | 13 | class FluxEmbedder(nn.Module): 14 | def __init__(self, version: str, random_init=False, **hf_kwargs): 15 | super().__init__() 16 | self.is_clip = "clip" in version.lower() 17 | self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" 18 | if self.is_clip: 19 | if random_init: 20 | # Initialize CLIP model with random weights for test purpose only 21 | self.hf_module = CLIPTextModel._from_config( 22 | CLIPTextModel.config_class.from_pretrained( 23 | os.path.join(version, "config.json"), **hf_kwargs 24 | ) 25 | ) 26 | else: 27 | self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( 28 | version, **hf_kwargs 29 | ) 30 | else: 31 | if random_init: 32 | # Initialize T5 model with random weights for test purpose only 33 | self.hf_module = T5EncoderModel._from_config( 34 | T5EncoderModel.config_class.from_pretrained( 35 | os.path.join(version, "config.json"), **hf_kwargs 36 | ) 37 | ) 38 | else: 39 | self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( 40 | version, **hf_kwargs 41 | ) 42 | 43 | self.hf_module = self.hf_module.eval().requires_grad_(False) 44 | 45 | def forward(self, batch_tokens: Tensor) -> Tensor: 46 | """ 47 | batch_tokens: [bsz, embedding_length] 48 | 49 | For T5 Encoder, embeding_length is 768 50 | For CLIP, embedding_length is 256 51 | """ 52 | outputs = self.hf_module( 53 | input_ids=batch_tokens.to(self.hf_module.device), 54 | attention_mask=None, 55 | output_hidden_states=False, 56 | ) 57 | return outputs[self.output_key] 58 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/model/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from einops import rearrange 9 | from torch import Tensor 10 | 11 | 12 | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: 13 | q, k = apply_rope(q, k, pe) 14 | 15 | x = torch.nn.functional.scaled_dot_product_attention(q, k, v) 16 | x = rearrange(x, "B H L D -> B L (H D)") 17 | 18 | return x 19 | 20 | 21 | def rope(pos: Tensor, dim: int, theta: int) -> Tensor: 22 | assert dim % 2 == 0 23 | scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim 24 | omega = 1.0 / (theta**scale) 25 | out = torch.einsum("...n,d->...nd", pos, omega) 26 | out = torch.stack( 27 | [torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 28 | ) 29 | out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) 30 | return out.float() 31 | 32 | 33 | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: 34 | xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) 35 | xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) 36 | xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] 37 | xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] 38 | return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) 39 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/requirements-flux.txt: -------------------------------------------------------------------------------- 1 | .ci/docker/requirements-flux.txt -------------------------------------------------------------------------------- /torchtitan/experiments/flux/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | # use envs as local overrides for convenience 11 | # e.g. 12 | # LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_train.sh 13 | NGPU=${NGPU:-"8"} 14 | export LOG_RANK=${LOG_RANK:-0} 15 | CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"} 16 | 17 | overrides="" 18 | if [ $# -ne 0 ]; then 19 | overrides="$*" 20 | fi 21 | 22 | 23 | PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ 24 | torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ 25 | --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ 26 | -m torchtitan.experiments.flux.train --job.config_file ${CONFIG_FILE} $overrides 27 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/scripts/download_autoencoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | from requests.exceptions import HTTPError 10 | 11 | 12 | def hf_download( 13 | repo_id: str, file_path: str, local_dir: str, hf_token: Optional[str] = None 14 | ) -> None: 15 | from huggingface_hub import hf_hub_download 16 | 17 | try: 18 | hf_hub_download( 19 | repo_id=repo_id, 20 | filename=file_path, 21 | local_dir=local_dir, 22 | local_dir_use_symlinks=False, 23 | token=hf_token, 24 | ) 25 | except HTTPError as e: 26 | if e.response.status_code == 401: 27 | print( 28 | "You need to pass a valid `--hf_token=...` to download private checkpoints." 29 | ) 30 | else: 31 | raise e 32 | 33 | 34 | if __name__ == "__main__": 35 | import argparse 36 | 37 | parser = argparse.ArgumentParser(description="Download tokenizer from HuggingFace.") 38 | parser.add_argument( 39 | "--repo_id", 40 | type=str, 41 | default="black-forest-labs/FLUX.1-dev", 42 | help="Repository ID to download from. default to Flux-dev model", 43 | ) 44 | parser.add_argument( 45 | "--ae_path", 46 | type=str, 47 | default="ae.safetensors", 48 | help="the autoencoder path relative to repo_id", 49 | ) 50 | parser.add_argument( 51 | "--hf_token", type=str, default=None, help="HuggingFace API token" 52 | ) 53 | parser.add_argument( 54 | "--local_dir", 55 | type=str, 56 | default="torchtitan/experiments/flux/assets/autoencoder/", 57 | help="local directory to save the autoencoder", 58 | ) 59 | 60 | args = parser.parse_args() 61 | hf_download(args.repo_id, args.ae_path, args.local_dir, args.hf_token) 62 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/tests/assets/cc12m_test/cc12m-train-0000.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/torchtitan/1cb1fa19033b42bbd11a64c9a227698949c7740b/torchtitan/experiments/flux/tests/assets/cc12m_test/cc12m-train-0000.tar -------------------------------------------------------------------------------- /torchtitan/experiments/flux/tests/assets/cc12m_test/pack_test_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | 10 | import webdataset as wds 11 | 12 | 13 | def pack_wds_dataset(tar_destination, source_folder, number_of_samples): 14 | """Pack cc12m dataset into a tar file using WebDataset format. 15 | This function is used to create the test file containing the cc12m dataset. 16 | 17 | Args: 18 | tar_destination (str): The path to the output tar file. 19 | source_folder (str): The path to the source folder containing the dataset. 20 | number_of_samples (int): The number of samples to pack. 21 | """ 22 | 23 | # Create a TarWriter object to write the dataset to a tar archive 24 | with wds.TarWriter(tar_destination) as tar: 25 | # Iterate over the files in the dataset directory 26 | samples_cnt = 0 27 | for root, dirs, files in os.walk(source_folder): 28 | # Iterate over the files in each subdirectory 29 | for filename in files: 30 | if not filename.endswith(".jpg") or filename.startswith("."): 31 | continue 32 | # Construct the path to the file 33 | img_path = os.path.join(root, filename) 34 | json_path = os.path.join(root, filename.replace(".jpg", ".json")) 35 | key = json.loads(open(json_path, "r").read())["key"] 36 | print(f"Saved Key to tar file: {key}") 37 | txt_path = os.path.join(root, filename.replace(".jpg", ".txt")) 38 | # Write the file and its metadata to the TarWriter 39 | with open(img_path, "rb") as img_file, open( 40 | txt_path, "r" 41 | ) as txt_file, open(json_path, "r") as json_file: 42 | save_dict = { 43 | "__key__": key, 44 | "txt": txt_file.read(), 45 | "jpg": img_file.read(), 46 | "json": json_file.read(), 47 | } 48 | tar.write(save_dict) 49 | 50 | samples_cnt += 1 51 | if samples_cnt >= number_of_samples: 52 | break 53 | 54 | 55 | if __name__ == "__main__": 56 | tar_destination = ( 57 | "torchtitan/experiments/flux/tests/assets/cc12m_test/cc12m-train-0000.tar" 58 | ) 59 | source_folder = "cc12m_test" 60 | number_of_samples = 32 61 | pack_wds_dataset(tar_destination, source_folder, number_of_samples) 62 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/tests/assets/t5-v1_1-xxl/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "t5-v1_1-xxl", 3 | "architectures": [ 4 | "T5ForConditionalGeneration" 5 | ], 6 | "d_ff": 10240, 7 | "d_kv": 64, 8 | "d_model": 4096, 9 | "decoder_start_token_id": 0, 10 | "dropout_rate": 0.1, 11 | "eos_token_id": 1, 12 | "feed_forward_proj": "gated-gelu", 13 | "initializer_factor": 1.0, 14 | "is_encoder_decoder": true, 15 | "layer_norm_epsilon": 1e-06, 16 | "model_type": "t5", 17 | "num_decoder_layers": 24, 18 | "num_heads": 64, 19 | "num_layers": 24, 20 | "output_past": true, 21 | "pad_token_id": 0, 22 | "relative_attention_num_buckets": 32, 23 | "tie_word_embeddings": false, 24 | "vocab_size": 32128 25 | } 26 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtitan.config_manager import ConfigManager 8 | from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader 9 | from torchtitan.tools.profiling import ( 10 | maybe_enable_memory_snapshot, 11 | maybe_enable_profiling, 12 | ) 13 | 14 | 15 | class TestFluxDataLoader: 16 | def test_load_dataset(self): 17 | for dataset_name in ["cc12m-test"]: 18 | self._test_flux_dataloader(dataset_name) 19 | 20 | def _test_flux_dataloader(self, dataset_name): 21 | batch_size = 4 22 | world_size = 4 23 | rank = 0 24 | 25 | num_steps = 10 26 | 27 | path = "torchtitan.experiments.flux.job_config" 28 | config_manager = ConfigManager() 29 | config = config_manager.parse_args( 30 | [ 31 | f"--experimental.custom_args_module={path}", 32 | # Profiling options 33 | # "--profiling.enable_profiling", 34 | # "--profiling.profile_freq", 35 | # "5", 36 | # "--profiling.enable_memory_snapshot", 37 | # "--profiling.save_memory_snapshot_folder", 38 | # "memory_snapshot_flux", 39 | "--training.img_size", 40 | str(256), 41 | "--training.dataset", 42 | dataset_name, 43 | "--training.batch_size", 44 | str(batch_size), 45 | "--training.seed", 46 | "0", 47 | "--training.classifer_free_guidance_prob", 48 | "0.1", 49 | "--encoder.t5_encoder", 50 | "google/t5-v1_1-small", 51 | "--encoder.clip_encoder", 52 | "openai/clip-vit-large-patch14", 53 | "--encoder.max_t5_encoding_len", 54 | "512", 55 | ] 56 | ) 57 | 58 | with maybe_enable_profiling( 59 | config, global_step=0 60 | ) as torch_profiler, maybe_enable_memory_snapshot( 61 | config, global_step=0 62 | ) as memory_profiler: 63 | dl = self._build_dataloader( 64 | config, 65 | world_size, 66 | rank, 67 | ) 68 | dl = iter(dl) 69 | 70 | for i in range(0, num_steps): 71 | input_data, labels = next(dl) 72 | if torch_profiler: 73 | torch_profiler.step() 74 | if memory_profiler: 75 | memory_profiler.step() 76 | 77 | assert len(input_data) == 2 # (clip_encodings, t5_encodings) 78 | assert labels.shape == (batch_size, 3, 256, 256) 79 | # assert input_data["clip_tokens"].shape[0] == batch_size 80 | # assert input_data["t5_tokens"].shape == (batch_size, 512, 512) 81 | 82 | if torch_profiler: 83 | torch_profiler.step() 84 | if memory_profiler: 85 | memory_profiler.step(exit_ctx=True) 86 | 87 | def _build_dataloader( 88 | self, 89 | job_config, 90 | world_size, 91 | rank, 92 | ): 93 | return build_flux_dataloader( 94 | dp_world_size=world_size, 95 | dp_rank=rank, 96 | job_config=job_config, 97 | tokenizer=None, 98 | infinite=True, 99 | ) 100 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/train_configs/debug_model.toml: -------------------------------------------------------------------------------- 1 | 2 | [job] 3 | dump_folder = "./outputs" 4 | description = "Flux debug model" 5 | print_args = false 6 | use_for_integration_test = true 7 | 8 | [profiling] 9 | enable_profiling = false 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 10 12 | enable_memory_snapshot = false 13 | save_memory_snapshot_folder = "memory_snapshot" 14 | 15 | [metrics] 16 | log_freq = 1 17 | disable_color_printing = false 18 | enable_tensorboard = false 19 | save_tb_folder = "tb" 20 | enable_wandb = false 21 | 22 | [model] 23 | name = "flux" 24 | flavor = "flux-debug" 25 | 26 | [optimizer] 27 | name = "AdamW" 28 | lr = 8e-4 29 | eps = 1e-8 30 | 31 | [lr_scheduler] 32 | warmup_steps = 1 # 10% warmup steps 33 | decay_ratio = 0.0 # no decay, stay stable during training 34 | 35 | [training] 36 | batch_size = 4 37 | seq_len = 512 38 | max_norm = 2.0 # grad norm clipping 39 | steps = 10 40 | compile = false 41 | dataset = "cc12m-test" 42 | classifer_free_guidance_prob = 0.1 43 | img_size = 256 44 | 45 | [encoder] 46 | t5_encoder = "google/t5-v1_1-xxl" 47 | clip_encoder = "openai/clip-vit-large-patch14" 48 | max_t5_encoding_len = 256 49 | autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image 50 | 51 | [eval] 52 | enable_classifer_free_guidance = true 53 | classifer_free_guidance_scale = 5.0 54 | denoising_steps = 4 55 | save_img_folder = "img" 56 | eval_freq = 5 57 | 58 | [parallelism] 59 | data_parallel_replicate_degree = 1 60 | data_parallel_shard_degree = -1 61 | 62 | [experimental] 63 | custom_args_module = "torchtitan.experiments.flux.job_config" 64 | 65 | [activation_checkpoint] 66 | mode = "full" 67 | 68 | [checkpoint] 69 | enable_checkpoint = false 70 | folder = "checkpoint" 71 | interval = 5 72 | model_weights_only = false 73 | export_dtype = "float32" 74 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 75 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/train_configs/flux_dev_model.toml: -------------------------------------------------------------------------------- 1 | 2 | [job] 3 | dump_folder = "./outputs" 4 | description = "Flux-dev model" 5 | print_args = false 6 | 7 | [profiling] 8 | enable_profiling = false 9 | save_traces_folder = "profile_trace" 10 | profile_freq = 10 11 | enable_memory_snapshot = false 12 | save_memory_snapshot_folder = "memory_snapshot" 13 | 14 | [metrics] 15 | log_freq = 100 16 | disable_color_printing = false 17 | enable_tensorboard = false 18 | save_tb_folder = "tb" 19 | enable_wandb = false 20 | 21 | [model] 22 | name = "flux" 23 | flavor = "flux-dev" 24 | 25 | [optimizer] 26 | name = "AdamW" 27 | lr = 1e-4 28 | eps = 1e-8 29 | 30 | [lr_scheduler] 31 | warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps 32 | decay_ratio = 0.0 # no decay 33 | 34 | [training] 35 | batch_size = 4 36 | seq_len = 512 37 | max_norm = 1.0 # grad norm clipping 38 | steps = 30_000 39 | compile = false 40 | dataset = "cc12m-wds" 41 | classifer_free_guidance_prob = 0.1 42 | img_size = 256 43 | 44 | [encoder] 45 | t5_encoder = "google/t5-v1_1-xxl" 46 | clip_encoder = "openai/clip-vit-large-patch14" 47 | max_t5_encoding_len = 512 48 | autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image 49 | 50 | [eval] 51 | enable_classifer_free_guidance = true 52 | classifer_free_guidance_scale = 5.0 53 | denoising_steps = 50 54 | save_img_folder = "img" 55 | eval_freq = 1000 56 | 57 | [parallelism] 58 | data_parallel_replicate_degree = 1 59 | data_parallel_shard_degree = -1 60 | 61 | [experimental] 62 | custom_args_module = "torchtitan.experiments.flux.job_config" 63 | 64 | [activation_checkpoint] 65 | mode = "full" 66 | 67 | [checkpoint] 68 | enable_checkpoint = false 69 | folder = "checkpoint" 70 | interval = 1_000 71 | model_weights_only = false 72 | export_dtype = "float32" 73 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 74 | -------------------------------------------------------------------------------- /torchtitan/experiments/flux/train_configs/flux_schnell_model.toml: -------------------------------------------------------------------------------- 1 | 2 | [job] 3 | dump_folder = "./outputs" 4 | description = "Flux-schnell model" 5 | print_args = false 6 | 7 | [profiling] 8 | enable_profiling = false 9 | save_traces_folder = "profile_trace" 10 | profile_freq = 10 11 | enable_memory_snapshot = false 12 | save_memory_snapshot_folder = "memory_snapshot" 13 | 14 | [metrics] 15 | log_freq = 100 16 | disable_color_printing = false 17 | enable_tensorboard = false 18 | save_tb_folder = "tb" 19 | enable_wandb = false 20 | 21 | [model] 22 | name = "flux" 23 | flavor = "flux-schnell" 24 | 25 | [optimizer] 26 | name = "AdamW" 27 | lr = 1e-4 28 | eps = 1e-8 29 | 30 | [lr_scheduler] 31 | warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps 32 | decay_ratio = 0.0 # no decay 33 | 34 | [training] 35 | batch_size = 4 36 | seq_len = 512 37 | max_norm = 1.0 # grad norm clipping 38 | steps = 30_000 39 | compile = false 40 | dataset = "cc12m-wds" 41 | classifer_free_guidance_prob = 0.1 42 | img_size = 256 43 | 44 | [encoder] 45 | t5_encoder = "google/t5-v1_1-xxl" 46 | clip_encoder = "openai/clip-vit-large-patch14" 47 | max_t5_encoding_len = 256 48 | autoencoder_path = "torchtitan/experiments/flux/assets/autoencoder/ae.safetensors" # Autoencoder to use for image 49 | 50 | [eval] 51 | enable_classifer_free_guidance = true 52 | classifer_free_guidance_scale = 5.0 53 | denoising_steps = 50 54 | save_img_folder = "img" 55 | eval_freq = 1000 56 | 57 | [parallelism] 58 | data_parallel_replicate_degree = 1 59 | data_parallel_shard_degree = -1 60 | 61 | [experimental] 62 | custom_args_module = "torchtitan.experiments.flux.job_config" 63 | 64 | [activation_checkpoint] 65 | mode = "full" 66 | 67 | [checkpoint] 68 | enable_checkpoint = false 69 | folder = "checkpoint" 70 | interval = 1_000 71 | model_weights_only = false 72 | export_dtype = "float32" 73 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 74 | -------------------------------------------------------------------------------- /torchtitan/experiments/kernels/triton_contiguous_group_gemm/cg_reference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | 10 | # Simple reference implementation for verification 11 | def pytorch_reference( 12 | inputs: torch.Tensor, 13 | expert_weights: torch.Tensor, 14 | expert_indices: torch.Tensor, 15 | group_size_m: int = 128, 16 | ) -> torch.Tensor: 17 | """ 18 | Reference implementation using PyTorch for verification. 19 | """ 20 | M_total, K = inputs.shape 21 | num_experts, N, _ = expert_weights.shape 22 | 23 | output = torch.empty((M_total, N), device=inputs.device, dtype=inputs.dtype) 24 | 25 | # Process each group 26 | for i in range(0, M_total, group_size_m): 27 | end_idx = min(i + group_size_m, M_total) 28 | 29 | # Get expert index for this group 30 | expert_idx = expert_indices[i].item() 31 | 32 | # Get expert weights 33 | expert_weight = expert_weights[expert_idx] 34 | 35 | # Compute output for this group 36 | output[i:end_idx] = torch.matmul(inputs[i:end_idx], expert_weight.T) 37 | 38 | return output 39 | -------------------------------------------------------------------------------- /torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .mg_grouped_gemm import grouped_gemm_forward 8 | from .tma_autotuning import ALIGN_SIZE_M 9 | 10 | __all__ = [ 11 | "grouped_gemm_forward", 12 | "ALIGN_SIZE_M", 13 | ] 14 | -------------------------------------------------------------------------------- /torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/reference_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-unsafe 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | 13 | # Configure logging 14 | logging.basicConfig( 15 | level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" 16 | ) 17 | 18 | 19 | def compute_reference_forward(x, w, m_sizes): 20 | """ 21 | Compute reference forward pass using PyTorch operations. 22 | 23 | Args: 24 | x (torch.Tensor): Input tensor of shape (M, K) 25 | w (torch.Tensor): Weight tensor of shape (N, K) 26 | m_sizes (torch.Tensor): Group sizes tensor of shape (G) 27 | 28 | Returns: 29 | torch.Tensor: Reference output tensor of shape (M, N) 30 | """ 31 | result = torch.zeros((x.shape[0], w.shape[0]), dtype=x.dtype, device=x.device) 32 | 33 | m_start = 0 34 | for g in range(len(m_sizes)): 35 | m_size = m_sizes[g].item() 36 | if m_size > 0: 37 | m_end = m_start + m_size 38 | 39 | # Extract group input 40 | x_g = x[m_start:m_end] 41 | 42 | # Compute group output: y_g = x_g @ w.T 43 | y_g = torch.matmul(x_g, w.T) 44 | 45 | # Store result 46 | result[m_start:m_end] = y_g 47 | 48 | # Update start index 49 | m_start = m_end 50 | 51 | return result 52 | 53 | 54 | def compute_reference_backward(x, w, m_sizes, grad_output): 55 | """ 56 | Compute reference backward pass using PyTorch autograd. 57 | 58 | Args: 59 | x (torch.Tensor): Input tensor of shape (M, K) 60 | w (torch.Tensor): Weight tensor of shape (N, K) 61 | m_sizes (torch.Tensor): Group sizes tensor of shape (G) 62 | grad_output (torch.Tensor): Gradient tensor of shape (M, N) 63 | 64 | Returns: 65 | tuple: (grad_x, grad_w) gradient tensors 66 | """ 67 | # Create autograd-enabled copies 68 | x_autograd = x.detach().clone().requires_grad_(True) 69 | w_autograd = w.detach().clone().requires_grad_(True) 70 | 71 | # Compute forward pass 72 | output = compute_reference_forward(x_autograd, w_autograd, m_sizes) 73 | 74 | # Backpropagate 75 | output.backward(grad_output) 76 | 77 | return x_autograd.grad, w_autograd.grad 78 | 79 | 80 | def analyze_tensor_differences(actual, expected, name): 81 | """ 82 | Analyze differences between actual and expected tensors. 83 | 84 | Args: 85 | actual (torch.Tensor): Actual tensor 86 | expected (torch.Tensor): Expected tensor 87 | name (str): Name of the tensor for logging 88 | 89 | Returns: 90 | bool: True if tensors are close enough 91 | """ 92 | rtol = 0.5 # Relative tolerance for float16 93 | atol = 0.5 # Absolute tolerance for float16 94 | 95 | # Analyze differences 96 | diff = (actual - expected).abs() 97 | max_idx = diff.argmax().item() 98 | idx = np.unravel_index(max_idx, actual.shape) 99 | max_diff = diff.max().item() 100 | 101 | logging.info(f"Largest {name} difference: {max_diff} at {idx}") 102 | logging.info(f"Values: {actual[idx].item()} vs {expected[idx].item()}") 103 | 104 | is_close = torch.allclose(actual, expected, rtol=rtol, atol=atol) 105 | 106 | if is_close: 107 | logging.info(f"✓ SUCCESS: {name} matches PyTorch reference") 108 | else: 109 | logging.error(f"✗ FAILURE: {name} mismatch detected") 110 | 111 | # Count zeros 112 | zeros_actual = (actual == 0).sum().item() 113 | zeros_expected = (expected == 0).sum().item() 114 | logging.info( 115 | f"Zeros in {name} (actual): {zeros_actual}/{actual.numel()} ({zeros_actual/actual.numel()*100:.2f}%)" 116 | ) 117 | logging.info( 118 | f"Zeros in {name} (expected): {zeros_expected}/{expected.numel()} ({zeros_expected/expected.numel()*100:.2f}%)" 119 | ) 120 | 121 | # Check for NaNs 122 | nan_actual = torch.isnan(actual).sum().item() 123 | if nan_actual > 0: 124 | logging.error(f"NaN values detected in {name}: {nan_actual}") 125 | 126 | return is_close 127 | -------------------------------------------------------------------------------- /torchtitan/experiments/kernels/triton_mg_group_gemm/torchao_pr/unit_test_forwards.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # pyre-unsafe 8 | import logging 9 | import unittest 10 | from typing import Tuple 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from mg_grouped_gemm import grouped_gemm_forward 16 | 17 | 18 | class TestMG_GroupedGEMM(unittest.TestCase): 19 | def setUp(self) -> None: 20 | torch.manual_seed(2020) 21 | 22 | def _run_grouped_gemm_test( 23 | self, 24 | shape: Tuple[int, int, int, int], 25 | device: torch.device, 26 | dtype: torch.dtype = torch.bfloat16, 27 | atol: float = 1e-5, 28 | rtol: float = 1.6e-2, 29 | ) -> None: 30 | G, M, N, K = shape 31 | # In M*G grouping, input is [M*G, K] and weights are [N*G, K] 32 | a = torch.randn(M * G, K, dtype=dtype, device=device) 33 | b = torch.randn(N * G, K, dtype=dtype, device=device) 34 | 35 | # Create equal-sized groups for simplicity 36 | m_size = M 37 | m_sizes = torch.full((G,), m_size, device=device, dtype=torch.int32) 38 | 39 | result = grouped_gemm_forward(a, b, m_sizes) 40 | self.assertTrue(result.shape == (M * G, N)) 41 | 42 | expected_result = torch.zeros(M * G, N, dtype=dtype, device=device) 43 | m_start = 0 44 | for g in range(G): 45 | m_end = m_start + m_sizes[g] 46 | b_slice = b[N * g : N * (g + 1), :] 47 | expected_result[m_start:m_end, :] = a[m_start:m_end, :] @ b_slice.T 48 | m_start = m_end 49 | 50 | # Convert result to match input dtype if needed 51 | result = result.to(dtype) 52 | torch.testing.assert_close(result, expected_result, atol=atol, rtol=rtol) 53 | 54 | def test_MG_grouped_gemm_bf16(self) -> None: 55 | for G in (1, 4, 16): 56 | for M in (128, 512, 1024): 57 | print(f"Testing BF16 M*G GroupGeMM with G={G}, M={M}") 58 | self._run_grouped_gemm_test( 59 | (G, M, 1024, 1024), 60 | torch.device("cuda"), 61 | dtype=torch.bfloat16, 62 | atol=1e-5, 63 | rtol=1.6e-2, 64 | ) 65 | 66 | def test_MG_grouped_gemm_deepseek_shapes(self) -> None: 67 | """Test with shapes from Deepseek model.""" 68 | deepseek_shapes = [ 69 | (4, 2048, 4096, 7168), # G, M, N, K 70 | (4, 2048, 7168, 2048), 71 | (8, 512, 4096, 7168), 72 | (8, 512, 7168, 2048), 73 | ] 74 | 75 | device = torch.device("cuda") 76 | 77 | for shape in deepseek_shapes: 78 | G, M, N, K = shape 79 | print(f"Testing BF16 M*G Deepseek shape: G={G}, M={M}, N={N}, K={K}") 80 | self._run_grouped_gemm_test( 81 | shape, device, dtype=torch.bfloat16, atol=1e-5, rtol=1.6e-2 82 | ) 83 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/README.md: -------------------------------------------------------------------------------- 1 | **The Llama 4 folder is still under development.** 2 | 3 | #### Issue tracking 4 | https://github.com/pytorch/torchtitan/issues/1118 5 | 6 | #### Available features 7 | - Llama 4 model (text-only), including a token-choice MoE architecture with efficient bfloat16 Grouped MM kernels and auxiliary-loss-free load balancing 8 | - FSDP, TP, PP, CP support 9 | - DCP checkpoint conversion scripts 10 | 11 | #### Download Llama 4 tokenizer 12 | ```bash 13 | # Llama 4 tokenizer.model 14 | python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E --tokenizer_path "" --hf_token=... 15 | ``` 16 | 17 | #### To be added 18 | - Modeling 19 | - alternative expert-choice MoE 20 | - multimodal support 21 | - Parallelism 22 | - Context Parallel support for FlexAttention and multimodal inputs 23 | - Expert Parallel support 24 | - torch.compile 25 | - for MoE layers 26 | - Quantization 27 | - efficient float8 Grouped MM kernels (from torchao) 28 | - Testing 29 | - perfomance and loss converging tests 30 | - CI integration 31 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtitan.components.loss import build_cross_entropy_loss 8 | from torchtitan.components.lr_scheduler import build_lr_schedulers 9 | from torchtitan.components.optimizer import build_optimizers 10 | from torchtitan.datasets.hf_datasets import build_hf_dataloader 11 | from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer 12 | from torchtitan.models.llama3 import pipeline_llama 13 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec 14 | 15 | from .infra.parallelize_llama import parallelize_llama 16 | from .model.args import TransformerModelArgs 17 | from .model.model import Transformer 18 | 19 | __all__ = [ 20 | "TransformerModelArgs", 21 | "Transformer", 22 | "llama4_configs", 23 | ] 24 | 25 | 26 | llama4_configs = { 27 | "debugmodel": TransformerModelArgs( 28 | dim=256, 29 | n_layers=6, 30 | n_heads=16, 31 | rope_theta=500000, 32 | ), 33 | "17bx16e": TransformerModelArgs( 34 | dim=5120, 35 | n_layers=48, 36 | n_heads=40, 37 | n_kv_heads=8, 38 | ffn_dim_multiplier=1.2, 39 | multiple_of=2048, 40 | rope_theta=500000, 41 | num_experts=16, 42 | interleave_moe_layer_step=1, 43 | ), 44 | "17bx128e": TransformerModelArgs( 45 | dim=5120, 46 | n_layers=48, 47 | n_heads=40, 48 | n_kv_heads=8, 49 | ffn_dim_multiplier=1.2, 50 | multiple_of=2048, 51 | rope_theta=500000, 52 | num_experts=128, 53 | ), 54 | "debugmodel_irope": TransformerModelArgs( 55 | dim=256, 56 | n_layers=6, 57 | n_heads=16, 58 | rope_theta=500000, 59 | every_n_layers_nope=4, 60 | fixed_attn_block_size=256, 61 | use_flex_attn=True, 62 | attn_mask_type="block_causal", 63 | ), 64 | "17bx16e_irope": TransformerModelArgs( 65 | dim=5120, 66 | n_layers=48, 67 | n_heads=40, 68 | n_kv_heads=8, 69 | ffn_dim_multiplier=1.2, 70 | multiple_of=2048, 71 | rope_theta=500000, 72 | num_experts=16, 73 | interleave_moe_layer_step=1, 74 | every_n_layers_nope=4, 75 | use_flex_attn=True, 76 | attn_mask_type="block_causal", 77 | ), 78 | "17bx128e_irope": TransformerModelArgs( 79 | dim=5120, 80 | n_layers=48, 81 | n_heads=40, 82 | n_kv_heads=8, 83 | ffn_dim_multiplier=1.2, 84 | multiple_of=2048, 85 | rope_theta=500000, 86 | num_experts=128, 87 | every_n_layers_nope=4, 88 | use_flex_attn=True, 89 | attn_mask_type="block_causal", 90 | ), 91 | } 92 | 93 | 94 | register_train_spec( 95 | TrainSpec( 96 | name="llama4", 97 | cls=Transformer, 98 | config=llama4_configs, 99 | parallelize_fn=parallelize_llama, 100 | pipelining_fn=pipeline_llama, 101 | build_optimizers_fn=build_optimizers, 102 | build_lr_schedulers_fn=build_lr_schedulers, 103 | build_dataloader_fn=build_hf_dataloader, 104 | build_tokenizer_fn=build_tiktoken_tokenizer, 105 | build_loss_fn=build_cross_entropy_loss, 106 | ) 107 | ) 108 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/scripts/REAME.md: -------------------------------------------------------------------------------- 1 | ## How to convert a Llama 4 checkpoint for use in torchtitan 2 | 3 | To continue training from an existing model checkpoint, the checkpoint must be in the DCP format expected by the checkpoint manager. 4 | This folder contains the scripts for converting officially released Llama 4 checkpoints into the expected DCP format, from original Meta format, or from HuggingFace format, using GPUs. 5 | 6 | #### Example usage 7 | 8 | From Meta format: 9 | ```bash 10 | CONFIG_FILE=../train_configs/llama4_16.toml ./convert_meta_to_dcp.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 11 | ``` 12 | 13 | 14 | From HuggingFace format: 15 | ```bash 16 | CONFIG_FILE=../train_configs/llama4_16.toml ./convert_hf_to_dcp_with_gpus.sh --checkpoint.enable_checkpoint --checkpoint.convert_path=[checkpoint_folder] --checkpoint.convert_load_every_n_ranks=8 17 | ``` 18 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | set -ex 10 | 11 | # use envs as local overrides for convenience 12 | # e.g. 13 | # LOG_RANK=0,1 NGPU=4 ./convert_hf_to_dcp_with_gpus.sh 14 | NGPU=${NGPU:-"8"} 15 | LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} 16 | CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} 17 | 18 | overrides="" 19 | if [ $# -ne 0 ]; then 20 | overrides="$*" 21 | fi 22 | 23 | PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ 24 | torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ 25 | --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ 26 | convert_hf_to_dcp_with_gpus.py --job.config_file ${CONFIG_FILE} $overrides 27 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/scripts/convert_meta_to_dcp_with_gpus.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the BSD-style license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | set -ex 9 | 10 | # use envs as local overrides for convenience 11 | # e.g. 12 | # LOG_RANK=0,1 NGPU=4 ./convert_meta_to_dcp_with_gpus.sh 13 | NGPU=${NGPU:-"8"} 14 | LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7} 15 | CONFIG_FILE=${CONFIG_FILE:-"../train_configs/llama4_17bx16e.toml"} 16 | 17 | overrides="" 18 | if [ $# -ne 0 ]; then 19 | overrides="$*" 20 | fi 21 | 22 | PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ 23 | torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ 24 | --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ 25 | convert_meta_to_dcp_with_gpus_meta.py --job.config_file ${CONFIG_FILE} $overrides 26 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/train_configs/debug_model.toml: -------------------------------------------------------------------------------- 1 | [job] 2 | dump_folder = "./outputs" 3 | description = "Llama 4 debug training" 4 | print_args = false 5 | use_for_integration_test = true 6 | 7 | [profiling] 8 | enable_profiling = false 9 | save_traces_folder = "profile_trace" 10 | profile_freq = 10 11 | enable_memory_snapshot = false 12 | save_memory_snapshot_folder = "memory_snapshot" 13 | 14 | [metrics] 15 | log_freq = 1 16 | disable_color_printing = false 17 | enable_tensorboard = false 18 | save_tb_folder = "tb" 19 | enable_wandb = false 20 | 21 | [model] 22 | name = "llama4" 23 | flavor = "debugmodel" 24 | # test tokenizer.model, for debug purpose only 25 | tokenizer_path = "./tests/assets/test_tiktoken.model" 26 | # converters = ["float8"] 27 | 28 | [optimizer] 29 | name = "AdamW" 30 | lr = 4e-3 31 | eps = 1e-15 32 | 33 | [lr_scheduler] 34 | warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps 35 | decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps 36 | decay_type = "linear" 37 | lr_min = 0.1 38 | 39 | [training] 40 | batch_size = 8 41 | seq_len = 2048 42 | max_norm = 1.0 # grad norm clipping 43 | steps = 10 44 | compile = false 45 | dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) 46 | 47 | [parallelism] 48 | data_parallel_replicate_degree = 1 49 | data_parallel_shard_degree = -1 50 | fsdp_reshard_after_forward = "default" # default / never / always 51 | tensor_parallel_degree = 1 52 | enable_async_tensor_parallel = false 53 | pipeline_parallel_degree = 1 54 | context_parallel_degree = 1 55 | 56 | [checkpoint] 57 | enable_checkpoint = false 58 | folder = "checkpoint" 59 | interval = 10 60 | model_weights_only = false 61 | export_dtype = "float32" 62 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 63 | 64 | [activation_checkpoint] 65 | mode = "none" # ["none", "selective", "full"] 66 | selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy 67 | 68 | [float8] 69 | enable_fsdp_float8_all_gather = false 70 | precompute_float8_dynamic_scale_for_fsdp = false 71 | filter_fqns = ["output", "router.gate"] 72 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml: -------------------------------------------------------------------------------- 1 | # TODO: this toml config is still under development 2 | 3 | [job] 4 | dump_folder = "./outputs" 5 | description = "Llama 4 Maverick 17Bx128E training" 6 | 7 | [profiling] 8 | enable_profiling = false 9 | save_traces_folder = "profile_trace" 10 | profile_freq = 100 11 | 12 | [metrics] 13 | log_freq = 10 14 | enable_tensorboard = false 15 | save_tb_folder = "tb" 16 | 17 | [model] 18 | name = "llama4" 19 | flavor = "17bx128e" 20 | tokenizer_path = "./assets/tokenizer/tokenizer.model" 21 | # converters = ["float8"] 22 | 23 | [optimizer] 24 | name = "AdamW" 25 | lr = 4e-3 26 | eps = 1e-15 27 | 28 | [lr_scheduler] 29 | warmup_steps = 600 30 | lr_min = 0.1 31 | 32 | [training] 33 | batch_size = 1 34 | seq_len = 8192 35 | max_norm = 1.0 # grad norm clipping 36 | steps = 3000 37 | compile = false 38 | dataset = "c4" 39 | 40 | [parallelism] 41 | data_parallel_replicate_degree = 1 42 | data_parallel_shard_degree = -1 43 | tensor_parallel_degree = 8 44 | enable_async_tensor_parallel = false 45 | pipeline_parallel_degree = 4 46 | # pipeline_parallel_schedule = "interleaved1f1b" 47 | # pipeline_parallel_microbatches = 2 48 | context_parallel_degree = 1 49 | 50 | [checkpoint] 51 | enable_checkpoint = false 52 | folder = "checkpoint" 53 | interval = 500 54 | model_weights_only = false 55 | export_dtype = "float32" 56 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 57 | 58 | [activation_checkpoint] 59 | mode = "full" # ["none", "selective", "full"] 60 | 61 | [float8] 62 | enable_fsdp_float8_all_gather = false 63 | precompute_float8_dynamic_scale_for_fsdp = false 64 | filter_fqns = ["output", "router.gate"] 65 | -------------------------------------------------------------------------------- /torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml: -------------------------------------------------------------------------------- 1 | # NOTE: this toml config is a preset for 64 H100 GPUs. 2 | 3 | [job] 4 | dump_folder = "./outputs" 5 | description = "Llama 4 Scout 17Bx16E training" 6 | 7 | [profiling] 8 | enable_profiling = false 9 | save_traces_folder = "profile_trace" 10 | profile_freq = 100 11 | 12 | [metrics] 13 | log_freq = 10 14 | enable_tensorboard = false 15 | save_tb_folder = "tb" 16 | 17 | [model] 18 | name = "llama4" 19 | flavor = "17bx16e" 20 | tokenizer_path = "./assets/tokenizer/tokenizer.model" 21 | # converters = ["float8"] 22 | 23 | [optimizer] 24 | name = "AdamW" 25 | lr = 4e-3 26 | eps = 1e-15 27 | 28 | [lr_scheduler] 29 | warmup_steps = 600 30 | lr_min = 0.1 31 | 32 | [training] 33 | batch_size = 8 34 | seq_len = 8192 35 | max_norm = 1.0 # grad norm clipping 36 | steps = 3000 37 | compile = false 38 | dataset = "c4" 39 | 40 | [parallelism] 41 | data_parallel_replicate_degree = 1 42 | data_parallel_shard_degree = -1 43 | tensor_parallel_degree = 8 44 | enable_async_tensor_parallel = false 45 | pipeline_parallel_degree = 1 46 | context_parallel_degree = 1 47 | 48 | [checkpoint] 49 | enable_checkpoint = false 50 | folder = "checkpoint" 51 | interval = 500 52 | model_weights_only = false 53 | export_dtype = "float32" 54 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 55 | 56 | [activation_checkpoint] 57 | mode = "full" # ["none", "selective", "full"] 58 | 59 | [float8] 60 | enable_fsdp_float8_all_gather = false 61 | precompute_float8_dynamic_scale_for_fsdp = false 62 | filter_fqns = ["output", "router.gate"] 63 | -------------------------------------------------------------------------------- /torchtitan/experiments/multimodal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from mm_dataset import build_mm_dataloader 8 | 9 | from torchtitan.components.loss import build_cross_entropy_loss 10 | from torchtitan.components.lr_scheduler import build_lr_schedulers 11 | from torchtitan.components.optimizer import build_optimizers 12 | from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer 13 | from torchtitan.models.llama3 import parallelize_llama, pipeline_llama 14 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec 15 | 16 | from .model import ModelArgs, MultimodalDecoder, VisionEncoder 17 | 18 | __all__ = ["VisionEncoder", "ModelArgs", "MultimodalDecoder"] 19 | 20 | llama4_mm_configs = { 21 | # TODO: add configs for llama4 multimodal 22 | } 23 | 24 | register_train_spec( 25 | TrainSpec( 26 | name="llama4_multimodal", 27 | cls=MultimodalDecoder, 28 | config=llama4_mm_configs, 29 | parallelize_fn=parallelize_llama, 30 | pipelining_fn=pipeline_llama, 31 | build_optimizers_fn=build_optimizers, 32 | build_lr_schedulers_fn=build_lr_schedulers, 33 | build_dataloader_fn=build_mm_dataloader, 34 | build_tokenizer_fn=build_tiktoken_tokenizer, 35 | build_loss_fn=build_cross_entropy_loss, 36 | ) 37 | ) 38 | -------------------------------------------------------------------------------- /torchtitan/experiments/multimodal/check_padding_mm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import click 7 | 8 | from mm_dataset import build_mm_dataloader 9 | from tokenizer.tiktoken import build_tiktoken_tokenizer 10 | 11 | from torchtitan.config_manager import ConfigManager 12 | from torchtitan.tools.logging import init_logger 13 | 14 | 15 | @click.command() 16 | @click.option("--dataset", default="OBELICS") 17 | @click.option("--batch-size", default=4) 18 | @click.option("--seq-len", default=4096) 19 | @click.option("--tokenizer-path", required=True) 20 | @click.option("--dp-rank", default=0) 21 | @click.option("--dp-world-size", default=2) 22 | @click.option("--batch-number", default=4) 23 | def main( 24 | dataset: str, 25 | batch_size: int, 26 | seq_len: int, 27 | tokenizer_path: str, 28 | dp_rank: int, 29 | dp_world_size: int, 30 | batch_number: int, 31 | ): 32 | init_logger() 33 | config_manager = ConfigManager() 34 | config = config_manager.parse_args( 35 | [ 36 | "--training.dataset", 37 | dataset, 38 | "--training.batch_size", 39 | str(batch_size), 40 | "--training.seq_len", 41 | str(seq_len), 42 | "--model.tokenizer_path", 43 | tokenizer_path, 44 | ] 45 | ) 46 | tokenizer = build_tiktoken_tokenizer(config) 47 | dl = build_mm_dataloader( 48 | dp_world_size=dp_world_size, 49 | dp_rank=dp_rank, 50 | tokenizer=tokenizer, 51 | job_config=config, 52 | ) 53 | dl_iter = iter(dl) 54 | 55 | for _ in range(batch_number): 56 | batch = next(dl_iter) 57 | 58 | # Analyze Batch 59 | # input_ids 60 | total_input_ids = batch["input_ids"].shape[0] * batch["input_ids"].shape[1] 61 | total_non_padding_tokens = total_input_ids - int( 62 | (batch["input_ids"] == 128004).sum() 63 | ) 64 | total_padding_tokens = total_input_ids - total_non_padding_tokens 65 | print(f"Padding tokens in each sample: {(batch['input_ids'] == 128004).sum(dim=1)}") 66 | print( 67 | f"Unpadded tokens: {total_non_padding_tokens}, Total tokens in batch: {total_input_ids}" 68 | ) 69 | print( 70 | f"Padded text tokens: {total_padding_tokens}, {(total_padding_tokens) / total_input_ids * 100:.2f}%" 71 | ) 72 | print(80 * "#") 73 | # Images 74 | padded_images = 0 75 | padded_tiles = 0 76 | for sample in batch["encoder_input"]["images"]: 77 | for image in sample: 78 | if int(image.sum()) == 0: 79 | padded_images += 1 80 | for tile in image: 81 | if int(tile.sum()) == 0: 82 | padded_tiles += 1 83 | 84 | total_images = ( 85 | batch["encoder_input"]["images"].shape[0] 86 | * batch["encoder_input"]["images"].shape[1] 87 | ) 88 | 89 | print( 90 | f"Unpadded images: {total_images - padded_images}, Total images in batch: {total_images}" 91 | ) 92 | print( 93 | f'Padded images: {padded_images}, {padded_images / total_images * 100:.2f}% (Each image with shape {list(batch["encoder_input"]["images"][0, 0].shape)})' # noqa: B950 94 | ) 95 | print(80 * "#") 96 | # Tiles 97 | total_number_of_tiles = total_images * batch["encoder_input"]["images"].shape[2] 98 | 99 | print( 100 | f"Unpadded number of tiles: {total_number_of_tiles - padded_tiles}, Total number of tiles: {total_number_of_tiles}" 101 | ) 102 | print( 103 | f'Padded tiles: {padded_tiles}, {padded_tiles / total_number_of_tiles * 100:.2f}% (Each with shape {list(batch["encoder_input"]["images"][0, 0, 0].shape)})' # noqa: B950 104 | ) 105 | print(80 * "#") 106 | 107 | 108 | if __name__ == "__main__": 109 | main() 110 | -------------------------------------------------------------------------------- /torchtitan/experiments/multimodal/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | -------------------------------------------------------------------------------- /torchtitan/experiments/multimodal/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtitan/experiments/multimodal/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from typing import Optional, Union 10 | 11 | import torch 12 | from torch import nn 13 | 14 | 15 | def fixed_init_tensor( 16 | shape: torch.Size, 17 | min_val: Union[float, int] = 0.0, 18 | max_val: Union[float, int] = 1.0, 19 | nonlinear: bool = False, 20 | dtype: torch.dtype = torch.float, 21 | ): 22 | """ 23 | Utility for generating deterministic tensors of a given shape. In general stuff 24 | like torch.ones, torch.eye, etc can result in trivial outputs. This utility 25 | generates a range tensor [min_val, max_val) of a specified dtype, applies 26 | a sine function if nonlinear=True, then reshapes to the appropriate shape. 27 | """ 28 | n_elements = math.prod(shape) 29 | step_size = (max_val - min_val) / n_elements 30 | x = torch.arange(min_val, max_val, step_size, dtype=dtype) 31 | x = x.reshape(shape) 32 | if nonlinear: 33 | return torch.sin(x) 34 | return x 35 | 36 | 37 | @torch.no_grad 38 | def fixed_init_model( 39 | model: nn.Module, 40 | min_val: Union[float, int] = 0.0, 41 | max_val: Union[float, int] = 1.0, 42 | nonlinear: bool = False, 43 | dtype: Optional[torch.dtype] = None, 44 | ): 45 | """ 46 | This utility initializes all parameters of a model deterministically using the 47 | function fixed_init_tensor above. See that docstring for details of each parameter. 48 | """ 49 | for _, param in model.named_parameters(): 50 | param.copy_( 51 | fixed_init_tensor( 52 | param.shape, 53 | min_val=min_val, 54 | max_val=max_val, 55 | nonlinear=nonlinear, 56 | dtype=param.dtype if dtype is None else dtype, 57 | ) 58 | ) 59 | -------------------------------------------------------------------------------- /torchtitan/experiments/simple_fsdp/README.md: -------------------------------------------------------------------------------- 1 | ## SimpleFSDP 2 | 3 | [![integration tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain) 4 | [![arXiv](https://img.shields.io/badge/arXiv-2411.00284-b31b1b.svg)](https://arxiv.org/abs/2411.00284) 5 | 6 | 7 | This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. 8 | 9 | ### Enable SimpleFSDP Training 10 | 11 | ```bash 12 | CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh --model.name llama3_simple_fsdp --training.compile 13 | ``` 14 | 15 | ### Composability Support 16 | 17 | Some of the features require the updates from PyTorch, with which we are working on providing composability support for the following features: 18 | 19 | | Feature | Support | 20 | | :--------: | :--------: | 21 | |Meta Initialization| ✅ | 22 | |Activation Checkpointing| ✅ | 23 | |Mixed Precision Training| ✅ | 24 | |Tensor Parallelism| ✅ | 25 | |Context Parallelism| ✅ | 26 | |Pipeline Parallelism| ✅ | 27 | |Distributed Checkpointing| 🚧 | 28 | |Float8 Training| 🚧 | 29 | 30 | 31 | ### Citation 32 | 33 | If you find SimpleFSDP useful, please kindly consider citing the following paper: 34 | 35 | ```latex 36 | @article{zhang2024simplefsdp, 37 | title={SimpleFSDP: Simpler Fully Sharded Data Parallel with torch. compile}, 38 | author={Zhang, Ruisi and Liu, Tianyu and Feng, Will and Gu, Andrew and Purandare, Sanket and Liang, Wanchao and Massa, Francisco}, 39 | journal={arXiv preprint arXiv:2411.00284}, 40 | year={2024} 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /torchtitan/experiments/simple_fsdp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | from torchtitan.components.loss import build_cross_entropy_loss 10 | from torchtitan.components.lr_scheduler import build_lr_schedulers 11 | from torchtitan.components.optimizer import build_optimizers 12 | from torchtitan.datasets.hf_datasets import build_hf_dataloader 13 | from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer 14 | from torchtitan.models.llama3 import llama3_configs, pipeline_llama 15 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec 16 | 17 | from .model import SimpleFSDPTransformer 18 | from .parallelize_llama import parallelize_llama 19 | 20 | register_train_spec( 21 | TrainSpec( 22 | name="llama3_simple_fsdp", 23 | cls=SimpleFSDPTransformer, 24 | config=llama3_configs, 25 | parallelize_fn=parallelize_llama, 26 | pipelining_fn=pipeline_llama, 27 | build_optimizers_fn=build_optimizers, 28 | build_lr_schedulers_fn=build_lr_schedulers, 29 | build_dataloader_fn=build_hf_dataloader, 30 | build_tokenizer_fn=build_tiktoken_tokenizer, 31 | build_loss_fn=build_cross_entropy_loss, 32 | ) 33 | ) 34 | -------------------------------------------------------------------------------- /torchtitan/experiments/simple_fsdp/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtitan.models.llama3 import Transformer, TransformerModelArgs 8 | from .simple_fsdp import disable_data_parallel 9 | 10 | 11 | class SimpleFSDPTransformer(Transformer): 12 | def __init__(self, model_args: TransformerModelArgs): 13 | super().__init__(model_args) 14 | self.init_weights() 15 | 16 | def init_weights(self, *args, **kwargs): 17 | with disable_data_parallel(): 18 | super().init_weights(*args, **kwargs) 19 | -------------------------------------------------------------------------------- /torchtitan/experiments/simple_fsdp/parallelize_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from torch.distributed import DeviceMesh 11 | 12 | from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP 13 | from torchtitan.distributed import ParallelDims 14 | from torchtitan.models.llama3.parallelize_llama import apply_ac, apply_tp 15 | from torchtitan.tools.logging import logger 16 | 17 | from .simple_fsdp import data_parallel, MixedPrecisionPolicy 18 | 19 | 20 | def parallelize_llama( 21 | model: nn.Module, 22 | world_mesh: DeviceMesh, 23 | parallel_dims: ParallelDims, 24 | job_config: JobConfig, 25 | ): 26 | """ 27 | Apply tensor parallelism, activation checkpointing, torch.compile, and data 28 | parallelism to the model. 29 | 30 | NOTE: The passed-in model preferably should be on meta device. Otherwise, 31 | the model must fit on GPU or CPU memory. 32 | """ 33 | if parallel_dims.tp_enabled: 34 | if ( 35 | job_config.parallelism.enable_async_tensor_parallel 36 | and not job_config.training.compile 37 | ): 38 | raise RuntimeError("Async TP requires --training.compile") 39 | 40 | enable_float8_linear = "float8" in job_config.model.converters 41 | float8_is_rowwise = job_config.float8.recipe_name in ( 42 | "rowwise", 43 | "rowwise_with_gw_hp", 44 | ) 45 | 46 | # For now, float8 all-gather with TP is only supported for tensorwise 47 | # float8 scaling recipes. For rowwise recipes, we use regular TP and 48 | # all-gather happens in high precision. 49 | enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise 50 | 51 | tp_mesh = world_mesh["tp"] 52 | apply_tp( 53 | model, 54 | world_mesh["tp"], 55 | loss_parallel=parallel_dims.loss_parallel_enabled, 56 | enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, 57 | enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, 58 | ) 59 | 60 | if job_config.activation_checkpoint.mode != "none": 61 | apply_ac(model, job_config.activation_checkpoint) 62 | 63 | # apply data parallel 64 | if ( 65 | parallel_dims.dp_replicate_enabled 66 | or parallel_dims.dp_shard_enabled 67 | or parallel_dims.cp_enabled 68 | ): 69 | if parallel_dims.dp_replicate_enabled: 70 | if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: 71 | dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") 72 | dp_mode = "hybrid_shard" 73 | else: 74 | dp_mesh_dim_names = ("dp_replicate",) 75 | dp_mode = "replicate" 76 | else: 77 | dp_mesh_dim_names = ("dp_shard_cp",) 78 | dp_mode = "fully_shard" 79 | 80 | mp_policy = MixedPrecisionPolicy( 81 | param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], 82 | reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], 83 | ) 84 | 85 | model = data_parallel( 86 | model, 87 | world_mesh[tuple(dp_mesh_dim_names)], 88 | mode=dp_mode, 89 | ac_mode=job_config.activation_checkpoint.mode, 90 | mp_policy=mp_policy, 91 | tp_mesh=tp_mesh if parallel_dims.tp_enabled else None, 92 | ) 93 | logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) 94 | 95 | if job_config.training.compile: 96 | torch._inductor.config.reorder_for_peak_memory = False 97 | model = torch.compile(model, fullgraph=True) 98 | 99 | return model 100 | -------------------------------------------------------------------------------- /torchtitan/experiments/simple_fsdp/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /torchtitan/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # Import the built-in models here so that the corresponding register_model_spec() 9 | # will be called. 10 | import torchtitan.models.llama3 # noqa: F401 11 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | from torchtitan.components.loss import build_cross_entropy_loss 10 | from torchtitan.components.lr_scheduler import build_lr_schedulers 11 | from torchtitan.components.optimizer import build_optimizers 12 | from torchtitan.datasets.hf_datasets import build_hf_dataloader 13 | from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer 14 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec 15 | 16 | from .model import Transformer, TransformerModelArgs 17 | from .parallelize_llama import parallelize_llama 18 | from .pipeline_llama import pipeline_llama 19 | 20 | __all__ = [ 21 | "parallelize_llama", 22 | "pipeline_llama", 23 | "TransformerModelArgs", 24 | "Transformer", 25 | "llama3_configs", 26 | ] 27 | 28 | 29 | llama3_configs = { 30 | "debugmodel": TransformerModelArgs( 31 | dim=256, n_layers=6, n_heads=16, rope_theta=500000 32 | ), 33 | "debugmodel_flex_attn": TransformerModelArgs( 34 | dim=256, 35 | n_layers=6, 36 | n_heads=16, 37 | rope_theta=500000, 38 | use_flex_attn=True, 39 | attn_mask_type="block_causal", 40 | ), 41 | "8B": TransformerModelArgs( 42 | dim=4096, 43 | n_layers=32, 44 | n_heads=32, 45 | n_kv_heads=8, 46 | ffn_dim_multiplier=1.3, 47 | multiple_of=1024, 48 | rope_theta=500000, 49 | ), 50 | "70B": TransformerModelArgs( 51 | dim=8192, 52 | n_layers=80, 53 | n_heads=64, 54 | n_kv_heads=8, 55 | ffn_dim_multiplier=1.3, 56 | multiple_of=4096, 57 | rope_theta=500000, 58 | ), 59 | "405B": TransformerModelArgs( 60 | dim=16384, 61 | n_layers=126, 62 | n_heads=128, 63 | n_kv_heads=8, 64 | ffn_dim_multiplier=1.2, 65 | multiple_of=4096, 66 | rope_theta=500000, 67 | ), 68 | } 69 | 70 | 71 | register_train_spec( 72 | TrainSpec( 73 | name="llama3", 74 | cls=Transformer, 75 | config=llama3_configs, 76 | parallelize_fn=parallelize_llama, 77 | pipelining_fn=pipeline_llama, 78 | build_optimizers_fn=build_optimizers, 79 | build_lr_schedulers_fn=build_lr_schedulers, 80 | build_dataloader_fn=build_hf_dataloader, 81 | build_tokenizer_fn=build_tiktoken_tokenizer, 82 | build_loss_fn=build_cross_entropy_loss, 83 | ) 84 | ) 85 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/debug_model.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | 3 | [job] 4 | dump_folder = "./outputs" 5 | description = "Llama 3 debug training" 6 | print_args = false 7 | use_for_integration_test = true 8 | 9 | [profiling] 10 | enable_profiling = false 11 | save_traces_folder = "profile_trace" 12 | profile_freq = 10 13 | enable_memory_snapshot = false 14 | save_memory_snapshot_folder = "memory_snapshot" 15 | 16 | [metrics] 17 | log_freq = 1 18 | disable_color_printing = false 19 | enable_tensorboard = false 20 | save_tb_folder = "tb" 21 | enable_wandb = false 22 | 23 | [model] 24 | name = "llama3" 25 | flavor = "debugmodel" 26 | # test tokenizer.model, for debug purpose only 27 | tokenizer_path = "./tests/assets/test_tiktoken.model" 28 | # converters = ["float8"] 29 | 30 | [optimizer] 31 | name = "AdamW" 32 | lr = 8e-4 33 | eps = 1e-8 34 | 35 | [lr_scheduler] 36 | warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps 37 | decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps 38 | decay_type = "linear" 39 | lr_min = 0.0 40 | 41 | [training] 42 | batch_size = 8 43 | seq_len = 2048 44 | max_norm = 1.0 # grad norm clipping 45 | steps = 10 46 | compile = false 47 | dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) 48 | 49 | [parallelism] 50 | data_parallel_replicate_degree = 1 51 | data_parallel_shard_degree = -1 52 | fsdp_reshard_after_forward = "default" # default / never / always 53 | tensor_parallel_degree = 1 54 | enable_async_tensor_parallel = false 55 | pipeline_parallel_degree = 1 56 | context_parallel_degree = 1 57 | 58 | [checkpoint] 59 | enable_checkpoint = false 60 | folder = "checkpoint" 61 | interval = 10 62 | model_weights_only = false 63 | export_dtype = "float32" 64 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 65 | 66 | [activation_checkpoint] 67 | mode = "selective" # ["none", "selective", "full"] 68 | selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy 69 | 70 | [float8] 71 | enable_fsdp_float8_all_gather = false 72 | precompute_float8_dynamic_scale_for_fsdp = false 73 | filter_fqns = ["output"] 74 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/llama3_405b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 128 H100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 405B training" 7 | 8 | [profiling] 9 | enable_profiling = true 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 10 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | 18 | [model] 19 | name = "llama3" 20 | flavor = "405B" 21 | tokenizer_path = "./assets/tokenizer/original/tokenizer.model" 22 | converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 8e-5 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps 31 | 32 | [training] 33 | batch_size = 2 34 | seq_len = 8192 35 | max_norm = 1.0 # grad norm clipping 36 | steps = 3000 37 | compile = true 38 | dataset = "c4" 39 | 40 | [parallelism] 41 | data_parallel_replicate_degree = 1 42 | data_parallel_shard_degree = -1 43 | tensor_parallel_degree = 8 # 8-way TP 44 | enable_async_tensor_parallel = true 45 | pipeline_parallel_degree = 1 46 | context_parallel_degree = 1 47 | 48 | [checkpoint] 49 | enable_checkpoint = false 50 | folder = "checkpoint" 51 | interval = 500 52 | model_weights_only = false 53 | export_dtype = "float32" 54 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 55 | 56 | [activation_checkpoint] 57 | mode = "full" # ["none", "selective", "full"] 58 | 59 | [float8] 60 | enable_fsdp_float8_all_gather = true 61 | precompute_float8_dynamic_scale_for_fsdp = true 62 | filter_fqns = ["output"] 63 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/llama3_70b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 64 A100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 70B training" 7 | 8 | [profiling] 9 | enable_profiling = true 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 10 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | 18 | [model] 19 | name = "llama3" 20 | flavor = "70B" 21 | tokenizer_path = "./assets/tokenizer/original/tokenizer.model" 22 | # converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 1.5e-4 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps 31 | 32 | [training] 33 | batch_size = 8 34 | seq_len = 8192 35 | max_norm = 1.0 # grad norm clipping 36 | steps = 1000 37 | compile = false 38 | dataset = "c4" 39 | 40 | [parallelism] 41 | data_parallel_replicate_degree = 1 42 | data_parallel_shard_degree = -1 43 | tensor_parallel_degree = 8 # 8-way TP 44 | pipeline_parallel_degree = 1 45 | context_parallel_degree = 1 46 | 47 | [checkpoint] 48 | enable_checkpoint = false 49 | folder = "checkpoint" 50 | interval = 500 51 | model_weights_only = false 52 | export_dtype = "float32" 53 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 54 | 55 | [activation_checkpoint] 56 | mode = "full" 57 | 58 | [float8] 59 | enable_fsdp_float8_all_gather = false 60 | precompute_float8_dynamic_scale_for_fsdp = false 61 | filter_fqns = ["output"] 62 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/llama3_8b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 64 A100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 8B training" 7 | 8 | [profiling] 9 | enable_profiling = true 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 10 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | 18 | [model] 19 | name = "llama3" 20 | flavor = "8B" 21 | tokenizer_path = "./assets/tokenizer/original/tokenizer.model" 22 | # converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 3e-4 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 200 # lr scheduler warm up 31 | 32 | [training] 33 | batch_size = 1 34 | seq_len = 8192 35 | max_norm = 1.0 # grad norm clipping 36 | steps = 1000 37 | compile = false 38 | dataset = "c4" 39 | 40 | [parallelism] 41 | data_parallel_replicate_degree = 1 42 | data_parallel_shard_degree = -1 43 | tensor_parallel_degree = 1 44 | pipeline_parallel_degree = 1 45 | context_parallel_degree = 1 46 | 47 | [checkpoint] 48 | enable_checkpoint = false 49 | folder = "checkpoint" 50 | interval = 500 51 | model_weights_only = false 52 | export_dtype = "float32" 53 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 54 | 55 | [activation_checkpoint] 56 | mode = "selective" # ["none", "selective", "full"] 57 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy 58 | 59 | [float8] 60 | enable_fsdp_float8_all_gather = false 61 | precompute_float8_dynamic_scale_for_fsdp = false 62 | filter_fqns = ["output"] 63 | -------------------------------------------------------------------------------- /torchtitan/protocols/model_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Dict, List, Protocol, Union 7 | 8 | import torch.nn as nn 9 | 10 | from torchtitan.config_manager import JobConfig 11 | from torchtitan.distributed import ParallelDims 12 | from torchtitan.tools.logging import logger 13 | 14 | 15 | class ModelConverter(Protocol): 16 | """General model converter interface. 17 | 18 | A model converter is applying a modification to PyTorch model. 19 | Typical use cases are: 20 | - Quantization: using QAT, FP8, ... specialized linear layers; 21 | - Fused optimized layers (e.g. flash-attention, norms, ...) 22 | """ 23 | 24 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 25 | ... 26 | 27 | def convert(self, model: nn.Module): 28 | """Inplace convertion of the model.""" 29 | ... 30 | 31 | def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): 32 | """Post-optimizer (optional) hook (e.g. compute weights statistics).""" 33 | ... 34 | 35 | 36 | _registry_model_converter_cls: Dict[str, type[ModelConverter]] = {} 37 | """Registry of model converter classes. 38 | """ 39 | 40 | 41 | def register_model_converter(converter_cls: type[ModelConverter], name: str): 42 | """Register a model converter class. 43 | 44 | A registered model converter can be applied on any model 45 | using the `model.converters` config parameter. 46 | """ 47 | assert ( 48 | name not in _registry_model_converter_cls 49 | ), f"A model converter '{name}' is already registered." 50 | _registry_model_converter_cls[name] = converter_cls 51 | 52 | 53 | class ModelConvertersContainer(ModelConverter): 54 | """Model converters sequential container. 55 | 56 | The class build the sequence of model converters defined in `model.converters` 57 | job config, and apply them to the model sequentially. 58 | """ 59 | 60 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 61 | converter_classes = [ 62 | _registry_model_converter_cls[name] for name in job_config.model.converters 63 | ] 64 | self.converters = [ 65 | mh_cls(job_config, parallel_dims) for mh_cls in converter_classes 66 | ] 67 | self.print_after_conversion = job_config.model.print_after_conversion 68 | 69 | def convert(self, model: nn.Module): 70 | for mh in self.converters: 71 | mh.convert(model) 72 | if self.print_after_conversion: 73 | logger.info(f"Model definion after conversion:\n\n{model}\n\n") 74 | 75 | def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): 76 | for mh in self.converters: 77 | mh.post_optimizer_hook(model) 78 | 79 | 80 | def build_model_converters( 81 | job_config: JobConfig, parallel_dims: ParallelDims 82 | ) -> ModelConvertersContainer: 83 | """Build the collection of model converters to apply to the model.""" 84 | return ModelConvertersContainer(job_config, parallel_dims) 85 | -------------------------------------------------------------------------------- /torchtitan/protocols/train_spec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | from abc import abstractmethod 10 | from collections.abc import Callable, Mapping 11 | from dataclasses import dataclass 12 | from typing import Protocol, TypeAlias 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.distributed.pipelining.schedules import _PipelineSchedule 17 | 18 | from torchtitan.components.dataloader import BaseDataLoader 19 | from torchtitan.components.ft import FTManager 20 | from torchtitan.components.loss import LossFunction 21 | from torchtitan.components.lr_scheduler import LRSchedulersContainer 22 | from torchtitan.components.metrics import MetricsProcessor 23 | from torchtitan.components.optimizer import OptimizersContainer 24 | from torchtitan.components.tokenizer import Tokenizer 25 | from torchtitan.config_manager import JobConfig 26 | 27 | DeviceType = int | str | torch.device 28 | 29 | 30 | @dataclass 31 | class BaseModelArgs: 32 | """All ModelArgs should inherit from this class. 33 | 34 | The only usage of this class is type checking but allows us to extend common 35 | arguments to all models in the future. 36 | """ 37 | 38 | _enforced: str = "This field is used to enforce all fields have defaults." 39 | 40 | @abstractmethod 41 | def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: 42 | pass 43 | 44 | @abstractmethod 45 | def get_nparams_and_flops( 46 | self, model: nn.Module, seq_len: int 47 | ) -> tuple[int, float]: 48 | pass 49 | 50 | 51 | class ModelProtocol(Protocol): 52 | """Defines the interface for a model class. 53 | 54 | This is used to enforce that all model classes have some methods that are 55 | required by the TorchTitan trainer. 56 | """ 57 | 58 | @classmethod 59 | def from_model_args(cls, args: BaseModelArgs) -> nn.Module: 60 | ... 61 | 62 | 63 | ParallelizeFunction: TypeAlias = Callable[..., nn.Module] 64 | PipeliningFunction: TypeAlias = Callable[ 65 | ..., tuple[_PipelineSchedule, list[nn.Module], bool, bool] 66 | ] 67 | DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader] 68 | TokenizerBuilder: TypeAlias = Callable[..., Tokenizer] 69 | MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] 70 | OptimizersBuilder: TypeAlias = Callable[ 71 | [list[nn.Module], JobConfig, FTManager], OptimizersContainer 72 | ] 73 | LRSchedulersBuilder: TypeAlias = Callable[ 74 | [OptimizersContainer, JobConfig], LRSchedulersContainer 75 | ] 76 | LossFunctionBuilder: TypeAlias = Callable[..., LossFunction] 77 | 78 | 79 | @dataclass 80 | class TrainSpec: 81 | name: str 82 | cls: type[nn.Module] 83 | config: Mapping[str, BaseModelArgs] 84 | parallelize_fn: ParallelizeFunction 85 | pipelining_fn: PipeliningFunction | None 86 | build_optimizers_fn: OptimizersBuilder 87 | build_lr_schedulers_fn: LRSchedulersBuilder 88 | build_dataloader_fn: DataLoaderBuilder 89 | build_tokenizer_fn: TokenizerBuilder | None 90 | build_loss_fn: LossFunctionBuilder 91 | build_metrics_processor_fn: MetricsProcessorBuilder | None = None 92 | 93 | 94 | _train_specs = {} 95 | 96 | 97 | def register_train_spec(train_spec: TrainSpec) -> None: 98 | global _train_specs 99 | if train_spec.name in _train_specs: 100 | raise ValueError(f"Model {train_spec.name} is already registered.") 101 | 102 | _train_specs[train_spec.name] = train_spec 103 | 104 | 105 | def get_train_spec(name: str) -> TrainSpec: 106 | global _train_specs 107 | if name not in _train_specs: 108 | raise ValueError(f"Model {name} is not registered.") 109 | return _train_specs[name] 110 | 111 | 112 | def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None: 113 | global _train_specs 114 | for name, train_spec in _train_specs.items(): 115 | _train_specs[name] = func(train_spec) 116 | -------------------------------------------------------------------------------- /torchtitan/tools/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | def init_logger(): 15 | logger.setLevel(logging.INFO) 16 | ch = logging.StreamHandler() 17 | ch.setLevel(logging.INFO) 18 | formatter = logging.Formatter( 19 | "[titan] %(asctime)s - %(name)s - %(levelname)s - %(message)s" 20 | ) 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | # suppress verbose torch.profiler logging 25 | os.environ["KINETO_LOG_LEVEL"] = "5" 26 | --------------------------------------------------------------------------------