├── .python-version
├── assets
├── output.wav
├── llama-mimi.png
├── great_day_gt.wav
└── log_validation.png
├── tests
├── __init__.py
└── unit_tests
│ ├── __init__.py
│ └── test_audio_array_to_text.py
├── train.sh
├── torchtitan
├── distributed
│ ├── __init__.py
│ ├── parallel_dims.py
│ └── pipeline.py
├── models
│ ├── __init__.py
│ ├── llama3
│ │ ├── model
│ │ │ ├── state_dict_adapter.py
│ │ │ └── args.py
│ │ ├── train_configs
│ │ │ ├── llama3_70b.toml
│ │ │ ├── llama3_405b.toml
│ │ │ ├── llama3_8b.toml
│ │ │ └── debug_model.toml
│ │ ├── __init__.py
│ │ └── infra
│ │ │ └── pipeline.py
│ └── attention.py
├── __init__.py
├── components
│ ├── quantization
│ │ ├── __init__.py
│ │ ├── utils.py
│ │ ├── mx.py
│ │ └── float8.py
│ ├── loss.py
│ ├── dataloader.py
│ ├── ft.py
│ ├── validate.py
│ ├── lr_scheduler.py
│ ├── optimizer.py
│ ├── metrics.py
│ └── tokenizer.py
├── tools
│ ├── logging.py
│ ├── profiling.py
│ └── utils.py
├── protocols
│ ├── state_dict_adapter.py
│ ├── model_converter.py
│ └── train_spec.py
└── datasets
│ └── hf_datasets.py
├── .gitignore
├── pyproject.toml
├── scripts
└── convert_dcp_to_hf.py
├── LICENSE
├── config
├── llama3_1_8b.toml
├── llama3_2_1b.toml
└── llama3_2_1b_peoples_speech.toml
├── eval
├── salmon.py
├── sLM21.py
└── sStoryCloze.py
├── README.md
└── inference.py
/.python-version:
--------------------------------------------------------------------------------
1 | 3.13
2 |
3 |
--------------------------------------------------------------------------------
/assets/output.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/output.wav
--------------------------------------------------------------------------------
/assets/llama-mimi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/llama-mimi.png
--------------------------------------------------------------------------------
/assets/great_day_gt.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/great_day_gt.wav
--------------------------------------------------------------------------------
/assets/log_validation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/log_validation.png
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/tests/unit_tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | set -eux
2 |
3 | export NGPU=8
4 | export LOG_RANK=0
5 | export CONFIG_FILE="config/llama3_2_1b_peoples_speech.toml"
6 |
7 | torchrun --nproc_per_node=${NGPU} --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
8 | -m torchtitan.train --job.config_file ${CONFIG_FILE}
9 |
--------------------------------------------------------------------------------
/torchtitan/distributed/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | from torchtitan.distributed.parallel_dims import ParallelDims
9 |
10 |
11 | __all__ = ["ParallelDims"]
12 |
--------------------------------------------------------------------------------
/torchtitan/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | # Import the built-in models here so that the corresponding register_model_spec()
9 | # will be called.
10 | # import torchtitan.models.deepseek_v3 # noqa: F401
11 | import torchtitan.models.llama3 # noqa: F401
12 |
--------------------------------------------------------------------------------
/torchtitan/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # Import to register quantization modules.
8 | import torchtitan.components.quantization # noqa: F401
9 |
10 | # Import the built-in models here so that the corresponding register_model_spec()
11 | # will be called.
12 | import torchtitan.models # noqa: F401
13 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea
3 | .DS_Store
4 | *.egg-info
5 | build
6 | outputs
7 | dist/*
8 | .vscode
9 | .vs
10 |
11 | # data
12 | data
13 | out
14 | wandb
15 |
16 | torchtitan/datasets/**/*.model
17 |
18 | # tokenizer models
19 | assets/**/*.model
20 | assets/**/*.json
21 | assets/**/*.txt
22 | torchtitan/experiments/flux/assets/*
23 |
24 | # temp files
25 | *.log
26 | error.json
27 | _remote_module_non_scriptable.py
28 |
29 | # Editor temporaries (VIM)
30 | [._]*.s[a-v][a-z]
31 | [._]*.sw[a-p]
32 | [._]s[a-rt-v][a-z]
33 | [._]ss[a-gi-z]
34 | [._]sw[a-p]
35 | Session.vim
36 | Sessionx.vim
37 | .netrwhist
38 | *~
39 | .~lock.*
40 |
41 | # macOS dir files
42 | .DS_Store
43 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/model/state_dict_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Any
8 |
9 | from torchtitan.protocols.state_dict_adapter import StateDictAdapter
10 |
11 |
12 | class Llama3StateDictAdapter(StateDictAdapter):
13 | @staticmethod
14 | def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]:
15 | # TODO: implement this
16 | return state_dict
17 |
18 | @staticmethod
19 | def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]:
20 | # TODO: implement this
21 | return hf_state_dict
22 |
--------------------------------------------------------------------------------
/torchtitan/components/quantization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # [Note] Getting the 'torchao' package:
8 | # This script requires the 'torchao' package to function correctly.
9 | # Please ensure you have this package installed from the appropriate repository.
10 | # You can obtain it from https://github.com/pytorch/ao by following the
11 | # installation instructions.
12 |
13 | # Note: Performance
14 | # The quantization modules are intended to be ran under `torch.compile`` for competitive performance
15 |
16 | # Import to register quantization modules as ModelConverter
17 | import torchtitan.components.quantization.float8 # noqa: F401
18 | import torchtitan.components.quantization.mx # noqa: F401
19 |
--------------------------------------------------------------------------------
/torchtitan/tools/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 |
10 |
11 | logger = logging.getLogger()
12 |
13 |
14 | def init_logger():
15 | rank = int(os.environ.get("RANK", 0)) # デフォルトは0
16 |
17 | if rank == 0:
18 | logger.setLevel(logging.INFO)
19 | ch = logging.StreamHandler()
20 | ch.setLevel(logging.INFO)
21 | formatter = logging.Formatter(
22 | f"[rank{rank}]:[titan] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
23 | )
24 | ch.setFormatter(formatter)
25 | logger.addHandler(ch)
26 | else:
27 | # rank != 0 のときはログ出力を抑制
28 | logger.setLevel(logging.ERROR) # もしくは logging.CRITICAL でもOK
29 |
30 | # suppress verbose torch.profiler logging
31 | os.environ["KINETO_LOG_LEVEL"] = "5"
32 |
--------------------------------------------------------------------------------
/torchtitan/components/quantization/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch.nn as nn
8 |
9 |
10 | def module_filter_fn(mod: nn.Module, fqn: str, filter_fqns: list[str]) -> bool:
11 | """
12 | Filter function to determine which modules should be converted.
13 | For both Float8 and MXFP8, we only convert Linear modules
14 | with dimensions divisible by 16 and not matching any filtered FQNs.
15 | """
16 | if not isinstance(mod, nn.Linear):
17 | return False
18 |
19 | # All dims must be divisible by 16 due to float8 tensorcore hardware requirements.
20 | dims_multiples_of_16 = (
21 | mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0
22 | )
23 |
24 | # If the fqn matches any filtered fqn, then we should not convert this module.
25 | is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
26 |
27 | return dims_multiples_of_16 and not is_filtered_fqn
28 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "torchtitan"
3 | version = "0.1.0"
4 | description = "Add your description here"
5 | readme = "README.md"
6 | requires-python = ">=3.12"
7 | dependencies = [
8 | "accelerate>=1.9.0",
9 | "blobfile>=3.0.0",
10 | "datasets>=3.6.0",
11 | "fsspec>=2025.3.0",
12 | "librosa>=0.11.0",
13 | "openai-whisper>=20250625",
14 | "soundfile>=0.13.1",
15 | "tabulate>=0.9.0",
16 | "tiktoken>=0.9.0",
17 | "tomli>=1.1.0 ; python_full_version < '3.11'",
18 | "torch>=2.7.1",
19 | "torchaudio>=2.7.1",
20 | "torchcodec==0.5.0",
21 | "torchdata>=0.8.0",
22 | "transformers>=4.53.2",
23 | "tyro>=0.9.25",
24 | "wandb>=0.20.1",
25 | "seaborn>=0.13.2",
26 | "moshi>=0.2.11",
27 | "audiobox-aesthetics>=0.0.4",
28 | "pyannote-audio>=3.3.2",
29 | "python-dotenv>=1.1.1",
30 | "openai>=1.107.1",
31 | "flash-attn==2.8.2",
32 | ]
33 |
34 | [dependency-groups]
35 | dev = [
36 | "matplotlib>=3.10.3",
37 | "mypy>=1.17.0",
38 | "ruff>=0.12.4",
39 | ]
40 |
41 |
42 | [build-system]
43 | requires = ["hatchling"]
44 | build-backend = "hatchling.build"
45 |
46 | [tool.hatch.build.targets.wheel]
47 | packages = ["torchtitan"]
48 |
49 | [tool.uv]
50 | no-build-isolation-package = ["flash-attn"]
51 |
--------------------------------------------------------------------------------
/torchtitan/protocols/state_dict_adapter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import ABC, abstractmethod
8 | from typing import Any
9 |
10 |
11 | class StateDictAdapter(ABC):
12 | """Abstract base class for state dict transformations.
13 |
14 | This class defines the interface for converting between native model
15 | state dict format and other model state dict formats.
16 | """
17 |
18 | @staticmethod
19 | @abstractmethod
20 | def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]:
21 | """Convert from native model state dict to HuggingFace format.
22 |
23 | Args:
24 | state_dict: The native model state dict
25 |
26 | Returns:
27 | The converted HuggingFace format state dict
28 | """
29 | pass
30 |
31 | @staticmethod
32 | @abstractmethod
33 | def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]:
34 | """Obtain native model state dict from HuggingFace format.
35 |
36 | Args:
37 | hf_state_dict: The HuggingFace format state dict
38 |
39 | Returns:
40 | The converted native model state dict
41 | """
42 | pass
43 |
--------------------------------------------------------------------------------
/scripts/convert_dcp_to_hf.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | from torchtitan.components.checkpoint import ModelWrapper
3 | import torch.distributed.checkpoint as dcp
4 | import torch
5 | from torchtitan.train import expand_tokenizer_with_unit_tokens
6 |
7 | if __name__ == "__main__":
8 | num_quantizers = 4
9 | checkpoint_id = f"outputs/Llama-3.2-1B_peoples_speech-q4-s1024/checkpoint/step-5000"
10 | output_dir = f"models/Llama-3.2-1B_peoples_speech-q4-s1024"
11 | model_name = "meta-llama/Llama-3.2-1B"
12 | model = AutoModelForCausalLM.from_pretrained(model_name)
13 |
14 | device = "cuda" if torch.cuda.is_available() else "cpu"
15 | tokenizer = AutoTokenizer.from_pretrained(model_name)
16 | tokenizer.pad_token_id = 0
17 | tokenizer = expand_tokenizer_with_unit_tokens(
18 | tokenizer,
19 | codebook_size=2048,
20 | num_quantizers=num_quantizers,
21 | )
22 |
23 | embedding_size = model.get_input_embeddings().weight.shape[0]
24 | if len(tokenizer) > embedding_size:
25 | model.resize_token_embeddings(len(tokenizer))
26 |
27 | wrapped = ModelWrapper(model)
28 | print(wrapped)
29 | dcp.load(wrapped.state_dict(), checkpoint_id=checkpoint_id)
30 | model.config.num_quantizers = num_quantizers
31 | model.save_pretrained(output_dir)
32 | tokenizer.save_pretrained(output_dir)
33 |
--------------------------------------------------------------------------------
/torchtitan/components/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import functools
8 | from typing import Callable, TypeAlias
9 |
10 | import torch
11 |
12 | from torchtitan.config_manager import JobConfig
13 | from torchtitan.tools.logging import logger
14 |
15 | LossFunction: TypeAlias = Callable[..., torch.Tensor]
16 |
17 |
18 | def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
19 | """Common cross-entropy loss function for Transformer models training."""
20 | return torch.nn.functional.cross_entropy(
21 | pred.flatten(0, 1).float(), labels.flatten(0, 1)
22 | )
23 |
24 |
25 | def build_cross_entropy_loss(job_config: JobConfig):
26 | loss_fn = cross_entropy_loss
27 | if job_config.training.compile:
28 | logger.info("Compiling the loss function with torch.compile")
29 | loss_fn = torch.compile(loss_fn)
30 | return loss_fn
31 |
32 |
33 | def rescale_accumulated_loss(unwrapped_loss_fn, accumulation_steps):
34 | """Add a mean reduction over `accumulation_steps` to the given
35 | `unwrapped_loss_fn`.
36 | """
37 |
38 | @functools.wraps(unwrapped_loss_fn)
39 | def accumulated_loss_fn(*args, **kwargs):
40 | loss = unwrapped_loss_fn(*args, **kwargs)
41 | return loss / accumulation_steps
42 |
43 | return accumulated_loss_fn
44 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/train_configs/llama3_70b.toml:
--------------------------------------------------------------------------------
1 | # torchtitan Config.toml
2 | # NOTE: this toml config is a preset for 64 A100 GPUs.
3 |
4 | [job]
5 | dump_folder = "./outputs"
6 | description = "Llama 3 70B training"
7 |
8 | [profiling]
9 | enable_profiling = true
10 | save_traces_folder = "profile_trace"
11 | profile_freq = 100
12 |
13 | [metrics]
14 | log_freq = 10
15 | enable_tensorboard = true
16 | save_tb_folder = "tb"
17 |
18 | [model]
19 | name = "llama3"
20 | flavor = "70B"
21 | tokenizer_path = "./assets/tokenizer/Llama-3.1-8B"
22 | # converters = ["float8"]
23 |
24 | [optimizer]
25 | name = "AdamW"
26 | lr = 1.5e-4
27 | eps = 1e-8
28 |
29 | [lr_scheduler]
30 | warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps
31 |
32 | [training]
33 | local_batch_size = 8
34 | seq_len = 8192
35 | max_norm = 1.0 # grad norm clipping
36 | steps = 1000
37 | compile = false
38 | dataset = "c4"
39 |
40 | [parallelism]
41 | data_parallel_replicate_degree = 1
42 | data_parallel_shard_degree = -1
43 | tensor_parallel_degree = 8 # 8-way TP
44 | pipeline_parallel_degree = 1
45 | context_parallel_degree = 1
46 |
47 | [checkpoint]
48 | enable_checkpoint = false
49 | folder = "checkpoint"
50 | interval = 500
51 | last_save_model_only = true
52 | export_dtype = "float32"
53 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
54 |
55 | [activation_checkpoint]
56 | mode = "full"
57 |
58 | [float8]
59 | enable_fsdp_float8_all_gather = false
60 | precompute_float8_dynamic_scale_for_fsdp = false
61 | filter_fqns = ["output"]
62 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | (c) Meta Platforms, Inc. and affiliates.
4 | LLM-jp (2025)
5 |
6 | Redistribution and use in source and binary forms, with or without modification,
7 | are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice,this list
10 | of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice, this
13 | list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its contributors may
17 | be used to endorse or promote products derived from this software without specific
18 | prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY
21 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
22 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
23 | SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
24 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
25 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
26 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
28 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
29 | DAMAGE.
30 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/train_configs/llama3_405b.toml:
--------------------------------------------------------------------------------
1 | # torchtitan Config.toml
2 | # NOTE: this toml config is a preset for 128 H100 GPUs.
3 |
4 | [job]
5 | dump_folder = "./outputs"
6 | description = "Llama 3 405B training"
7 |
8 | [profiling]
9 | enable_profiling = true
10 | save_traces_folder = "profile_trace"
11 | profile_freq = 100
12 |
13 | [metrics]
14 | log_freq = 10
15 | enable_tensorboard = true
16 | save_tb_folder = "tb"
17 |
18 | [model]
19 | name = "llama3"
20 | flavor = "405B"
21 | tokenizer_path = "./assets/tokenizer/Llama-3.1-8B"
22 | converters = ["float8"]
23 |
24 | [optimizer]
25 | name = "AdamW"
26 | lr = 8e-5
27 | eps = 1e-8
28 |
29 | [lr_scheduler]
30 | warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps
31 |
32 | [training]
33 | local_batch_size = 2
34 | seq_len = 8192
35 | max_norm = 1.0 # grad norm clipping
36 | steps = 3000
37 | compile = true
38 | dataset = "c4"
39 |
40 | [parallelism]
41 | data_parallel_replicate_degree = 1
42 | data_parallel_shard_degree = -1
43 | tensor_parallel_degree = 8 # 8-way TP
44 | enable_async_tensor_parallel = true
45 | pipeline_parallel_degree = 1
46 | context_parallel_degree = 1
47 |
48 | [checkpoint]
49 | enable_checkpoint = false
50 | folder = "checkpoint"
51 | interval = 500
52 | last_save_model_only = true
53 | export_dtype = "float32"
54 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
55 |
56 | [activation_checkpoint]
57 | mode = "full" # ["none", "selective", "full"]
58 |
59 | [float8]
60 | enable_fsdp_float8_all_gather = true
61 | precompute_float8_dynamic_scale_for_fsdp = true
62 | filter_fqns = ["output"]
63 |
--------------------------------------------------------------------------------
/tests/unit_tests/test_audio_array_to_text.py:
--------------------------------------------------------------------------------
1 | from torchtitan.datasets.hf_datasets import audio_array_to_text
2 | import torchaudio
3 | from transformers import MimiModel, AutoFeatureExtractor
4 |
5 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to("cpu")
6 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
7 | print(feature_extractor.sampling_rate)
8 |
9 | audio_path = "assets/great_day.mp3"
10 | waveform, sample_rate = torchaudio.load(audio_path)
11 | # load audio array
12 | print(waveform)
13 | print(sample_rate)
14 |
15 |
16 | text = audio_array_to_text(waveform[0], audio_tokenizer, feature_extractor, 4)
17 | print(text)
18 | assert text == ""
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/train_configs/llama3_8b.toml:
--------------------------------------------------------------------------------
1 | # torchtitan Config.toml
2 | # NOTE: this toml config is a preset for 64 A100 GPUs.
3 |
4 | [job]
5 | dump_folder = "./outputs"
6 | description = "Llama 3 8B training"
7 |
8 | [profiling]
9 | enable_profiling = false
10 | save_traces_folder = "profile_trace"
11 | profile_freq = 100
12 |
13 | [metrics]
14 | log_freq = 1
15 | enable_tensorboard = true
16 | save_tb_folder = "tb"
17 | enable_wandb = true
18 |
19 | [model]
20 | name = "llama3"
21 | flavor = "8B"
22 | tokenizer_path = "./assets/tokenizer/Llama-3.1-8B"
23 | num_quantizers = 4 # Number of quantizers to use for float8 conversion
24 | # converters = ["float8"]
25 |
26 | [optimizer]
27 | name = "AdamW"
28 | lr = 3e-4
29 | eps = 1e-8
30 |
31 | [lr_scheduler]
32 | warmup_steps = 200 # lr scheduler warm up
33 |
34 | [training]
35 | local_batch_size = 8
36 | global_batch_size = 256
37 | seq_len = 1024
38 | max_norm = 1.0 # grad norm clipping
39 | steps = 30000
40 | compile = false
41 | dataset = "libri_light"
42 |
43 | [parallelism]
44 | data_parallel_replicate_degree = 1
45 | data_parallel_shard_degree = -1
46 | tensor_parallel_degree = 1
47 | pipeline_parallel_degree = 1
48 | context_parallel_degree = 1
49 |
50 | [checkpoint]
51 | enable_checkpoint = true
52 | folder = "checkpoint"
53 | interval = 100
54 | last_save_model_only = false
55 | export_dtype = "float32"
56 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
57 |
58 | [activation_checkpoint]
59 | mode = "none" # ["none", "selective", "full"]
60 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
61 |
62 | [float8]
63 | enable_fsdp_float8_all_gather = false
64 | precompute_float8_dynamic_scale_for_fsdp = false
65 | filter_fqns = ["output"]
66 |
--------------------------------------------------------------------------------
/config/llama3_1_8b.toml:
--------------------------------------------------------------------------------
1 | # torchtitan Config.toml
2 | # NOTE: this toml config is a preset for 64 A100 GPUs.
3 |
4 | [job]
5 | dump_folder = "./outputs"
6 | description = "Llama 3 8B training"
7 |
8 | [profiling]
9 | enable_profiling = false
10 | save_traces_folder = "profile_trace"
11 | profile_freq = 100
12 |
13 | [metrics]
14 | log_freq = 1
15 | enable_tensorboard = true
16 | save_tb_folder = "tb"
17 | enable_wandb = true
18 |
19 | [model]
20 | name = "meta-llama/Llama-3.1-8B"
21 | num_quantizers = 4
22 | # converters = ["float8"]
23 |
24 | [optimizer]
25 | name = "AdamW"
26 | lr = 3e-4
27 | eps = 1e-8
28 |
29 | [lr_scheduler]
30 | warmup_steps = 1500 # lr scheduler warm up
31 | decay_ratio = 0.2
32 | lr_min = 0.1
33 |
34 | [training]
35 | local_batch_size = 8
36 | global_batch_size = 1024
37 | seq_len = 1024
38 | max_norm = 1.0 # grad norm clipping
39 | steps = 100_000
40 | compile = false
41 | dataset = "all" #"libri_light"
42 | task = "a2a"
43 |
44 | [validation]
45 | enabled = true
46 | dataset = "librispeech_asr"
47 | local_batch_size = 8
48 | seq_len = 1024
49 | freq = 100
50 | steps = 10
51 |
52 | [evaluation]
53 | enable_evaluation = false
54 | evaluation_freq = 500
55 |
56 | [parallelism]
57 | data_parallel_replicate_degree = 1
58 | data_parallel_shard_degree = -1
59 | tensor_parallel_degree = 1
60 | pipeline_parallel_degree = 1
61 | context_parallel_degree = 1
62 |
63 | [checkpoint]
64 | enable_checkpoint = true
65 | folder = "checkpoint"
66 | interval = 1000
67 | last_save_model_only = false
68 | export_dtype = "float32"
69 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
70 | keep_latest_k = 10
71 |
72 | [activation_checkpoint]
73 | mode = "none" # ["none", "selective", "full"]
74 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
75 |
76 | [float8]
77 | enable_fsdp_float8_all_gather = false
78 | precompute_float8_dynamic_scale_for_fsdp = false
79 | filter_fqns = ["output"]
80 |
--------------------------------------------------------------------------------
/config/llama3_2_1b.toml:
--------------------------------------------------------------------------------
1 | # torchtitan Config.toml
2 | # NOTE: this toml config is a preset for 64 A100 GPUs.
3 |
4 | [job]
5 | dump_folder = "./outputs"
6 | description = "Llama 3 1B training"
7 |
8 | [profiling]
9 | enable_profiling = false
10 | save_traces_folder = "profile_trace"
11 | profile_freq = 100
12 |
13 | [metrics]
14 | log_freq = 1
15 | enable_tensorboard = true
16 | save_tb_folder = "tb"
17 | enable_wandb = true
18 |
19 | [model]
20 | name = "meta-llama/Llama-3.2-1B"
21 | num_quantizers = 4
22 | # converters = ["float8"]
23 |
24 | [optimizer]
25 | name = "AdamW"
26 | lr = 3e-4
27 | eps = 1e-8
28 |
29 | [lr_scheduler]
30 | warmup_steps = 1500 # lr scheduler warm up
31 | decay_ratio = 0.2
32 | lr_min = 0.1
33 |
34 | [training]
35 | local_batch_size = 32
36 | global_batch_size = 1024
37 | seq_len = 1024
38 | max_norm = 1.0 # grad norm clipping
39 | steps = 100_000
40 | compile = false
41 | dataset = "all" #"libri_light" # "Emilia-librilight"
42 | task = "a2a"
43 |
44 | [validation]
45 | enabled = true
46 | dataset = "librispeech_asr"
47 | local_batch_size = 8
48 | seq_len = 1024
49 | freq = 100
50 | steps = 10
51 |
52 | [evaluation]
53 | enable_evaluation = false
54 | evaluation_freq = 500
55 |
56 | [parallelism]
57 | data_parallel_replicate_degree = 1
58 | data_parallel_shard_degree = -1
59 | tensor_parallel_degree = 1
60 | pipeline_parallel_degree = 1
61 | context_parallel_degree = 1
62 |
63 | [checkpoint]
64 | enable_checkpoint = true
65 | folder = "checkpoint"
66 | interval = 1000
67 | last_save_model_only = false
68 | export_dtype = "float32"
69 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
70 | keep_latest_k = 10
71 |
72 | [activation_checkpoint]
73 | mode = "none" # ["none", "selective", "full"]
74 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
75 |
76 | [float8]
77 | enable_fsdp_float8_all_gather = false
78 | precompute_float8_dynamic_scale_for_fsdp = false
79 | filter_fqns = ["output"]
80 |
--------------------------------------------------------------------------------
/config/llama3_2_1b_peoples_speech.toml:
--------------------------------------------------------------------------------
1 | # torchtitan Config.toml
2 | # NOTE: this toml config is a preset for 64 A100 GPUs.
3 |
4 | [job]
5 | dump_folder = "./outputs"
6 | description = "Llama 3 1B training"
7 |
8 | [profiling]
9 | enable_profiling = false
10 | save_traces_folder = "profile_trace"
11 | profile_freq = 100
12 |
13 | [metrics]
14 | log_freq = 1
15 | enable_tensorboard = true
16 | save_tb_folder = "tb"
17 | enable_wandb = true
18 |
19 | [model]
20 | name = "meta-llama/Llama-3.2-1B"
21 | num_quantizers = 4
22 | # converters = ["float8"]
23 |
24 | [optimizer]
25 | name = "AdamW"
26 | lr = 3e-4
27 | eps = 1e-8
28 |
29 | [lr_scheduler]
30 | warmup_steps = 250 # lr scheduler warm up
31 | decay_ratio = 0.2
32 | lr_min = 0.1
33 |
34 | [training]
35 | local_batch_size = 32
36 | global_batch_size = 1024
37 | seq_len = 1024
38 | max_norm = 1.0 # grad norm clipping
39 | steps = 5_000
40 | compile = false
41 | dataset = "peoples_speech" #"libri_light" # "Emilia-librilight"
42 | task = "a2a"
43 |
44 | [validation]
45 | enabled = true
46 | dataset = "librispeech_asr_test"
47 | local_batch_size = 8
48 | seq_len = 1024
49 | freq = 100
50 | steps = 10
51 |
52 | [evaluation]
53 | enable_evaluation = false
54 | evaluation_freq = 500
55 |
56 | [parallelism]
57 | data_parallel_replicate_degree = 1
58 | data_parallel_shard_degree = -1
59 | tensor_parallel_degree = 1
60 | pipeline_parallel_degree = 1
61 | context_parallel_degree = 1
62 |
63 | [checkpoint]
64 | enable_checkpoint = true
65 | folder = "checkpoint"
66 | interval = 1000
67 | last_save_model_only = false
68 | export_dtype = "float32"
69 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
70 | keep_latest_k = 10
71 |
72 | [activation_checkpoint]
73 | mode = "none" # ["none", "selective", "full"]
74 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy
75 |
76 | [float8]
77 | enable_fsdp_float8_all_gather = false
78 | precompute_float8_dynamic_scale_for_fsdp = false
79 | filter_fqns = ["output"]
80 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/train_configs/debug_model.toml:
--------------------------------------------------------------------------------
1 | # torchtitan Config.toml
2 |
3 | [job]
4 | dump_folder = "./outputs"
5 | description = "Llama 3 debug training"
6 | print_args = false
7 | use_for_integration_test = true
8 |
9 | [profiling]
10 | enable_profiling = false
11 | save_traces_folder = "profile_trace"
12 | profile_freq = 10
13 | enable_memory_snapshot = false
14 | save_memory_snapshot_folder = "memory_snapshot"
15 |
16 | [metrics]
17 | log_freq = 1
18 | disable_color_printing = false
19 | enable_tensorboard = false
20 | save_tb_folder = "tb"
21 | enable_wandb = false
22 |
23 | [model]
24 | name = "llama3"
25 | flavor = "debugmodel"
26 | # test folder with tokenizer.json, for debug purpose only
27 | tokenizer_path = "./tests/assets/tokenizer"
28 | # converters = ["float8"]
29 |
30 | [optimizer]
31 | name = "AdamW"
32 | lr = 8e-4
33 | eps = 1e-8
34 |
35 | [lr_scheduler]
36 | warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
37 | decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
38 | decay_type = "linear"
39 | lr_min = 0.0
40 |
41 | [training]
42 | local_batch_size = 8
43 | seq_len = 2048
44 | max_norm = 1.0 # grad norm clipping
45 | steps = 10
46 | compile = false
47 | dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
48 |
49 | [parallelism]
50 | data_parallel_replicate_degree = 1
51 | data_parallel_shard_degree = -1
52 | fsdp_reshard_after_forward = "default" # default / never / always
53 | tensor_parallel_degree = 1
54 | enable_async_tensor_parallel = false
55 | pipeline_parallel_degree = 1
56 | context_parallel_degree = 1
57 |
58 | [checkpoint]
59 | enable_checkpoint = false
60 | folder = "checkpoint"
61 | interval = 10
62 | last_save_model_only = false
63 | export_dtype = "float32"
64 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
65 |
66 | [activation_checkpoint]
67 | mode = "selective" # ["none", "selective", "full"]
68 | selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy
69 |
70 | [float8]
71 | enable_fsdp_float8_all_gather = false
72 | precompute_float8_dynamic_scale_for_fsdp = false
73 | filter_fqns = ["output"]
74 |
75 | [validation]
76 | enabled = false
77 | dataset = "c4_validation"
78 | freq = 5
79 | steps = 10
80 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from torchtitan.components.loss import build_cross_entropy_loss
8 | from torchtitan.components.lr_scheduler import build_lr_schedulers
9 | from torchtitan.components.optimizer import build_optimizers
10 | from torchtitan.components.tokenizer import build_hf_tokenizer
11 | from torchtitan.components.validate import build_validator
12 | from torchtitan.datasets.hf_datasets import build_hf_dataloader
13 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
14 |
15 | from .infra.parallelize import parallelize_llama
16 | from .infra.pipeline import pipeline_llama
17 | from .model.args import TransformerModelArgs
18 | from .model.model import Transformer
19 | from .model.state_dict_adapter import Llama3StateDictAdapter
20 |
21 | __all__ = [
22 | "parallelize_llama",
23 | "pipeline_llama",
24 | "TransformerModelArgs",
25 | "Transformer",
26 | "llama3_configs",
27 | ]
28 |
29 |
30 | llama3_configs = {
31 | "debugmodel": TransformerModelArgs(
32 | dim=256, n_layers=6, n_heads=16, rope_theta=500000
33 | ),
34 | "debugmodel_flex_attn": TransformerModelArgs(
35 | dim=256,
36 | n_layers=6,
37 | n_heads=16,
38 | rope_theta=500000,
39 | use_flex_attn=True,
40 | attn_mask_type="block_causal",
41 | ),
42 | "8B": TransformerModelArgs(
43 | dim=4096,
44 | n_layers=32,
45 | n_heads=32,
46 | n_kv_heads=8,
47 | ffn_dim_multiplier=1.3,
48 | multiple_of=1024,
49 | rope_theta=500000,
50 | ),
51 | "70B": TransformerModelArgs(
52 | dim=8192,
53 | n_layers=80,
54 | n_heads=64,
55 | n_kv_heads=8,
56 | ffn_dim_multiplier=1.3,
57 | multiple_of=4096,
58 | rope_theta=500000,
59 | ),
60 | "405B": TransformerModelArgs(
61 | dim=16384,
62 | n_layers=126,
63 | n_heads=128,
64 | n_kv_heads=8,
65 | ffn_dim_multiplier=1.2,
66 | multiple_of=4096,
67 | rope_theta=500000,
68 | ),
69 | }
70 |
71 |
72 | register_train_spec(
73 | TrainSpec(
74 | name="llama3",
75 | model_cls=Transformer,
76 | model_args=llama3_configs,
77 | parallelize_fn=parallelize_llama,
78 | pipelining_fn=pipeline_llama,
79 | build_optimizers_fn=build_optimizers,
80 | build_lr_schedulers_fn=build_lr_schedulers,
81 | build_dataloader_fn=build_hf_dataloader,
82 | build_tokenizer_fn=build_hf_tokenizer,
83 | build_loss_fn=build_cross_entropy_loss,
84 | build_validator_fn=build_validator,
85 | state_dict_adapter=Llama3StateDictAdapter,
86 | )
87 | )
88 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/model/args.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8 |
9 |
10 | from dataclasses import dataclass
11 |
12 | from torch import nn
13 |
14 | from torchtitan.components.tokenizer import BaseTokenizer
15 | from torchtitan.config_manager import JobConfig
16 | from torchtitan.protocols.train_spec import BaseModelArgs
17 |
18 |
19 | @dataclass
20 | class TransformerModelArgs(BaseModelArgs):
21 | dim: int = 4096
22 | n_layers: int = 32
23 | n_heads: int = 32
24 | n_kv_heads: int | None = None
25 | vocab_size: int = -1 # defined later by tokenizer
26 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27 | ffn_dim_multiplier: float | None = None
28 | norm_eps: float = 1e-5
29 | rope_theta: float = 10000
30 |
31 | max_seq_len: int = 2048
32 | # If `True`, then each transformer block init uses its layer ID, and if
33 | # `False`, each uses the total number of transformer blocks
34 | depth_init: bool = True
35 |
36 | use_flex_attn: bool = False
37 | attn_mask_type: str = "causal"
38 | eos_id: int = 0
39 |
40 | def update_from_config(
41 | self, job_config: JobConfig, tokenizer: BaseTokenizer
42 | ) -> None:
43 | self.vocab_size = tokenizer.get_vocab_size()
44 | self.max_seq_len = job_config.training.seq_len
45 | self.eos_id = tokenizer.eos_id
46 |
47 | if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
48 | raise ValueError(
49 | "FlexAttention is not compatible with CP yet. "
50 | "We are still working on this."
51 | )
52 |
53 | def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
54 | nparams = sum(p.numel() for p in model.parameters())
55 | nparams_embedding = sum(
56 | sum(p.numel() for p in m.parameters())
57 | for m in model.children()
58 | if isinstance(m, nn.Embedding)
59 | )
60 |
61 | l, h, q, t = (
62 | self.n_layers,
63 | self.n_heads,
64 | self.dim // self.n_heads,
65 | seq_len,
66 | )
67 | # Reasoning behind the factor of 12 for the self-attention part of the formula:
68 | # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
69 | # 2. the flash attention does 1 more matmul recomputation in the backward
70 | # but recomputation should not be counted in calculating MFU (+0)
71 | # 3. each matmul performs 1 multiplication and 1 addition (*2)
72 | # 4. we follow the convention and do not account for sparsity in causal attention
73 | num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
74 |
75 | return nparams, num_flops_per_token
76 |
--------------------------------------------------------------------------------
/torchtitan/protocols/model_converter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from typing import Dict, List, Protocol, Union
7 |
8 | import torch.nn as nn
9 |
10 | from torchtitan.config_manager import JobConfig
11 | from torchtitan.distributed import ParallelDims
12 | from torchtitan.tools.logging import logger
13 |
14 |
15 | class ModelConverter(Protocol):
16 | """General model converter interface.
17 |
18 | A model converter is applying a modification to PyTorch model.
19 | Typical use cases are:
20 | - Quantization: using QAT, FP8, ... specialized linear layers;
21 | - Fused optimized layers (e.g. flash-attention, norms, ...)
22 | """
23 |
24 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ...
25 |
26 | def convert(self, model: nn.Module):
27 | """Inplace convertion of the model."""
28 | ...
29 |
30 | def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
31 | """Post-optimizer (optional) hook (e.g. compute weights statistics)."""
32 | ...
33 |
34 |
35 | _registry_model_converter_cls: Dict[str, type[ModelConverter]] = {}
36 | """Registry of model converter classes.
37 | """
38 |
39 |
40 | def register_model_converter(converter_cls: type[ModelConverter], name: str):
41 | """Register a model converter class.
42 |
43 | A registered model converter can be applied on any model
44 | using the `model.converters` config parameter.
45 | """
46 | assert name not in _registry_model_converter_cls, (
47 | f"A model converter '{name}' is already registered."
48 | )
49 | _registry_model_converter_cls[name] = converter_cls
50 |
51 |
52 | class ModelConvertersContainer(ModelConverter):
53 | """Model converters sequential container.
54 |
55 | The class build the sequence of model converters defined in `model.converters`
56 | job config, and apply them to the model sequentially.
57 | """
58 |
59 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
60 | converter_classes = [
61 | _registry_model_converter_cls[name] for name in job_config.model.converters
62 | ]
63 | self.converters = [
64 | mh_cls(job_config, parallel_dims) for mh_cls in converter_classes
65 | ]
66 | self.print_after_conversion = job_config.model.print_after_conversion
67 |
68 | def convert(self, model: nn.Module):
69 | for mh in self.converters:
70 | mh.convert(model)
71 | if self.print_after_conversion:
72 | logger.info(f"Model definion after conversion:\n\n{model}\n\n")
73 |
74 | def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
75 | for mh in self.converters:
76 | mh.post_optimizer_hook(model)
77 |
78 |
79 | def build_model_converters(
80 | job_config: JobConfig, parallel_dims: ParallelDims
81 | ) -> ModelConvertersContainer:
82 | """Build the collection of model converters to apply to the model."""
83 | return ModelConvertersContainer(job_config, parallel_dims)
84 |
--------------------------------------------------------------------------------
/torchtitan/components/quantization/mx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from functools import partial
8 | from importlib.metadata import version
9 | from importlib.util import find_spec
10 | from typing import Any, List
11 |
12 | import torch.nn as nn
13 |
14 | from torchtitan.config_manager import JobConfig, MX
15 | from torchtitan.distributed import ParallelDims
16 | from torchtitan.protocols.model_converter import (
17 | ModelConverter,
18 | register_model_converter,
19 | )
20 | from torchtitan.tools.logging import logger
21 | from torchtitan.tools.utils import has_cuda_capability
22 |
23 | from .utils import module_filter_fn
24 |
25 | # Maps titan recipe names to torchao mx recipe names
26 | NAME_MAP = {"mxfp8": "mxfp8_cublas"}
27 |
28 |
29 | class MXConverter(ModelConverter):
30 | """Converts the linear layers of `model` to `MXLinear`."""
31 |
32 | enabled: bool
33 | filter_fqns: List[str]
34 | mx_config: Any # MXLinearConfig type when imported
35 |
36 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
37 | # Ensure minimum torchao versions
38 | if find_spec("torchao") is None:
39 | raise ImportError(
40 | "torchao is not installed. Please install it to use MXFP8 linear layers."
41 | )
42 | torchao_version = version("torchao")
43 | mxfp8_min_version = "0.11.0"
44 | if torchao_version < mxfp8_min_version:
45 | raise ImportError(
46 | f"torchao version {torchao_version} is too old, please install torchao {mxfp8_min_version} or later and try again"
47 | )
48 |
49 | # Can be removed if we enable the emulated versions
50 | assert has_cuda_capability(10, 0), (
51 | "MXFP8 is only supported on SM100 or architectures"
52 | )
53 |
54 | self.enabled = True
55 | mx_job_config: MX = job_config.mx
56 | self.filter_fqns = mx_job_config.filter_fqns
57 |
58 | # Configure MXFP8
59 | from torchao.prototype.mx_formats.config import MXLinearConfig
60 |
61 | config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name])
62 | config.use_fp8_dim1_cast_triton_kernel = (
63 | mx_job_config.use_fp8_dim1_cast_triton_kernel
64 | )
65 | self.config = config
66 |
67 | logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}")
68 |
69 | def convert(self, model: nn.Module):
70 | """
71 | Converts the linear layers of `model` to `MXLinear`.
72 | Note that today, only dynamic tensor scaling (the default) is supported.
73 | This will mutate the model inplace.
74 | """
75 | if not self.enabled:
76 | return
77 |
78 | from torchao.prototype.mx_formats.config import MXLinearConfig
79 | from torchao.quantization import quantize_
80 |
81 | assert isinstance(self.config, MXLinearConfig)
82 | quantize_(
83 | model,
84 | config=self.config,
85 | filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns),
86 | )
87 | logger.info("Swapped to MXLinear layers")
88 |
89 | def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
90 | """
91 | MXFP8 doesn't require any post-optimizer hooks at the moment
92 | """
93 | return
94 |
95 |
96 | register_model_converter(MXConverter, "mx")
97 |
--------------------------------------------------------------------------------
/eval/salmon.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, Audio, get_dataset_config_names
2 | from transformers import MimiModel, AutoFeatureExtractor
3 | import os
4 | from torchtitan.datasets.hf_datasets import audio_array_to_text
5 | import torch
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from argparse import ArgumentParser
8 | from tqdm import tqdm
9 | import json
10 | import torch.nn.functional as F
11 | from torch.utils.data import Dataset, DataLoader
12 | from torch.nn.utils.rnn import pad_sequence
13 |
14 |
15 | def parse_args():
16 | parser = ArgumentParser(description="Audio completion generation script")
17 | parser.add_argument(
18 | "--model_name",
19 | type=str,
20 | default="llm-jp/Llama-Mimi-1.3B",
21 | help="Run name for the model",
22 | )
23 | parser.add_argument(
24 | "--output_dir", type=str, default="results", help="Output directory"
25 | )
26 | return parser.parse_args()
27 |
28 |
29 | def compute_loss(
30 | audio, audio_tokenizer, feature_extractor, num_quantizers, tokenizer, model, device
31 | ):
32 | text = audio_array_to_text(
33 | audio, audio_tokenizer, feature_extractor, num_quantizers
34 | )
35 | inputs = tokenizer(text, return_tensors="pt")
36 |
37 | labels = inputs.input_ids.clone()
38 | labels[labels == tokenizer.pad_token_id] = -100
39 | inputs = inputs.to(device)
40 | outputs = model(input_ids=inputs.input_ids, labels=labels)
41 | loss = outputs.loss
42 | return loss
43 |
44 |
45 | if __name__ == "__main__":
46 | args = parse_args()
47 |
48 | device = "cuda" if torch.cuda.is_available() else "cpu"
49 |
50 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to(device)
51 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
52 |
53 | model = (
54 | AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16)
55 | .eval()
56 | .to(device)
57 | )
58 | num_quantizers = model.config.num_quantizers
59 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
60 |
61 | # Set pad token if not already set
62 | if tokenizer.pad_token is None:
63 | tokenizer.pad_token = tokenizer.eos_token
64 |
65 | tasks = get_dataset_config_names("slprl/SALMon")
66 | tasks = [c for c in tasks if not c.startswith("all_")]
67 | result = {}
68 | for task in tasks:
69 | ds = load_dataset("slprl/SALMon", task, split="train")
70 | ds = ds.cast_column(
71 | "negative_audio", Audio(sampling_rate=feature_extractor.sampling_rate)
72 | )
73 | ds = ds.cast_column(
74 | "positive_audio", Audio(sampling_rate=feature_extractor.sampling_rate)
75 | )
76 |
77 | total_correct = 0
78 | total_samples = 0
79 |
80 | for example in tqdm(ds):
81 | negative_audio = example["negative_audio"]["array"]
82 | positive_audio = example["positive_audio"]["array"]
83 |
84 | neg_loss = compute_loss(
85 | negative_audio,
86 | audio_tokenizer,
87 | feature_extractor,
88 | num_quantizers,
89 | tokenizer,
90 | model,
91 | device,
92 | )
93 | pos_loss = compute_loss(
94 | positive_audio,
95 | audio_tokenizer,
96 | feature_extractor,
97 | num_quantizers,
98 | tokenizer,
99 | model,
100 | device,
101 | )
102 | # print(f"Neg loss: {neg_loss.item()}, Pos loss: {pos_loss.item()}")
103 |
104 | total_correct += (neg_loss > pos_loss).item()
105 | total_samples += 1
106 |
107 | acc = total_correct / total_samples
108 | result[task] = acc
109 | print(f"Accuracy for {task}: {acc}")
110 |
111 | output_path = os.path.join(
112 | args.output_dir, "SALMON", args.model_name, "accuracy.json"
113 | )
114 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
115 | with open(output_path, "w") as f:
116 | json.dump(result, f, indent=4)
117 |
--------------------------------------------------------------------------------
/torchtitan/components/dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8 |
9 | import pickle
10 | from abc import ABC, abstractmethod
11 | from collections.abc import Callable
12 | from typing import Any
13 |
14 | from torch.distributed.checkpoint.stateful import Stateful
15 | from torch.utils.data import IterableDataset
16 | from torchdata.stateful_dataloader import StatefulDataLoader
17 | from torchtitan.tools.logging import logger
18 |
19 |
20 | class DataloaderStopIteration(StopIteration):
21 | """An exception that indicates dataloader exhaustion."""
22 |
23 | pass
24 |
25 |
26 | class BaseDataLoader(Stateful, ABC):
27 | """Base class for all dataloaders.
28 |
29 | This is used to enforce that all dataloaders have the methods defined in ``Stateful``,
30 | ``state_dict()`` and ``load_state_dict()``.
31 | """
32 |
33 | @abstractmethod
34 | def __iter__(self): ...
35 |
36 |
37 | class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader):
38 | """Dataloader that is aware of distributed data parallelism.
39 |
40 | This dataloader is used to load data in a distributed data parallel fashion. It also
41 | utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary
42 | methods such as ``__iter__``.
43 |
44 | Args:
45 | dataset (IterableDataset): The dataset to iterate over.
46 | dp_rank: Data parallelism rank for this dataloader.
47 | dp_world_size: The world size of the data parallelism.
48 | batch_size: The batch size to use for each iteration.
49 | collate_fn: Optional function to collate samples in a batch.
50 | """
51 |
52 | dp_rank: int
53 | dp_world_size: int
54 | batch_size: int
55 |
56 | def __init__(
57 | self,
58 | dataset: IterableDataset,
59 | dp_rank: int,
60 | dp_world_size: int,
61 | batch_size: int,
62 | collate_fn: Callable | None = None,
63 | num_workers: int = 2,
64 | prefetch_factor: int | None = 2,
65 | pin_memory: bool = True,
66 | persistent_workers: bool | None = True,
67 | ):
68 | self.dp_world_size = dp_world_size
69 | self.dp_rank = dp_rank
70 | self.batch_size = batch_size
71 | super().__init__(
72 | dataset,
73 | batch_size,
74 | collate_fn=collate_fn,
75 | num_workers=num_workers,
76 | prefetch_factor=prefetch_factor,
77 | pin_memory=pin_memory,
78 | persistent_workers=persistent_workers,
79 | )
80 | self._rank_id = f"dp_rank_{dp_rank}"
81 |
82 | def state_dict(self) -> dict[str, Any]:
83 | # Store state only for dp rank to avoid replicating the same state across other dimensions.
84 | return {
85 | # We don't have to use pickle as DCP will serialize the state_dict. However,
86 | # we have to keep this for backward compatibility.
87 | self._rank_id: pickle.dumps(super().state_dict()),
88 | "world_size": self.dp_world_size,
89 | }
90 |
91 | def load_state_dict(self, state_dict: dict[str, Any]) -> None:
92 | # State being empty is valid.
93 | if not state_dict:
94 | return
95 |
96 | if self._rank_id not in state_dict:
97 | logger.warning(
98 | f"DataLoader state is empty for dp rank {self.dp_rank}, "
99 | "expected key {self._rank_id}"
100 | )
101 | return
102 |
103 | assert self.dp_world_size == state_dict["world_size"], (
104 | "dp_degree is inconsistent before and after checkpoint, "
105 | "dataloader resharding is not supported yet."
106 | )
107 | # We don't have to use pickle as DCP will serialize the state_dict. However, we have to
108 | # keep this for backward compatibility.
109 | super().load_state_dict(pickle.loads(state_dict[self._rank_id]))
110 |
--------------------------------------------------------------------------------
/eval/sLM21.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, Audio
2 | from transformers import MimiModel, AutoFeatureExtractor
3 | import os
4 | from torchtitan.datasets.hf_datasets import audio_array_to_text
5 | import torch
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from argparse import ArgumentParser
8 | from tqdm import tqdm
9 | import json
10 |
11 |
12 | def parse_args():
13 | parser = ArgumentParser(description="Audio completion generation script")
14 | parser.add_argument(
15 | "--model_name",
16 | type=str,
17 | default="llm-jp/Llama-Mimi-1.3B",
18 | help="Run name for the model",
19 | )
20 | parser.add_argument(
21 | "--output_dir", type=str, default="results", help="Output directory"
22 | )
23 | return parser.parse_args()
24 |
25 |
26 | def compute_loss(
27 | audio, audio_tokenizer, feature_extractor, num_quantizers, tokenizer, model, device
28 | ):
29 | text = audio_array_to_text(
30 | audio, audio_tokenizer, feature_extractor, num_quantizers
31 | )
32 | inputs = tokenizer(text, return_tensors="pt")
33 |
34 | labels = inputs.input_ids.clone()
35 | labels[labels == tokenizer.pad_token_id] = -100
36 | mod_mask = torch.zeros_like(labels)
37 | mod_mask[:, 2 :: model.config.num_quantizers] = 1
38 | labels[mod_mask == 0] = -100
39 | # print([tokenizer.decode(ids) for ids in neg_labels[0].tolist() if ids != -100])
40 | inputs = inputs.to(device)
41 | outputs = model(input_ids=inputs.input_ids, labels=labels)
42 | loss = outputs.loss
43 | return loss
44 |
45 |
46 | if __name__ == "__main__":
47 | args = parse_args()
48 |
49 | device = "cuda" if torch.cuda.is_available() else "cpu"
50 |
51 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to(device)
52 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
53 |
54 | model = (
55 | AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16)
56 | .eval()
57 | .to(device)
58 | )
59 | num_quantizers = model.config.num_quantizers
60 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
61 |
62 | # Set pad token if not already set
63 | if tokenizer.pad_token is None:
64 | tokenizer.pad_token = tokenizer.eos_token
65 |
66 | tasks = ["sWUGGY", "sBLIMP"]
67 | result = {}
68 | for task in tasks:
69 | ds = load_dataset(
70 | f"speed/{task}", split="train"
71 | ) # .shuffle(seed=42).select(range(1000))
72 | ds = ds.cast_column(
73 | "negative", Audio(sampling_rate=feature_extractor.sampling_rate)
74 | )
75 | ds = ds.cast_column(
76 | "positive", Audio(sampling_rate=feature_extractor.sampling_rate)
77 | )
78 |
79 | total_correct = 0
80 | total_samples = 0
81 |
82 | for example in tqdm(ds):
83 | negative_audio = example["negative"]["array"]
84 | positive_audio = example["positive"]["array"]
85 |
86 | neg_loss = compute_loss(
87 | negative_audio,
88 | audio_tokenizer,
89 | feature_extractor,
90 | num_quantizers,
91 | tokenizer,
92 | model,
93 | device,
94 | )
95 | pos_loss = compute_loss(
96 | positive_audio,
97 | audio_tokenizer,
98 | feature_extractor,
99 | num_quantizers,
100 | tokenizer,
101 | model,
102 | device,
103 | )
104 | # print(f"Neg loss: {neg_loss.item()}, Pos loss: {pos_loss.item()}")
105 |
106 | total_correct += (neg_loss > pos_loss).item()
107 | total_samples += 1
108 |
109 | acc = total_correct / total_samples
110 | result[task] = acc
111 | print(f"Accuracy for {task}: {acc}")
112 |
113 | output_path = os.path.join(
114 | args.output_dir, "sLM21", args.model_name, "accuracy.json"
115 | )
116 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
117 | with open(output_path, "w") as f:
118 | json.dump(result, f, indent=4)
119 |
--------------------------------------------------------------------------------
/eval/sStoryCloze.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset, Audio
2 | from transformers import MimiModel, AutoFeatureExtractor
3 | import os
4 | from torchtitan.datasets.hf_datasets import audio_array_to_text
5 | import torch
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from argparse import ArgumentParser
8 | from tqdm import tqdm
9 | import json
10 |
11 |
12 | def parse_args():
13 | parser = ArgumentParser(description="Audio completion generation script")
14 | parser.add_argument(
15 | "--model_name",
16 | type=str,
17 | default="llm-jp/Llama-Mimi-1.3B",
18 | help="Run name for the model",
19 | )
20 | parser.add_argument(
21 | "--output_dir", type=str, default="results", help="Output directory"
22 | )
23 | return parser.parse_args()
24 |
25 |
26 | def compute_loss(
27 | audio, audio_tokenizer, feature_extractor, num_quantizers, tokenizer, model, device
28 | ):
29 | text = audio_array_to_text(
30 | audio, audio_tokenizer, feature_extractor, num_quantizers
31 | )
32 | inputs = tokenizer(text, return_tensors="pt")
33 |
34 | labels = inputs.input_ids.clone()
35 | labels[labels == tokenizer.pad_token_id] = -100
36 | mod_mask = torch.zeros_like(labels)
37 | mod_mask[:, 2 :: model.config.num_quantizers] = 1
38 | labels[mod_mask == 0] = -100
39 | # print([tokenizer.decode(ids) for ids in neg_labels[0].tolist() if ids != -100])
40 | inputs = inputs.to(device)
41 | outputs = model(input_ids=inputs.input_ids, labels=labels)
42 | loss = outputs.loss
43 | return loss
44 |
45 |
46 | if __name__ == "__main__":
47 | args = parse_args()
48 |
49 | import multiprocessing as mp
50 |
51 | mp.set_start_method("spawn", force=True)
52 |
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 |
55 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to(device)
56 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
57 |
58 | model = (
59 | AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16)
60 | .eval()
61 | .to(device)
62 | )
63 | num_quantizers = model.config.num_quantizers
64 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
65 |
66 | # Set pad token if not already set
67 | if tokenizer.pad_token is None:
68 | tokenizer.pad_token = tokenizer.eos_token
69 |
70 | tasks = ["sStoryCloze", "tStoryCloze"]
71 | result = {}
72 | for task in tasks:
73 | ds = load_dataset(
74 | f"speed/{task}", split="train"
75 | ) # .shuffle(seed=42).select(range(1000))
76 | ds = ds.cast_column(
77 | "negative", Audio(sampling_rate=feature_extractor.sampling_rate)
78 | )
79 | ds = ds.cast_column(
80 | "positive", Audio(sampling_rate=feature_extractor.sampling_rate)
81 | )
82 |
83 | total_correct = 0
84 | total_samples = 0
85 |
86 | for example in tqdm(ds):
87 | negative_audio = example["negative"]["array"]
88 | positive_audio = example["positive"]["array"]
89 |
90 | neg_loss = compute_loss(
91 | negative_audio,
92 | audio_tokenizer,
93 | feature_extractor,
94 | num_quantizers,
95 | tokenizer,
96 | model,
97 | device,
98 | )
99 | pos_loss = compute_loss(
100 | positive_audio,
101 | audio_tokenizer,
102 | feature_extractor,
103 | num_quantizers,
104 | tokenizer,
105 | model,
106 | device,
107 | )
108 |
109 | total_correct += (neg_loss > pos_loss).item()
110 | total_samples += 1
111 |
112 | acc = total_correct / total_samples
113 | result[task] = acc
114 | print(f"Accuracy for {task}: {acc}")
115 |
116 | output_path = os.path.join(
117 | args.output_dir, "sStoryCloze", args.model_name, "accuracy.json"
118 | )
119 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
120 | with open(output_path, "w") as f:
121 | json.dump(result, f, indent=4)
122 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Llama-Mimi
4 | #### Autoregressive Speech Language Modeling with Interleaved Semantic and Acoustic Tokens
5 | | [📃Paper](https://arxiv.org/abs/2509.14882) | [🤗Models](https://huggingface.co/llm-jp/Llama-Mimi-1.3B) | [🗣️Online Demo](https://speed1313.github.io/llama-mimi/) |
6 |
7 |

8 |
9 |
10 |
11 |
12 |
13 | ## Introduction
14 | Llama-Mimi is a speech language model that uses a unified tokenizer (Mimi) and a single Transformer decoder (Llama) to jointly model sequences of interleaved semantic and acoustic tokens.
15 | Trained on ~240k hours of English audio, Llama-Mimi achieves state-of-the-art performance in acoustic consistency on [SALMon](https://arxiv.org/abs/2409.07437) and effectively preserves speaker identity.
16 |
17 | Visit our [demo site](https://speed1313.github.io/llama-mimi/) to hear generated speech samples.
18 |
19 |
20 | ## Repository Overview
21 | This repository lets you:
22 | - Run inference with our pretrained models
23 | - Pre-train Llama-Mimi on [The People's Speech](https://huggingface.co/datasets/MLCommons/peoples_speech)
24 | - Evaluate the model on multiple benchmarks
25 |
26 | ## Setup
27 |
28 |
29 | Install dependencies using uv:
30 | ```bash
31 | uv sync
32 | ```
33 |
34 | ## Generate Speech
35 |
36 | Generate audio continuations from a given audio prompt using our pretrained model (Llama-Mimi-1.3B):
37 | ```bash
38 | uv run python inference.py
39 | ```
40 |
41 | [▶️ Listen to samples on our demo site](https://speed1313.github.io/llama-mimi)
42 |
43 | ## Pre-train Llama-Mimi on The People's Speech
44 |
45 | To pre-train Llama-Mimi on [The People's Speech](https://huggingface.co/datasets/MLCommons/peoples_speech) (30k hours), first download the dataset locally:
46 | ```bash
47 | uv run huggingface-cli download MLCommons/peoples_speech --repo-type dataset --local-dir data/peoples_speech
48 | ```
49 |
50 | Then launch training with:
51 | ```bash
52 | torchrun --nproc_per_node=8 --local-ranks-filter 0 \
53 | --role rank --tee 3 -m torchtitan.train \
54 | --job.config_file config/llama3_2_1b_peoples_speech.toml
55 | ```
56 | This configuration trains Llama-Mimi-1.3B for 5,000 steps with a global batch size of 1,024 on 8 GPUs, taking about 8 hours.
57 | Training progress can be monitored with Weights & Biases (W&B).
58 |
59 |
60 |

61 |
62 |
63 | To use a custom dataset, update the configuration in `torchtitan/datasets/hf_dataset.py`. We recommend downloading multiple large datasets, shuffling them, and then using `load_dataset()` with local files.
64 |
65 | After training, convert dcp checkpoint to HuggingFace format to use the model with `transformers` library:
66 |
67 | ```bash
68 | uv run python scripts/convert_dcp_to_hf.py
69 | ```
70 |
71 |
72 | ## Evaluation
73 | Evaluate models on [SALMon](https://github.com/slp-rl/salmon), [sLM21](https://arxiv.org/abs/2104.14700) (sWUGGY and sBLIMP), and [sStoryCloze](https://github.com/slp-rl/SpokenStoryCloze) tasks.
74 |
75 | SALMon:
76 | ```bash
77 | uv run python eval/salmon.py --model_name llm-jp/Llama-Mimi-1.3B
78 | ```
79 |
80 | sStoryCloze:
81 | ```bash
82 | uv run python eval/sStoryCloze.py --model_name llm-jp/Llama-Mimi-1.3B
83 | ```
84 |
85 | sLM21:
86 | ```bash
87 | uv run python eval/sLM21.py --model_name llm-jp/Llama-Mimi-1.3B
88 | ```
89 |
90 |
91 |
92 | ## Acknowledge
93 |
94 | - Our training code is built on top of [TorchTitan](https://github.com/pytorch/torchtitan).
95 |
96 | - Our model employs [Llama 3](https://arxiv.org/abs/2407.21783) as the base language model, and [Mimi](https://arxiv.org/abs/2410.00037) as the audio tokenizer.
97 |
98 |
99 | ## Citation
100 | Star us on GitHub if you find this repository useful! ⭐
101 |
102 | If you find this work interesting, please cite our paper:
103 | ```
104 | @misc{sugiura2025llamamimispeechlanguagemodels,
105 | title={Llama-Mimi: Speech Language Models with Interleaved Semantic and Acoustic Tokens},
106 | author={Issa Sugiura and Shuhei Kurita and Yusuke Oda and Ryuichiro Higashinaka},
107 | year={2025},
108 | eprint={2509.14882},
109 | archivePrefix={arXiv},
110 | primaryClass={cs.CL},
111 | url={https://arxiv.org/abs/2509.14882},
112 | }
113 | ```
114 |
115 |
--------------------------------------------------------------------------------
/torchtitan/protocols/train_spec.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from abc import abstractmethod
8 | from collections.abc import Callable
9 | from dataclasses import dataclass
10 | from typing import Protocol, TypeAlias
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.distributed.pipelining.schedules import _PipelineSchedule
15 |
16 | from torchtitan.components.dataloader import BaseDataLoader
17 | from torchtitan.components.ft import FTManager
18 | from torchtitan.components.loss import LossFunction
19 | from torchtitan.components.lr_scheduler import LRSchedulersContainer
20 | from torchtitan.components.metrics import MetricsProcessor
21 | from torchtitan.components.optimizer import OptimizersContainer
22 | from torchtitan.components.tokenizer import BaseTokenizer
23 | from torchtitan.components.validate import BaseValidator
24 | from torchtitan.config_manager import JobConfig
25 | from torchtitan.distributed import ParallelDims
26 | from torchtitan.protocols.state_dict_adapter import StateDictAdapter
27 |
28 |
29 | @dataclass
30 | class BaseModelArgs:
31 | """All ModelArgs should inherit from this class.
32 |
33 | The only usage of this class is type checking but allows us to extend common
34 | arguments to all models in the future.
35 | """
36 |
37 | _enforced: str = "This field is used to enforce all fields have defaults."
38 |
39 | @abstractmethod
40 | def update_from_config(
41 | self, job_config: JobConfig, tokenizer: BaseTokenizer
42 | ) -> None:
43 | pass
44 |
45 | @abstractmethod
46 | def get_nparams_and_flops(
47 | self, model: nn.Module, seq_len: int
48 | ) -> tuple[int, float]:
49 | pass
50 |
51 |
52 | class ModelProtocol(Protocol):
53 | """Defines the interface for a model class.
54 |
55 | This is used to enforce that all model classes have some methods that are
56 | required by the trainer.
57 | """
58 |
59 | def __init__(self, model_args: BaseModelArgs) -> None:
60 | pass
61 |
62 | @abstractmethod
63 | def init_weights(self, buffer_device: torch.device | None = None) -> None:
64 | """Initialize model weights.
65 |
66 | Args:
67 | buffer_device: Optional device to place buffers on during initialization.
68 | """
69 | pass
70 |
71 |
72 | ParallelizeFunction: TypeAlias = Callable[..., nn.Module]
73 | PipeliningFunction: TypeAlias = Callable[
74 | ..., tuple[_PipelineSchedule, list[nn.Module], bool, bool]
75 | ]
76 | DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader]
77 | TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer]
78 | MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor]
79 | OptimizersBuilder: TypeAlias = Callable[
80 | [list[nn.Module], JobConfig, ParallelDims, FTManager | None],
81 | OptimizersContainer,
82 | ]
83 | LRSchedulersBuilder: TypeAlias = Callable[
84 | [OptimizersContainer, JobConfig], LRSchedulersContainer
85 | ]
86 | LossFunctionBuilder: TypeAlias = Callable[..., LossFunction]
87 | ValidatorBuilder: TypeAlias = Callable[..., BaseValidator]
88 |
89 |
90 | @dataclass
91 | class TrainSpec:
92 | name: str
93 | model_cls: type[ModelProtocol]
94 | model_args: dict[str, BaseModelArgs]
95 | parallelize_fn: ParallelizeFunction
96 | pipelining_fn: PipeliningFunction | None
97 | build_optimizers_fn: OptimizersBuilder
98 | build_lr_schedulers_fn: LRSchedulersBuilder
99 | build_dataloader_fn: DataLoaderBuilder
100 | build_tokenizer_fn: TokenizerBuilder | None
101 | build_loss_fn: LossFunctionBuilder
102 | build_validator_fn: ValidatorBuilder | None = None
103 | build_metrics_processor_fn: MetricsProcessorBuilder | None = None
104 | state_dict_adapter: type[StateDictAdapter] | None = None
105 |
106 |
107 | _train_specs = {}
108 |
109 |
110 | def register_train_spec(train_spec: TrainSpec) -> None:
111 | global _train_specs
112 | if train_spec.name in _train_specs:
113 | raise ValueError(f"Model {train_spec.name} is already registered.")
114 |
115 | _train_specs[train_spec.name] = train_spec
116 |
117 |
118 | def get_train_spec(name: str) -> TrainSpec:
119 | global _train_specs
120 | if name not in _train_specs:
121 | raise ValueError(f"Model {name} is not registered.")
122 | return _train_specs[name]
123 |
124 |
125 | def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None:
126 | global _train_specs
127 | for name, train_spec in _train_specs.items():
128 | _train_specs[name] = func(train_spec)
129 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | from transformers import (
2 | AutoModelForCausalLM,
3 | AutoTokenizer,
4 | MimiModel,
5 | AutoFeatureExtractor,
6 | StoppingCriteria,
7 | )
8 | import torch
9 | import torchaudio
10 | import re
11 | import requests
12 | import io
13 |
14 |
15 | def audio_array_to_text(
16 | audio_array: torch.tensor,
17 | audio_tokenizer,
18 | feature_extractor,
19 | num_quantizers: int,
20 | ) -> str:
21 | inputs = feature_extractor(
22 | raw_audio=audio_array,
23 | sampling_rate=feature_extractor.sampling_rate,
24 | return_tensors="pt",
25 | ).to(audio_tokenizer.device)
26 | with torch.no_grad():
27 | encoder_outputs = audio_tokenizer.encode(
28 | inputs["input_values"],
29 | inputs["padding_mask"],
30 | num_quantizers=num_quantizers,
31 | )
32 | flatten_audio_codes = encoder_outputs.audio_codes.transpose(1, 2).reshape(-1)
33 | assert flatten_audio_codes.numel() % num_quantizers == 0
34 | steps = []
35 | for i in range(0, flatten_audio_codes.numel(), num_quantizers):
36 | group = [
37 | f"<{flatten_audio_codes[i + j].item()}_{j}>" for j in range(num_quantizers)
38 | ]
39 | steps.append(group)
40 |
41 | parts = [tok for step in steps for tok in step]
42 |
43 | text = "".join(parts)
44 |
45 | return f""
46 |
47 |
48 | def text_to_audio_values(
49 | text: str,
50 | num_quantizers: int,
51 | output_file: str,
52 | audio_tokenizer,
53 | feature_extractor,
54 | ):
55 | # Extract (val, idx) pairs from the format in the text
56 | matches = re.findall(r"<(\d+)_(\d+)>", text)
57 | vals = []
58 | for i in range(0, len(matches), num_quantizers):
59 | chunk = matches[i : i + num_quantizers]
60 | if len(chunk) < num_quantizers:
61 | break
62 | indices = [int(idx) for _, idx in chunk]
63 | if indices == list(range(num_quantizers)):
64 | vals.extend(int(val) for val, _ in chunk)
65 | else:
66 | break
67 | vals = vals[: len(vals) - len(vals) % num_quantizers]
68 | tensor_bt4 = torch.tensor(vals).reshape(1, -1, num_quantizers) # (B, T, 4)
69 | tensor_b4t = tensor_bt4.transpose(1, 2) # (B, 4, T)
70 | audio_values = audio_tokenizer.decode(tensor_b4t)[0]
71 | torchaudio.save(
72 | output_file,
73 | audio_values[0].detach().cpu(),
74 | feature_extractor.sampling_rate,
75 | )
76 |
77 |
78 | class StopOnAudioEnd(StoppingCriteria):
79 | def __init__(self, tokenizer):
80 | self.tokenizer = tokenizer
81 | self.target_text = ""
82 | self.target_ids = tokenizer(
83 | self.target_text, add_special_tokens=False
84 | ).input_ids
85 |
86 | def __call__(self, input_ids, scores, **kwargs):
87 | if len(input_ids[0]) < len(self.target_ids):
88 | return False
89 | return input_ids[0][-len(self.target_ids) :].tolist() == self.target_ids
90 |
91 |
92 | temperature = 0.8
93 | top_k = 30
94 | do_sample = True
95 | max_length = 1024
96 | device = "cuda" if torch.cuda.is_available() else "cpu"
97 | model_id = "llm-jp/Llama-Mimi-1.3B"
98 | model = (
99 | AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
100 | .eval()
101 | .to(device)
102 | )
103 | num_quantizers = model.config.num_quantizers
104 | tokenizer = AutoTokenizer.from_pretrained(model_id)
105 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi")
106 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
107 | stopping_criteria = StopOnAudioEnd(tokenizer)
108 |
109 | audio_url = (
110 | "https://speed1313.github.io/llama-mimi/data/prompt/natural/great_day_gt.wav"
111 | )
112 | response = requests.get(audio_url)
113 | response.raise_for_status()
114 | waveform, sample_rate = torchaudio.load(io.BytesIO(response.content))
115 | if sample_rate != feature_extractor.sampling_rate:
116 | waveform = torchaudio.transforms.Resample(
117 | sample_rate, feature_extractor.sampling_rate
118 | )(waveform)
119 | sample_rate = feature_extractor.sampling_rate
120 | prompt_array = waveform.squeeze().cpu().numpy()
121 |
122 | text = audio_array_to_text(
123 | prompt_array, audio_tokenizer, feature_extractor, num_quantizers
124 | )
125 |
126 | text = text.replace("", "")
127 | inputs = tokenizer(text, return_tensors="pt").to(device)
128 |
129 | with torch.no_grad():
130 | generated = model.generate(
131 | **inputs,
132 | max_length=max_length,
133 | do_sample=do_sample,
134 | temperature=temperature,
135 | top_k=top_k,
136 | stopping_criteria=[stopping_criteria],
137 | )
138 |
139 | generated_text = tokenizer.decode(generated[0])
140 |
141 | text_to_audio_values(
142 | generated_text,
143 | num_quantizers=num_quantizers,
144 | output_file="output.wav",
145 | audio_tokenizer=audio_tokenizer,
146 | feature_extractor=feature_extractor,
147 | )
148 |
--------------------------------------------------------------------------------
/torchtitan/tools/profiling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import contextlib
8 | import os
9 | import pickle
10 | import time
11 |
12 | import torch
13 |
14 | from torchtitan.config_manager import JobConfig
15 | from torchtitan.tools.logging import logger
16 |
17 | # the number of warmup steps before the active step in each profiling cycle
18 | WARMUP = 3
19 |
20 | # how much memory allocation/free ops to record in memory snapshots
21 | MEMORY_SNAPSHOT_MAX_ENTRIES = 100000
22 |
23 |
24 | @contextlib.contextmanager
25 | def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0):
26 | # get user defined profiler settings
27 | enable_profiling = config.profiling.enable_profiling
28 |
29 | if enable_profiling:
30 | dump_dir = config.job.dump_folder
31 | save_trace_dir = config.profiling.save_traces_folder
32 | trace_dir = os.path.join(dump_dir, save_trace_dir)
33 | profile_freq = config.profiling.profile_freq
34 |
35 | rank = torch.distributed.get_rank()
36 |
37 | replica_id = None
38 | if config.fault_tolerance.enable:
39 | replica_id = config.fault_tolerance.replica_id
40 |
41 | def trace_handler(prof):
42 | curr_trace_dir_name = "iteration_" + str(prof.step_num)
43 | curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
44 | if not os.path.exists(curr_trace_dir):
45 | os.makedirs(curr_trace_dir, exist_ok=True)
46 |
47 | logger.info(f"Dumping profiler traces at step {prof.step_num}")
48 | begin = time.monotonic()
49 |
50 | output_file = curr_trace_dir
51 | if replica_id is not None:
52 | output_file = os.path.join(output_file, f"replica{replica_id}")
53 | output_file = os.path.join(output_file, f"rank{rank}_trace.json")
54 |
55 | prof.export_chrome_trace(output_file)
56 | logger.info(
57 | f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds"
58 | )
59 |
60 | logger.info(f"Profiling active. Traces will be saved at {trace_dir}")
61 |
62 | if not os.path.exists(trace_dir):
63 | os.makedirs(trace_dir, exist_ok=True)
64 |
65 | warmup, active = WARMUP, 1
66 | wait = profile_freq - (active + warmup)
67 | assert wait >= 0, (
68 | "profile_freq must be greater than or equal to warmup + active"
69 | )
70 | gpu_device_profiled = None
71 | if torch.cuda.is_available():
72 | gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA
73 | elif torch.xpu.is_available():
74 | gpu_device_profiled = torch.profiler.ProfilerActivity.XPU
75 | with torch.profiler.profile(
76 | activities=[
77 | torch.profiler.ProfilerActivity.CPU,
78 | gpu_device_profiled,
79 | ],
80 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
81 | on_trace_ready=trace_handler,
82 | record_shapes=True,
83 | ) as torch_profiler:
84 | torch_profiler.step_num = global_step
85 | yield torch_profiler
86 | else:
87 | torch_profiler = contextlib.nullcontext()
88 | yield None
89 |
90 |
91 | @contextlib.contextmanager
92 | def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0):
93 | enable_snapshot = config.profiling.enable_memory_snapshot
94 | if enable_snapshot:
95 | snapshot_folder = config.profiling.save_memory_snapshot_folder
96 | snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder)
97 | if not os.path.exists(snapshot_dir):
98 | os.makedirs(snapshot_dir, exist_ok=True)
99 | rank = torch.distributed.get_rank()
100 |
101 | class MemoryProfiler:
102 | def __init__(self, step_num: int, freq: int):
103 | torch.cuda.memory._record_memory_history(
104 | max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES
105 | )
106 | # when resume training, we start from the last step
107 | self.step_num = step_num
108 | self.freq = freq
109 |
110 | def step(self, exit_ctx: bool = False):
111 | self.step_num += 1
112 | if not exit_ctx and self.step_num % self.freq != 0:
113 | return
114 | if not exit_ctx:
115 | curr_step = self.step_num
116 | dir_name = f"iteration_{curr_step}"
117 | else:
118 | # dump as iteration_0_exit if OOM at iter 1
119 | curr_step = self.step_num - 1
120 | dir_name = f"iteration_{curr_step}_exit"
121 | curr_snapshot_dir = os.path.join(snapshot_dir, dir_name)
122 | if not os.path.exists(curr_snapshot_dir):
123 | os.makedirs(curr_snapshot_dir, exist_ok=True)
124 | logger.info(f"Dumping memory snapshot at step {curr_step}")
125 | begin = time.monotonic()
126 | with open(
127 | f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb"
128 | ) as output:
129 | pickle.dump(torch.cuda.memory._snapshot(), output)
130 | logger.info(
131 | f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds"
132 | )
133 |
134 | logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}")
135 | profiler = MemoryProfiler(global_step, config.profiling.profile_freq)
136 | try:
137 | yield profiler
138 | except torch.OutOfMemoryError:
139 | profiler.step(exit_ctx=True)
140 | else:
141 | yield None
142 |
--------------------------------------------------------------------------------
/torchtitan/models/llama3/infra/pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | # This file applies the PT-D pipeline parallelism to the Llama model.
8 |
9 | import copy
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.distributed import DeviceMesh
14 | from torch.distributed.pipelining import PipelineStage
15 | from torch.distributed.pipelining.schedules import (
16 | _PipelineSchedule,
17 | get_schedule_class,
18 | ScheduleZBVZeroBubble,
19 | )
20 |
21 | from torchtitan.components.loss import LossFunction
22 | from torchtitan.config_manager import JobConfig
23 | from torchtitan.distributed import ParallelDims
24 | from torchtitan.distributed.pipeline import (
25 | build_pipeline_schedule,
26 | generate_split_points,
27 | stage_ids_this_rank,
28 | )
29 | from torchtitan.protocols.train_spec import ParallelizeFunction
30 | from torchtitan.tools.logging import logger
31 |
32 | from ..model.args import TransformerModelArgs
33 |
34 |
35 | def pipeline_llama(
36 | model: nn.Module,
37 | parallel_dims: ParallelDims,
38 | job_config: JobConfig,
39 | device: torch.device,
40 | model_config: TransformerModelArgs,
41 | parallelize_fn: ParallelizeFunction,
42 | loss_fn: LossFunction,
43 | ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
44 | pp_mesh = parallel_dims.world_mesh["pp"]
45 |
46 | stages, model_parts = pipeline_llama_manual_split(
47 | model, pp_mesh, parallel_dims, job_config, device, model_config
48 | )
49 |
50 | # For PP with looped schedules, each item in model_parts is one stage-model-chunk.
51 | # We need to iterate through model_parts to apply SPMD parallelisms, compilation,
52 | # optimizer, and checkpointing
53 | for i, m in enumerate(model_parts):
54 | # apply SPMD-style PT-D techniques
55 | m = parallelize_fn(m, parallel_dims, job_config)
56 | model_parts[i] = m
57 | # NOTE: this is to update the model in the stage
58 | # in case the model is modified e.g. by torch.compile
59 | stages[i].submod = m
60 |
61 | pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)
62 |
63 | # This is used in the train loop to determine whether to pass in the input_ids and labels
64 | has_first_stage = False
65 | has_last_stage = False
66 | for stage in stages:
67 | if stage.is_first:
68 | has_first_stage = True
69 | if stage.is_last:
70 | has_last_stage = True
71 |
72 | return pp_schedule, model_parts, has_first_stage, has_last_stage
73 |
74 |
75 | def pipeline_llama_manual_split(
76 | whole_model: nn.Module,
77 | pp_mesh: DeviceMesh,
78 | parallel_dims: ParallelDims,
79 | job_config: JobConfig,
80 | device: torch.device,
81 | model_config: TransformerModelArgs,
82 | ) -> tuple[list[PipelineStage], list[nn.Module]]:
83 | """
84 | This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
85 |
86 | It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects.
87 |
88 | The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD
89 | parallelism.
90 | """
91 | pp_rank = pp_mesh.get_local_rank()
92 | pp_size = pp_mesh.size()
93 | parallelism_config = job_config.parallelism
94 |
95 | splits = parallelism_config.pipeline_parallel_split_points or generate_split_points(
96 | parallelism_config.pipeline_parallel_schedule,
97 | parallel_dims.pp,
98 | model_config.n_layers,
99 | parallelism_config.pipeline_parallel_layers_per_stage,
100 | )
101 |
102 | def _build_stage(
103 | stage_idx: int,
104 | start_layer: str | None,
105 | stop_layer: str | None,
106 | is_first: bool = False,
107 | is_last: bool = False,
108 | ) -> tuple[PipelineStage, nn.Module]:
109 | model = copy.deepcopy(whole_model)
110 | if not is_first:
111 | model.tok_embeddings = None
112 |
113 | drop_layers = start_layer is not None
114 | for name in list(model.layers.keys()):
115 | # we keep layers in a contiguous region between start (inclusive) and stop (exclusive)
116 | if f"layers.{name}" == start_layer:
117 | drop_layers = False
118 | if f"layers.{name}" == stop_layer:
119 | drop_layers = True
120 | if drop_layers:
121 | del model.layers[name]
122 |
123 | if not is_last:
124 | model.norm = None
125 | model.output = None
126 |
127 | stage = PipelineStage(
128 | model,
129 | stage_idx,
130 | num_stages,
131 | device,
132 | group=pp_mesh.get_group("pp"),
133 | )
134 | return stage, model
135 |
136 | num_stages = len(splits) + 1
137 | stage_idx = pp_rank
138 |
139 | stages = []
140 | models = []
141 |
142 | schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule)
143 | style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"
144 |
145 | for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
146 | start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
147 | stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
148 | stage, model_chunk = _build_stage(
149 | stage_idx,
150 | start_layer,
151 | stop_layer,
152 | is_first=stage_idx == 0,
153 | is_last=stage_idx == num_stages - 1,
154 | )
155 | logger.info(
156 | f"PP rank {pp_rank} is building stage_idx {stage_idx}"
157 | f" with start_layer {start_layer}, stop_layer {stop_layer}"
158 | )
159 | stages.append(stage)
160 | models.append(model_chunk)
161 | return stages, models
162 |
--------------------------------------------------------------------------------
/torchtitan/components/ft.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import importlib
8 | from contextlib import nullcontext
9 | from datetime import timedelta
10 | from typing import ContextManager, Optional, TYPE_CHECKING, Union
11 |
12 | import torch
13 | import torch.distributed as dist
14 | from torch.distributed._composable.fsdp.fully_shard import FSDPModule
15 | from torch.distributed.distributed_c10d import ReduceOp
16 | from torchtitan.config_manager import FaultTolerance as FTConfig
17 |
18 | if importlib.util.find_spec("torchft") is not None:
19 | import torchft as ft
20 |
21 | if TYPE_CHECKING:
22 | from torchft import local_sgd
23 |
24 | has_torchft = True
25 | else:
26 | has_torchft = False
27 |
28 |
29 | class FTManager:
30 | def __init__(
31 | self,
32 | ft_config: FTConfig,
33 | ) -> None:
34 | if not ft_config.enable:
35 | self._manager = None
36 | return
37 |
38 | if not has_torchft:
39 | raise ImportError("torchft is not installed. Please install it.")
40 |
41 | process_group_timeout = timedelta(
42 | milliseconds=ft_config.process_group_timeout_ms
43 | )
44 | if ft_config.process_group == "gloo":
45 | pg = ft.ProcessGroupGloo(timeout=process_group_timeout)
46 | elif ft_config.process_group == "nccl":
47 | pg = ft.ProcessGroupNCCL(timeout=process_group_timeout)
48 | else:
49 | raise ValueError(f"Unsuported process group: {ft_config.process_group}")
50 |
51 | # If the training method is specific, then the quorum should be synchronous
52 | self.use_async_quorum = ft_config.semi_sync_method is None
53 |
54 | self._manager = ft.Manager(
55 | pg=pg,
56 | min_replica_size=ft_config.min_replica_size,
57 | load_state_dict=None,
58 | state_dict=None,
59 | use_async_quorum=self.use_async_quorum,
60 | replica_id=f"torchtitan_ft_{ft_config.replica_id}",
61 | )
62 | self.group_size = ft_config.group_size
63 | self.replica_id = ft_config.replica_id
64 |
65 | if self.use_async_quorum:
66 | self.replicate_pg = ft.process_group.ManagedProcessGroup(self._manager)
67 | self.replicate_pg.register("dp_replicate")
68 |
69 | @property
70 | def enabled(self) -> bool:
71 | return self._manager is not None
72 |
73 | @property
74 | def manager(self) -> "ft.Manager":
75 | assert self._manager is not None
76 | return self._manager
77 |
78 | def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]:
79 | if self.enabled:
80 | return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank
81 | else:
82 | return dp_degree, dp_rank
83 |
84 | def maybe_set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None:
85 | if self.enabled and self.use_async_quorum:
86 |
87 | def all_reduce_hook(output):
88 | dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG)
89 |
90 | def apply_set_all_reduce_hook(m):
91 | if isinstance(m, FSDPModule):
92 | m.set_all_reduce_hook(all_reduce_hook)
93 |
94 | for model_part in model_parts:
95 | model_part.apply(apply_set_all_reduce_hook)
96 |
97 | @property
98 | def loss_sync_pg(
99 | self,
100 | ) -> Optional["ft.process_group.ManagedProcessGroup"]:
101 | if self.enabled and self.use_async_quorum:
102 | return self.replicate_pg
103 | else:
104 | # skip loss sync when using semi-sync training
105 | return None
106 |
107 |
108 | def maybe_semi_sync_training(
109 | ft_config: FTConfig,
110 | ft_manager: FTManager,
111 | model_parts: list[torch.nn.Module],
112 | optimizer: torch.optim.Optimizer,
113 | ) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]:
114 | """
115 | If TorchFT is enabled and the config is set, use semi_sync_method
116 | """
117 | semi_sync_method = ft_config.semi_sync_method
118 | if ft_config.enable and semi_sync_method is not None:
119 | from torchft import local_sgd
120 |
121 | assert ft_manager._manager is not None, (
122 | "FTManager must be enabled to use semi-sync training."
123 | )
124 | if semi_sync_method.lower() == "diloco":
125 | # Create the outer optimizer based on the inner optimizer parameters.
126 | params = [group["params"] for group in optimizer.param_groups]
127 | params = [param for sublist in params for param in sublist]
128 | outer_optimizers = []
129 | for model in model_parts:
130 | params = [p for p in model.parameters() if p.requires_grad]
131 | outer_optimizer = torch.optim.SGD(
132 | params, lr=0.7, momentum=0.9, nesterov=True
133 | )
134 | outer_optimizers.append(outer_optimizer)
135 |
136 | return local_sgd.DiLoCo(
137 | manager=ft_manager._manager,
138 | model_fragments=model_parts,
139 | inner_optimizer=optimizer,
140 | outer_optimizer=outer_optimizers,
141 | sync_every=ft_config.sync_steps,
142 | should_quantize=ft_config.should_quantize,
143 | fragment_sync_delay=ft_config.fragment_sync_delay,
144 | fragment_update_alpha=ft_config.fragment_update_alpha,
145 | )
146 | elif semi_sync_method.lower() == "local_sgd":
147 | assert len(model_parts) == 1
148 | return local_sgd.LocalSGD(
149 | manager=ft_manager._manager,
150 | model=model_parts[0],
151 | optimizer=optimizer,
152 | sync_every=ft_config.sync_steps,
153 | )
154 | else:
155 | raise ValueError(
156 | f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported."
157 | )
158 | return nullcontext()
159 |
--------------------------------------------------------------------------------
/torchtitan/components/validate.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Generator
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.distributed.fsdp import FSDPModule
12 | from torchtitan.components.dataloader import BaseDataLoader
13 | from torchtitan.components.metrics import MetricsProcessor
14 | from torchtitan.components.tokenizer import BaseTokenizer
15 | from torchtitan.config_manager import JobConfig
16 | from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader
17 | from torchtitan.distributed import ParallelDims, utils as dist_utils
18 | from torchtitan.tools import utils
19 |
20 |
21 | class BaseValidator:
22 | def __init__(self, job_config: JobConfig):
23 | self.job_config = job_config
24 |
25 | def validate(self, model_parts: list[nn.Module]) -> dict[str, float]:
26 | raise NotImplementedError("validate method not implemented")
27 |
28 | def should_validate(self, step: int) -> bool:
29 | return step % self.job_config.validation.freq == 0
30 |
31 |
32 | class Validator(BaseValidator):
33 | """
34 | Simple validator focused on correctness and integration.
35 |
36 | Args:
37 | job_config: Job configuration
38 | validation_dataloader: The validation dataloader
39 | model: The model to validate (single model, no parallelism)
40 | """
41 |
42 | validation_dataloader: BaseDataLoader
43 |
44 | def __init__(
45 | self,
46 | job_config: JobConfig,
47 | dp_world_size: int,
48 | dp_rank: int,
49 | tokenizer: BaseTokenizer,
50 | audio_tokenizer,
51 | feature_extractor,
52 | parallel_dims: ParallelDims,
53 | validation_context: Generator[None, None, None],
54 | maybe_enable_amp: Generator[None, None, None],
55 | metrics_processor: MetricsProcessor,
56 | ):
57 | self.job_config = job_config
58 | self.parallel_dims = parallel_dims
59 | self.validation_dataloader = build_hf_validation_dataloader(
60 | dp_world_size=dp_world_size,
61 | dp_rank=dp_rank,
62 | tokenizer=tokenizer,
63 | audio_tokenizer=audio_tokenizer,
64 | feature_extractor=feature_extractor,
65 | job_config=job_config,
66 | )
67 | self.validation_context = validation_context
68 | self.maybe_enable_amp = maybe_enable_amp
69 | self.metrics_processor = metrics_processor
70 |
71 | @torch.no_grad()
72 | def validate(
73 | self,
74 | model_parts: list[nn.Module],
75 | step: int,
76 | ) -> dict[str, float]:
77 | # Set model to eval mode
78 | # TODO: currently does not support pipeline parallelism
79 | model = model_parts[0]
80 | model.eval()
81 |
82 | parallel_dims = self.parallel_dims
83 |
84 | accumulated_losses = []
85 | device_type = utils.device_type
86 | num_steps = 0
87 |
88 | for input_dict in self.validation_dataloader:
89 | if (
90 | self.job_config.validation.steps != -1
91 | and num_steps >= self.job_config.validation.steps
92 | ):
93 | break
94 |
95 | self.metrics_processor.ntokens_since_last_log += input_dict[
96 | "input_ids"
97 | ].numel()
98 | input_ids = input_dict["input_ids"].to(device_type)
99 | attention_mask = input_dict["attention_mask"].to(device_type)
100 |
101 | optional_context_parallel_ctx = (
102 | dist_utils.create_context_parallel_ctx(
103 | cp_mesh=parallel_dims.world_mesh["cp"],
104 | cp_buffers=[input_ids],
105 | cp_seq_dims=[1],
106 | cp_no_restore_buffers={inputs},
107 | cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
108 | )
109 | if parallel_dims.cp_enabled
110 | else None
111 | )
112 |
113 | with self.validation_context(optional_context_parallel_ctx):
114 | assert len(model_parts) == 1
115 | with self.maybe_enable_amp:
116 | outputs = model_parts[0](
117 | input_ids=input_ids,
118 | attention_mask=attention_mask,
119 | labels=input_ids,
120 | )
121 | loss = outputs.loss
122 |
123 | accumulated_losses.append(loss.detach())
124 |
125 | num_steps += 1
126 |
127 | # Compute average loss
128 | loss = torch.sum(torch.stack(accumulated_losses))
129 | loss /= num_steps
130 | if parallel_dims.dp_cp_enabled:
131 | global_avg_loss = dist_utils.dist_mean(
132 | loss, parallel_dims.world_mesh["dp_cp"]
133 | )
134 | else:
135 | global_avg_loss = loss.item()
136 |
137 | self.metrics_processor.log_validation(loss=global_avg_loss, step=step)
138 |
139 | # Reshard after run forward pass
140 | # This is to ensure the model weights are sharded the same way for checkpoint saving.
141 | for module in model.modules():
142 | if isinstance(module, FSDPModule):
143 | module.reshard()
144 |
145 | # Set model back to train mode
146 | model.train()
147 |
148 |
149 | def build_validator(
150 | job_config: JobConfig,
151 | dp_world_size: int,
152 | dp_rank: int,
153 | tokenizer: BaseTokenizer,
154 | audio_tokenizer,
155 | feature_extractor,
156 | parallel_dims: ParallelDims,
157 | validation_context: Generator[None, None, None],
158 | maybe_enable_amp: Generator[None, None, None],
159 | metrics_processor: MetricsProcessor | None = None,
160 | ) -> BaseValidator:
161 | """Build a simple validator focused on correctness."""
162 | return Validator(
163 | job_config=job_config,
164 | dp_world_size=dp_world_size,
165 | dp_rank=dp_rank,
166 | tokenizer=tokenizer,
167 | audio_tokenizer=audio_tokenizer,
168 | feature_extractor=feature_extractor,
169 | parallel_dims=parallel_dims,
170 | validation_context=validation_context,
171 | maybe_enable_amp=maybe_enable_amp,
172 | metrics_processor=metrics_processor,
173 | )
174 |
--------------------------------------------------------------------------------
/torchtitan/tools/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import gc
8 | import subprocess
9 | import time
10 | from dataclasses import dataclass
11 | from typing import Optional
12 |
13 | import torch
14 | from torch._utils import _get_available_device_type, _get_device_module
15 |
16 | from torchtitan.tools.logging import logger
17 |
18 |
19 | def has_cuda_capability(major: int, minor: int) -> bool:
20 | return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
21 | major,
22 | minor,
23 | )
24 |
25 |
26 | def get_device_info() -> tuple[str, torch.device]:
27 | device_type = _get_available_device_type() or "cuda"
28 | device_module = _get_device_module(device_type) # default device_module:torch.cuda
29 | return device_type, device_module
30 |
31 |
32 | device_type, device_module = get_device_info()
33 |
34 |
35 | # used to avoid stragglers in garbage collection
36 | class GarbageCollection:
37 | def __init__(self, gc_freq: int = 1000, debug: bool = False):
38 | assert gc_freq > 0, "gc_freq must be a positive integer"
39 | self.gc_freq = gc_freq
40 | self.debug = debug
41 | gc.disable()
42 | self.collect("Initial GC collection.")
43 | if debug:
44 | from torch.utils.viz._cycles import warn_tensor_cycles
45 |
46 | if torch.distributed.get_rank() == 0:
47 | warn_tensor_cycles()
48 |
49 | def run(self, step_count: int):
50 | if self.debug:
51 | self.collect(
52 | "Force GC to perform collection to obtain debug information.",
53 | generation=2,
54 | )
55 | gc.collect()
56 | elif step_count > 1 and step_count % self.gc_freq == 0:
57 | self.collect("Peforming periodical GC collection.")
58 |
59 | @staticmethod
60 | def collect(reason: str, generation: int = 1):
61 | begin = time.monotonic()
62 | gc.collect(generation)
63 | logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin)
64 |
65 |
66 | # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC
67 | def get_peak_flops(device_name: str) -> int:
68 | try:
69 | # Run the lspci command and capture the output
70 | result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True)
71 | # Filter the output for lines containing both "NVIDIA" and "H100"
72 | filtered_lines = [
73 | line
74 | for line in result.stdout.splitlines()
75 | if "NVIDIA" in line and "H100" in line
76 | ]
77 | # Join all filtered lines into a single string
78 | device_name = " ".join(filtered_lines) or device_name
79 | except FileNotFoundError as e:
80 | logger.warning(f"Error running lspci: {e}, fallback to use device_name")
81 | if "A100" in device_name:
82 | # data from https://www.nvidia.com/en-us/data-center/a100/
83 | return 312e12
84 | elif "H100" in device_name:
85 | # data from https://www.nvidia.com/en-us/data-center/h100/
86 | # NOTE: Specifications are one-half lower without sparsity.
87 | if "NVL" in device_name:
88 | return 835e12
89 | elif "PCIe" in device_name:
90 | return 756e12
91 | else: # for H100 SXM and other variants
92 | return 989e12
93 | elif "H200" in device_name:
94 | # data from https://www.nvidia.com/en-us/data-center/h200/
95 | return 989e12
96 | elif "B200" in device_name:
97 | # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
98 | return 2.25e15
99 | elif "MI300X" in device_name or "MI325X" in device_name:
100 | # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html
101 | # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html
102 | return 1300e12
103 | elif "MI250X" in device_name:
104 | # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD)
105 | return 191.5e12
106 | elif "Data Center GPU Max 1550" in device_name:
107 | # Also known as Ponte Vecchio (PVC).
108 | # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
109 | # Dot Product Accumulate Systolic (DPAS):
110 | # - Freq: 1300MHz
111 | # - #ops: 512
112 | # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16)
113 | # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16)
114 | max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units
115 | return 512 * max_comp_units * 1300 * 10**6
116 | elif "l40s" in device_name:
117 | # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413"
118 | return 362e12
119 |
120 | else: # for other GPU types, assume A100
121 | logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100")
122 | return 312e12
123 |
124 |
125 | @dataclass(frozen=True)
126 | class Color:
127 | black = "\033[30m"
128 | red = "\033[31m"
129 | green = "\033[32m"
130 | yellow = "\033[33m"
131 | blue = "\033[34m"
132 | magenta = "\033[35m"
133 | cyan = "\033[36m"
134 | white = "\033[37m"
135 | reset = "\033[39m"
136 | orange = "\033[38;2;180;60;0m"
137 | turquoise = "\033[38;2;54;234;195m"
138 |
139 |
140 | @dataclass(frozen=True)
141 | class NoColor:
142 | black = ""
143 | red = ""
144 | green = ""
145 | yellow = ""
146 | blue = ""
147 | magenta = ""
148 | cyan = ""
149 | white = ""
150 | reset = ""
151 | orange = ""
152 | turquoise = ""
153 |
154 |
155 | assert set(NoColor.__dataclass_fields__.keys()) == set(
156 | Color.__dataclass_fields__.keys()
157 | ), "NoColor must have the same fields as Color."
158 |
159 |
160 | def check_if_feature_in_pytorch(
161 | feature_name: str,
162 | pull_request: str,
163 | min_nightly_version: Optional[str] = None,
164 | ) -> None:
165 | if "git" in torch.__version__: # pytorch is built from source
166 | # notify users to check if the pull request is included in their pytorch
167 | logger.warning(
168 | "detected that the pytorch is built from source. Please make sure the PR "
169 | f"({pull_request_link}) is included in pytorch for correct {feature_name}."
170 | )
171 | elif min_nightly_version is not None and torch.__version__ < min_nightly_version:
172 | logger.warning(
173 | f"detected that the pytorch version {torch.__version__} is older than "
174 | f"{min_nightly_version}. Please upgrade a newer version to include the "
175 | f"change in ({pull_request_link}) for correct {feature_name}."
176 | )
177 |
--------------------------------------------------------------------------------
/torchtitan/components/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import copy
8 | import functools
9 | import math
10 | from typing import Any, Callable, Iterator
11 |
12 | from torch.distributed.checkpoint.stateful import Stateful
13 | from torch.optim.lr_scheduler import LambdaLR, LRScheduler
14 |
15 | from torchtitan.components.optimizer import OptimizersContainer
16 | from torchtitan.config_manager import JobConfig
17 | from torchtitan.tools.logging import logger
18 |
19 | __all__ = [
20 | "LRSchedulersContainer",
21 | "build_lr_schedulers",
22 | ]
23 |
24 |
25 | class LRSchedulersContainer(Stateful):
26 | """Container for multiple learning rate schedulers.
27 |
28 | This class is used to wrap multiple LRSchedulers into a single object that can be
29 | used to reduce the complexity of the training loop. This mimics the behavior of
30 | ``torch.optim.lr_scheduler.LRScheduler``. The design concept is the same as
31 | ``OptimizersContainer``. This class currently only supports ``LambdaLR``.
32 |
33 | **Note**
34 | Users who want to customize the lr_scheduler behavior can inherit from this class and
35 | extend the functionality as needed. The following methods must follow the same
36 | signature as ``torch.optim.lr_scheduler.LRScheduler`` class: ``step()``, ``state_dict()``,
37 | ``load_state_dict()``.
38 |
39 | **Limitations**
40 | This class assumes all the lr schedulers are the same. There is no easy way to support
41 | resharding for multiple different LRSchedulers because LRScheduler.state_dict() is not
42 | resharding friendly. Therefore, the limitation is used to allow TorchTitan to support
43 | lr scheduler resharding.
44 |
45 | Args:
46 | optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers.
47 | """
48 |
49 | schedulers: list[LRScheduler]
50 |
51 | def __init__(self, optimizers: OptimizersContainer, lr_lambda: Callable) -> None:
52 | assert len(optimizers) > 0, (
53 | "Must have at least one optimizer to create LRScheduler"
54 | )
55 |
56 | self.schedulers = [LambdaLR(optimizer, lr_lambda) for optimizer in optimizers]
57 |
58 | def __iter__(self) -> Iterator[LRScheduler]:
59 | return iter(self.schedulers)
60 |
61 | def __len__(self) -> int:
62 | return len(self.schedulers)
63 |
64 | def step(self) -> None:
65 | for scheduler in self.schedulers:
66 | scheduler.step()
67 |
68 | def state_dict(self) -> dict[str, Any]:
69 | # While there may be multiple schedulers, we only save the first one because
70 | # the state_dict is the same for all. See the limitations section in the
71 | # docstring.
72 | return self.schedulers[0].state_dict()
73 |
74 | def load_state_dict(self, state_dict: dict[str, Any]) -> None:
75 | # Load the same state_dict for all schedulers. The key value we're concerned
76 | # within ``LRScheduler.state_dict()`` is ``last_epoch``, which is an integer
77 | # that is immutable. As long as ``training.steps`` and ``lr_scheduler.warmup_steps``
78 | # in ``job_config`` remain unchanged when resuming from a checkpoint, this
79 | # approach is safe. We call ``copy()`` here to ensure extra safety.
80 | for scheduler in self.schedulers:
81 | scheduler.load_state_dict(copy.deepcopy(state_dict))
82 |
83 |
84 | def build_lr_schedulers(
85 | optimizers: OptimizersContainer, job_config: JobConfig
86 | ) -> LRSchedulersContainer:
87 | """Create a LRSchedulerContainer for the given optimizers and job config.
88 |
89 | This function creates a ``LRSchedulersContainer`` for the given optimizers.
90 | ``job_config`` should define the correct lr scheduler parameters.
91 |
92 | **Note**
93 | Users who want to customize the lr scheduler behavior can create their own
94 | ``LRSchedulersContainer`` subclass and ``build_lr_scheduler``. Passing the
95 | customized ``build_lr_schedulers`` to ``TrainSpec`` will create the customized
96 | ``LRSchedulersContainer``.
97 |
98 |
99 | Args:
100 | optimizers (OptimizersContainer): The corresponding optimizers for the
101 | lr_schedulers.
102 | """
103 | training_steps = job_config.training.steps
104 | warmup_steps = int(job_config.lr_scheduler.warmup_steps)
105 |
106 | if warmup_steps > training_steps:
107 | logger.warning(
108 | f"Warmup steps ({warmup_steps}) exceed total training steps ({training_steps}). "
109 | f"Adjusting warmup steps to {training_steps}."
110 | )
111 | warmup_steps = training_steps
112 |
113 | if job_config.lr_scheduler.decay_ratio is not None:
114 | decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio)
115 | if warmup_steps + decay_steps > training_steps:
116 | logger.warning(
117 | f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed "
118 | f"total training steps ({training_steps}). "
119 | f"Adjusting decay steps to {training_steps - warmup_steps}."
120 | )
121 | decay_steps = training_steps - warmup_steps
122 | else:
123 | decay_steps = training_steps - warmup_steps
124 | # Add a vitual last step to prevent the learning rate from dropping to 0
125 | stable_steps = training_steps + 1 - warmup_steps - decay_steps
126 | lr_decay_type = job_config.lr_scheduler.decay_type
127 | lr_min = job_config.lr_scheduler.lr_min
128 |
129 | def linear_warmup_stable_decay(
130 | current_step: int,
131 | warmup_steps: int,
132 | stable_steps: int,
133 | decay_steps: int,
134 | lr_decay_type: str,
135 | lr_min: float,
136 | ):
137 | """
138 | Computes linear warmup followed by stable learning rate for a while,
139 | then some type of decay.
140 |
141 | Per LambdaLR requirement, this is accomplished by returning
142 | a multiplicative factor `curr_adjustment` ranging from 1 to 0
143 | to adjust the learning rate to create the desired schedule.
144 |
145 | We offer three types of learning rate decay schedules:
146 | 1. `linear`: decays linearly from 1 to 0 over the decay period.
147 | 2. `sqrt`: decays as 1 minus the square root of the decay progress.
148 | 3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function.
149 |
150 | If `lr_min` is specified, the decay range is scaled from 1 to `lr_min`
151 | to ensure the learning rate does not drop below this minimum value.
152 | """
153 | warmup_stable_steps = warmup_steps + stable_steps
154 | if current_step < warmup_steps:
155 | # linear warmup
156 | # 0-indexed step, hence + 1 adjustments
157 | current_step += 1
158 | assert warmup_steps != 0, (
159 | "warmup_steps must not be zero to reach this branch"
160 | )
161 | curr_adjustment = float(current_step / warmup_steps)
162 | elif current_step < warmup_stable_steps:
163 | curr_adjustment = 1.0
164 | else:
165 | # 0-indexed step, hence + 1 adjustments
166 | current_step += 1
167 | assert decay_steps != 0, "decay_steps must not be zero to reach this branch"
168 | progress = float(current_step - warmup_stable_steps) / decay_steps
169 |
170 | if lr_decay_type == "linear":
171 | curr_adjustment = 1 - progress
172 | elif lr_decay_type == "sqrt":
173 | curr_adjustment = 1 - math.sqrt(progress)
174 | elif lr_decay_type == "cosine":
175 | curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress))
176 | curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment
177 | return curr_adjustment
178 |
179 | lr_lambda = functools.partial(
180 | linear_warmup_stable_decay,
181 | warmup_steps=warmup_steps,
182 | stable_steps=stable_steps,
183 | decay_steps=decay_steps,
184 | lr_decay_type=lr_decay_type,
185 | lr_min=lr_min,
186 | )
187 | return LRSchedulersContainer(optimizers, lr_lambda)
188 |
--------------------------------------------------------------------------------
/torchtitan/distributed/parallel_dims.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import dataclass
8 | from functools import cached_property
9 |
10 | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
11 |
12 | from torchtitan.tools.logging import logger
13 | from torchtitan.tools.utils import device_type
14 |
15 |
16 | __all__ = ["ParallelDims"]
17 |
18 |
19 | @dataclass
20 | class ParallelDims:
21 | dp_replicate: int
22 | dp_shard: int
23 | cp: int
24 | tp: int
25 | pp: int
26 | ep: int
27 | world_size: int
28 |
29 | _world_mesh: DeviceMesh = None
30 |
31 | def __post_init__(self):
32 | self._validate()
33 |
34 | def _validate(self):
35 | dp_replicate, dp_shard, cp, tp, pp, ep = (
36 | self.dp_replicate,
37 | self.dp_shard,
38 | self.cp,
39 | self.tp,
40 | self.pp,
41 | self.ep,
42 | )
43 | for d in (dp_replicate, cp, tp, pp, ep):
44 | assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
45 |
46 | assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
47 | if dp_shard < 0:
48 | self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp)
49 | assert dp_shard >= 1
50 |
51 | assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
52 | f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
53 | f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
54 | )
55 |
56 | if ep > 1:
57 | # EP would borrow all cp and some dp_shard degree
58 | assert ep % cp == 0 and (dp_shard * cp) % ep == 0
59 |
60 | def build_mesh(self) -> DeviceMesh:
61 | # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel
62 | # is not very clean, due to the limited support from DeviceMesh
63 | # for creating two staggered meshes. Will improve.
64 | if self.ep > 1:
65 | return self._build_mesh_with_ep()
66 | else:
67 | return self._build_mesh_without_ep()
68 |
69 | def _build_mesh_with_ep(self) -> DeviceMesh:
70 | # With ep, dp_shard and ep are derived submeshes:
71 | # dp_shard = dp_shard_mod_ep * dp_shard_in_ep
72 | # ep = dp_shard_in_ep * cp
73 | dp_shard_mod_ep = self.dp_shard * self.cp // self.ep
74 | dp_shard_in_ep = self.ep // self.cp
75 |
76 | dims = []
77 | names = []
78 | for d, name in zip(
79 | [
80 | self.pp,
81 | self.dp_replicate,
82 | dp_shard_mod_ep,
83 | dp_shard_in_ep,
84 | self.cp,
85 | self.tp,
86 | ],
87 | ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"],
88 | ):
89 | # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping
90 | # helps the MoE layers do mixed precision training
91 | if d > 1 or name == "dp_shard_mod_ep":
92 | dims.append(d)
93 | names.append(name)
94 |
95 | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
96 | mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
97 |
98 | # Create all the submesh here to ensure all required process groups are
99 | # initialized:
100 | # Mesh for data loading (no communication on this mesh)
101 | dp_mesh_dim_names = []
102 | # Mesh for param sharding
103 | dp_shard_cp_mesh_dim_names = []
104 | # Mesh for loss all-reduce
105 | dp_cp_mesh_dim_names = []
106 | # Mesh for ep
107 | ep_mesh_dim_names = []
108 |
109 | if self.dp_replicate_enabled:
110 | dp_mesh_dim_names.append("dp_replicate")
111 | dp_cp_mesh_dim_names.append("dp_replicate")
112 | # dp_shard_mod_ep is always needed, even if it's 1
113 | dp_mesh_dim_names.append("dp_shard_mod_ep")
114 | dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep")
115 | dp_cp_mesh_dim_names.append("dp_shard_mod_ep")
116 | if "dp_shard_in_ep" in names:
117 | dp_mesh_dim_names.append("dp_shard_in_ep")
118 | dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep")
119 | dp_cp_mesh_dim_names.append("dp_shard_in_ep")
120 | ep_mesh_dim_names.append("dp_shard_in_ep")
121 | if self.cp_enabled:
122 | dp_shard_cp_mesh_dim_names.append("cp")
123 | dp_cp_mesh_dim_names.append("cp")
124 | ep_mesh_dim_names.append("cp")
125 |
126 | mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
127 | mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp")
128 | mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
129 | mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")
130 |
131 | return mesh
132 |
133 | def _build_mesh_without_ep(self) -> DeviceMesh:
134 | dims = []
135 | names = []
136 | for d, name in zip(
137 | [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
138 | ["pp", "dp_replicate", "dp_shard", "cp", "tp"],
139 | ):
140 | if d > 1:
141 | dims.append(d)
142 | names.append(name)
143 |
144 | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
145 | mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
146 |
147 | # Create all the submesh here to ensure all required process groups are
148 | # initialized:
149 | # Mesh for data loading (no communication on this mesh)
150 | dp_mesh_dim_names = []
151 | # Mesh for param sharding
152 | dp_shard_cp_mesh_dim_names = []
153 | # Mesh for loss all-reduce
154 | dp_cp_mesh_dim_names = []
155 |
156 | if self.dp_replicate_enabled:
157 | dp_mesh_dim_names.append("dp_replicate")
158 | dp_cp_mesh_dim_names.append("dp_replicate")
159 | if self.dp_shard_enabled:
160 | dp_mesh_dim_names.append("dp_shard")
161 | dp_shard_cp_mesh_dim_names.append("dp_shard")
162 | dp_cp_mesh_dim_names.append("dp_shard")
163 | if self.cp_enabled:
164 | dp_shard_cp_mesh_dim_names.append("cp")
165 | dp_cp_mesh_dim_names.append("cp")
166 |
167 | if dp_mesh_dim_names != []:
168 | mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
169 | if dp_shard_cp_mesh_dim_names != []:
170 | mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
171 | mesh_dim_name="dp_shard_cp"
172 | )
173 | if dp_cp_mesh_dim_names != []:
174 | mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
175 |
176 | return mesh
177 |
178 | @property
179 | def world_mesh(self) -> str:
180 | # doing late init so ParallelDims can still be used as a lightweight
181 | # dataclass without having to initialize the world mesh
182 | if self._world_mesh is None:
183 | self._world_mesh = self.build_mesh()
184 | return self._world_mesh
185 |
186 | @property
187 | def dp_enabled(self):
188 | return self.dp_replicate > 1 or self.dp_shard > 1
189 |
190 | @property
191 | def dp_replicate_enabled(self):
192 | return self.dp_replicate > 1
193 |
194 | @property
195 | def dp_shard_enabled(self):
196 | return self.dp_shard > 1
197 |
198 | @property
199 | def cp_enabled(self):
200 | return self.cp > 1
201 |
202 | @property
203 | def dp_cp_enabled(self):
204 | return self.dp_enabled or self.cp_enabled
205 |
206 | @property
207 | def fsdp_enabled(self):
208 | return self.dp_shard_enabled or self.cp_enabled
209 |
210 | @property
211 | def tp_enabled(self):
212 | return self.tp > 1
213 |
214 | @property
215 | def pp_enabled(self):
216 | return self.pp > 1
217 |
218 | @property
219 | def ep_enabled(self):
220 | return self.ep > 1
221 |
222 | @cached_property
223 | def non_data_parallel_size(self):
224 | return self.cp * self.tp * self.pp
225 |
226 | @cached_property
227 | def seq_len_divisor(self):
228 | # Sequence Parallel requires that seq_len be divisible by TP degree.
229 | # https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001
230 |
231 | # Context Parallel requires that seq_len be divisible by 2 * CP degree,
232 | # when load balancing is enabled (by default).
233 | # https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246
234 | return self.tp * (self.cp * 2)
235 |
236 | @cached_property
237 | def dense_params_mesh_ndim(self):
238 | # Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh
239 | return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled
240 |
--------------------------------------------------------------------------------
/torchtitan/distributed/pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | from typing import Callable
8 |
9 | from torch.distributed.pipelining.schedules import (
10 | _PipelineSchedule,
11 | _PipelineScheduleRuntime,
12 | get_schedule_class,
13 | PipelineScheduleMulti,
14 | PipelineScheduleSingle,
15 | )
16 | from torch.distributed.pipelining.stage import PipelineStage
17 |
18 | from torchtitan.config_manager import JobConfig
19 | from torchtitan.tools.logging import logger
20 |
21 |
22 | __all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"]
23 |
24 |
25 | # TODO: It's unclear if this API is general enough to be used by other models.
26 | # If not, we should move it to a Transformer-specific directory.
27 | def generate_split_points(
28 | schedule_str: str,
29 | pp_degree: int,
30 | num_layers: int,
31 | num_layers_per_stage: int | None,
32 | input_weight: int = 1,
33 | output_weight: int = 1,
34 | ) -> list[str]:
35 | """
36 | Generate a list of split points based on the input configs. In this function,
37 | the number of effective layers considered is the summation of num_layers,
38 | input_weight, and output_weight.
39 |
40 | If num_layers_per_virtual_stage is given, we require rigid fit of the
41 | effective layers (regular layers + weighted input + weighted output)
42 | onto pipeline stages and ranks, with several assertions. It is the users'
43 | responsibility to figure out the input weight, output weight, and the
44 | number of regular layers, so that they can be arranged neatly.
45 |
46 | If num_layers_per_virtual_stage is None, we by default set each pipeline rank
47 | to have 1 stage if schedule_str is a single-stage schedule, or 2 virtual stages
48 | if it is a multi-stage schedule, and try to distribute all effective layers
49 | evenly onto the PP stages. If there are extra layers, we disperse them in
50 | the starting stages.
51 |
52 | Args:
53 | schedule_str (str): The string of the schedule name.
54 | pp_degree (int): The pipeline parallel dimension.
55 | num_layers (int): The number of layers in the model.
56 | input_weight (int): The number of layers to consider the input modules in layer calculation.
57 | output_weight (int): The number of layers to consider the output modules in layer calculation.
58 | num_layers_per_stage (int): The number of layers per (virtual) pipeline stage.
59 |
60 | Returns:
61 | list[str]: A list of split point FQNs.
62 | """
63 |
64 | schedule_class = get_schedule_class(schedule_str)
65 | is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
66 |
67 | num_effective_layers = num_layers + input_weight + output_weight
68 |
69 | if num_layers_per_stage is not None:
70 | # If num_layers_per_stage is provided, we require a rigid fit of the effective layers
71 | assert num_effective_layers % pp_degree == 0
72 | num_layers_per_pipeline_rank = num_effective_layers // pp_degree
73 |
74 | assert num_layers_per_pipeline_rank % num_layers_per_stage == 0
75 | num_stages_per_rank = num_layers_per_pipeline_rank // num_layers_per_stage
76 |
77 | num_total_virtual_stages = num_stages_per_rank * pp_degree
78 | num_extra_layers = 0
79 |
80 | if is_single_stage_schedule:
81 | assert num_stages_per_rank == 1, (
82 | f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single-stage schedules."
83 | )
84 | else:
85 | assert num_stages_per_rank >= 2, (
86 | f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi-stage schedules."
87 | )
88 | else:
89 | # In a multi-stage schedule, if num_layers_per_stage is not
90 | # provided, by default each pipeline rank has 2 virtual stages.
91 | num_stages_per_rank = 1 if is_single_stage_schedule else 2
92 | num_total_virtual_stages = pp_degree * num_stages_per_rank
93 |
94 | if num_total_virtual_stages > num_effective_layers:
95 | raise ValueError(
96 | "The number of total stages cannot be greater than the number of effective layers."
97 | )
98 |
99 | num_layers_per_stage = num_effective_layers // num_total_virtual_stages
100 | num_extra_layers = num_effective_layers % num_total_virtual_stages
101 |
102 | assert num_layers_per_stage >= max(input_weight, output_weight)
103 |
104 | splits = []
105 | current_layer = 0
106 | for i in range(num_total_virtual_stages - 1):
107 | if i == 0:
108 | current_layer += num_layers_per_stage - input_weight
109 | else:
110 | current_layer += num_layers_per_stage
111 | # extra layers will be dispersed to the first stages
112 | if num_extra_layers > 0:
113 | current_layer += 1
114 | num_extra_layers -= 1
115 | splits.append("layers." + str(current_layer))
116 |
117 | logger.info(
118 | "No 'pipeline_parallel_split_points' provided. Here is the auto-generated split, "
119 | f"which may be sub-optimal: {splits}."
120 | )
121 | return splits
122 |
123 |
124 | def build_pipeline_schedule(
125 | job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable
126 | ) -> _PipelineSchedule:
127 | """Builds a pipeline schedule for the given job configuration and stages.
128 |
129 | Args:
130 | job_config (JobConfig): The job configuration.
131 | stages (list[PipelineStage]): The stages to be scheduled.
132 | loss_fn (Callable): The loss function.
133 |
134 | Returns:
135 | _PipelineSchedule: The pipeline schedule for the given stages.
136 | """
137 | pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv
138 |
139 | # Validate that pp_schedule_csv is a valid path
140 | if pp_schedule_csv:
141 | if not os.path.isfile(pp_schedule_csv):
142 | raise FileNotFoundError(
143 | f"The specified path {pp_schedule_csv} does not exist or is not a file."
144 | )
145 | schedule_class = _PipelineScheduleRuntime
146 | else:
147 | schedule_class = get_schedule_class(
148 | job_config.parallelism.pipeline_parallel_schedule
149 | )
150 |
151 | looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
152 | microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
153 | batch_size = job_config.training.local_batch_size
154 | # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
155 | if batch_size % microbatch_size != 0:
156 | raise ValueError(
157 | f"Batch size {job_config.training.local_batch_size} must be divisible by number of microbatches {n_microbatches}. "
158 | "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
159 | )
160 | n_microbatches = batch_size // microbatch_size
161 | # We expect that the number of local stages (`len(stages)`) is the same across all ranks
162 | num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages)
163 | if n_microbatches < num_total_stages:
164 | logger.warning(
165 | f"Number of microbatches ({n_microbatches}) is less than the total number "
166 | f"of stages ({num_total_stages}) which may result in a bubble in the pipeline."
167 | )
168 |
169 | schedule = schedule_class(
170 | stages if looped_schedule else stages[0],
171 | n_microbatches=n_microbatches,
172 | loss_fn=loss_fn,
173 | )
174 | logger.info(
175 | f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} "
176 | f"with {n_microbatches} microbatches and {num_total_stages} stages."
177 | )
178 |
179 | if pp_schedule_csv:
180 | assert schedule_class in [
181 | PipelineScheduleSingle,
182 | PipelineScheduleMulti,
183 | _PipelineScheduleRuntime,
184 | ], (
185 | "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), "
186 | "and _PipelineScheduleRuntime support csv schedules"
187 | )
188 | schedule._load_csv(pp_schedule_csv)
189 |
190 | return schedule
191 |
192 |
193 | # TODO(whc) should this be a utility inside torch.pipelining?
194 | def stage_ids_this_rank(
195 | pp_rank: int, pp_size: int, num_stages: int, style: str = "loop"
196 | ) -> tuple[int]:
197 | """Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule"""
198 | assert num_stages % pp_size == 0, (
199 | f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}"
200 | )
201 | stages_per_rank = num_stages // pp_size
202 | if style == "loop":
203 | return tuple(pp_rank + s * pp_size for s in range(stages_per_rank))
204 | elif style == "v":
205 | assert stages_per_rank == 2, (
206 | f"v schedules assume 2 stages per rank, got {stages_per_rank}"
207 | )
208 | stage_v_pairs = list(
209 | zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1))
210 | )
211 | return stage_v_pairs[pp_rank]
212 |
--------------------------------------------------------------------------------
/torchtitan/components/quantization/float8.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from functools import partial
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | from torchtitan.config_manager import Float8, JobConfig
12 | from torchtitan.distributed import ParallelDims
13 | from torchtitan.protocols.model_converter import (
14 | ModelConverter,
15 | register_model_converter,
16 | )
17 | from torchtitan.tools.logging import logger
18 | from torchtitan.tools.utils import has_cuda_capability
19 |
20 | from .utils import module_filter_fn
21 |
22 | AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn"
23 |
24 |
25 | class Float8Converter(ModelConverter):
26 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
27 | self.enabled = False
28 |
29 | float8_config: Float8 = job_config.float8
30 | if has_cuda_capability(8, 9) or (
31 | float8_config.emulate and not job_config.training.compile
32 | ):
33 | pass
34 | else:
35 | raise ValueError(
36 | "Failed to swap to Float8Linear because float8 is only supported on SM89 or later."
37 | "To enable testing on older hardware, set `float8.emulate` to True in eager mode.",
38 | )
39 | try:
40 | from torchao.float8 import Float8LinearConfig
41 | except ImportError as e:
42 | raise ImportError(
43 | "torchao is not installed. Please install it to use float8 linear layers."
44 | ) from e
45 |
46 | if float8_config.recipe_name is not None and not hasattr(
47 | Float8LinearConfig, "from_recipe_name"
48 | ):
49 | logger.warning(
50 | "Failed to swap to Float8Linear with recipe lookup because the torchao version "
51 | "is too old, please install torchao v0.9.0 or later and try again",
52 | )
53 | return
54 |
55 | self.enabled = True
56 | self.filter_fqns = float8_config.filter_fqns
57 | self.moe_fqns = float8_config.moe_fqns_prototype
58 | self.filter_fn = self._init_filter_fn(float8_config)
59 |
60 | # Validate MoE training prototype limitations.
61 | if self.moe_fqns:
62 | assert job_config.parallelism.pipeline_parallel_degree == 1, (
63 | "Float8 MoE training prototype does not yet support pipeline parallelism"
64 | )
65 | assert job_config.parallelism.context_parallel_degree == 1, (
66 | "Float8 MoE training prototype does not yet support context parallelism"
67 | )
68 |
69 | if float8_config.recipe_name is not None:
70 | assert not float8_config.enable_fsdp_float8_all_gather, (
71 | "using `float8_config.enable_fsdp_float8_all_gather` together "
72 | "with `float8_config.recipe_name` is not supported"
73 | )
74 |
75 | assert not float8_config.force_recompute_fp8_weight_in_bwd, (
76 | "using `float8_config.force_recompute_fp8_weight_in_bwd` together "
77 | "with `float8_config.recipe_name` is not supported"
78 | )
79 |
80 | self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
81 | self.precompute_scale = False
82 | logger.info(
83 | f"Float8 training active with recipe {float8_config.recipe_name}"
84 | )
85 |
86 | # short-term solution for https://github.com/pytorch/pytorch/issues/150859
87 | if float8_config.recipe_name == "rowwise":
88 | torch._inductor.config.emulate_precision_casts = True
89 | logger.debug(
90 | "Set torch._inductor.config.emulate_precision_casts to True"
91 | )
92 | else:
93 | # Mutates the model inplace replacing instances of nn.Linear with Float8Linear
94 | enable_fsdp_float8_all_gather = (
95 | parallel_dims.dp_shard_enabled
96 | and float8_config.enable_fsdp_float8_all_gather
97 | )
98 | self.config = Float8LinearConfig(
99 | enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
100 | force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
101 | emulate=float8_config.emulate,
102 | )
103 | # for precompute_float8_dynamic_scale_for_fsdp
104 | self.precompute_scale = (
105 | enable_fsdp_float8_all_gather
106 | and float8_config.precompute_float8_dynamic_scale_for_fsdp
107 | )
108 | logger.info("Float8 tensorwise scaled training active")
109 |
110 | def _init_filter_fn(self, float8_config: Float8):
111 | # use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns.
112 | use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns
113 | if use_auto_filter:
114 | try:
115 | from torchao.float8 import _auto_filter_for_recipe
116 |
117 | logger.info(
118 | "Using _auto_filter_for_recipe to avoid converting linear layers with dims too small "
119 | "to benefit from float8 training. See docs/float8.md for more info."
120 | )
121 |
122 | recipe_name = (
123 | float8_config.recipe_name
124 | if float8_config.recipe_name
125 | else "tensorwise"
126 | )
127 |
128 | # remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe
129 | float8_config.filter_fqns.remove(AUTO_FILTER_SMALL_KN_FLAG)
130 |
131 | return _auto_filter_for_recipe(
132 | recipe_name,
133 | filter_fqns=float8_config.filter_fqns,
134 | )
135 | except ImportError:
136 | logger.warning(
137 | (
138 | "Using default module_filter_fn for float8 model conversion. "
139 | "To use _auto_filter_for_recipe, please install torchao nightly build."
140 | )
141 | )
142 |
143 | # use default filter func
144 | return partial(module_filter_fn, filter_fqns=float8_config.filter_fqns)
145 |
146 | def convert(self, model: nn.Module):
147 | """
148 | This function converts the linear layers of `model` to `Float8Linear`.
149 | Note that today, only dynamic tensor scaling (the default) is supported.
150 | This will mutate the model inplace.
151 | """
152 | if not self.enabled:
153 | return
154 |
155 | # MoE conversion must take place before Float8Linear conversion, otherwise the Float8Linears will
156 | # be converted back to nn.Linear:
157 | # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299
158 | # TODO: add warning in torchao when this happens, or find a better way to avoid this.
159 | if self.moe_fqns:
160 | self._convert_moe_layers(model)
161 |
162 | from torchao.float8 import convert_to_float8_training
163 |
164 | # Mutates the model inplace replacing instances of nn.Linear with Float8Linear
165 | convert_to_float8_training(
166 | model,
167 | config=self.config,
168 | module_filter_fn=self.filter_fn,
169 | )
170 | logger.info(
171 | "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
172 | f"{self.config.enable_fsdp_float8_all_gather}"
173 | )
174 |
175 | def _convert_moe_layers(self, model: nn.Module):
176 | """
177 | Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
178 | to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
179 | """
180 | from torchao.quantization.quant_api import quantize_
181 |
182 | try:
183 | from torchao.prototype.moe_training.conversion_utils import (
184 | MoETrainingConfig,
185 | )
186 | except ImportError as e:
187 | raise ImportError(
188 | "torchao installation does not have MoE training support. Please install torchao nightly build."
189 | ) from e
190 |
191 | def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
192 | for target_fqn in self.moe_fqns:
193 | if target_fqn in cur_fqn:
194 | return True
195 | return False
196 |
197 | config = MoETrainingConfig()
198 | quantize_(model, config=config, filter_fn=moe_module_filter_fn)
199 | logger.info(
200 | f"Converted MoE layers matching FQNS {self.moe_fqns} "
201 | "to use dynamic float8 rowwise quantization with scaled grouped GEMMs"
202 | )
203 |
204 | def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
205 | if not self.enabled:
206 | return
207 |
208 | if not self.precompute_scale:
209 | return
210 |
211 | from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
212 |
213 | models = [model] if isinstance(model, nn.Module) else model
214 | for m in models:
215 | precompute_float8_dynamic_scale_for_fsdp(m)
216 |
217 |
218 | register_model_converter(Float8Converter, "float8")
219 |
--------------------------------------------------------------------------------
/torchtitan/models/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | #
7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8 |
9 | from typing import Callable, ClassVar
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch.nn.attention import sdpa_kernel, SDPBackend
14 | from torch.nn.attention.flex_attention import (
15 | _mask_mod_signature,
16 | BlockMask,
17 | create_block_mask,
18 | flex_attention,
19 | )
20 |
21 | from torchtitan.tools.utils import has_cuda_capability
22 |
23 | # FlexAttention mask type. For each mask type, we initialize it at most once per
24 | # batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to
25 | # track the initialized mask.
26 | FLEX_ATTN_MASK_T = tuple[str, int | None]
27 |
28 |
29 | class FlexAttention(torch.nn.Module):
30 | """FlexAttention module that uses torch.nn.attention.flex_attention.
31 |
32 | This module is a wrapper around torch.nn.attention.flex_attention. This module
33 | implements certain common attention types, such as causal and block_causal.
34 |
35 | Args:
36 | attn_mask_type (str): The type of attention mask. Currently, we support
37 | "causal" and "block_causal". "causal" means the lower triangle of the
38 | attention matrix is masked. "block_causal" means the attention matrix
39 | is divided into blocks, where block boundary is defined by EOS token,
40 | and the lower triangle of each block is masked.
41 | fixed_block_size (int | None): The block size to be used to perform attention.
42 | If specified, each sequence will be further divided to blocks, where each
43 | block has the maximum size of ``fixed_block_size``. A query will only attend
44 | to the keys within the same block.
45 | """
46 |
47 | # We registered flex_attention related attributes as class variables as we
48 | # need to amortize the cost of compilation.
49 | flex_attn: ClassVar[Callable] = torch.compile(
50 | flex_attention, mode="max-autotune-no-cudagraphs"
51 | )
52 | compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask)
53 | used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set()
54 | # Attention mask type to the created BlockMask.
55 | # This allows us to keep track the created block masks for each
56 | # new batch. We will use this to update the block mask when a
57 | # new batch is created. This also allows user to create different
58 | # block masks for different layers.
59 | block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {}
60 |
61 | # Instance variables.
62 | attn_mask_type: str
63 |
64 | def __init__(
65 | self, attn_mask_type: str, fixed_block_size: int | None = None
66 | ) -> None:
67 | super().__init__()
68 | if attn_mask_type not in ["causal", "block_causal"]:
69 | raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.")
70 | self.attn_mask_type = attn_mask_type
71 | self.fixed_block_size = fixed_block_size
72 |
73 | FlexAttention.used_attn_mask_types.add(self.mask_key)
74 |
75 | @property
76 | def mask_key(self) -> FLEX_ATTN_MASK_T:
77 | return (self.attn_mask_type, self.fixed_block_size)
78 |
79 | def forward(
80 | self,
81 | q: torch.Tensor,
82 | k: torch.Tensor,
83 | v: torch.Tensor,
84 | scale: float | None = None,
85 | ) -> torch.Tensor:
86 | block_mask = FlexAttention.block_masks[self.mask_key]
87 | return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale)
88 |
89 | @staticmethod
90 | def _get_causal_mask_mod() -> _mask_mod_signature:
91 | def causal_mask(
92 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
93 | ):
94 | return q_idx >= kv_idx
95 |
96 | return causal_mask
97 |
98 | @staticmethod
99 | def _get_block_causal_mask_mod(
100 | batch: torch.Tensor, eos_id: int
101 | ) -> _mask_mod_signature:
102 | # batch is [b, s, h, d] shape
103 | mask = batch == eos_id
104 | mask[:, -1] = True
105 | acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1)
106 | seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32)
107 | seq_idx[:, 1:] = acc_mask[:, :-1]
108 |
109 | def block_causal_mask(
110 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
111 | ):
112 | return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
113 |
114 | return block_causal_mask
115 |
116 | @staticmethod
117 | def _fixed_block_mask_mod(
118 | mask_mod: _mask_mod_signature, fixed_block_size: int
119 | ) -> _mask_mod_signature:
120 | """
121 | Given an arbirary mask_mod, divide the input sequence to blocks
122 | and only allow attention within the same block.
123 |
124 | Args:
125 | mask_mod: The mask mod to apply to the documents
126 | fixed_block_size: The number of tokens in each block.
127 | """
128 |
129 | # Credit to @drisspg.
130 | def blocked_mask_mod(
131 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
132 | ):
133 | # Get the block index of the query and key
134 | q_block = q_idx // fixed_block_size
135 | kv_block = kv_idx // fixed_block_size
136 | # Only allow attention within the same block
137 | same_block = q_block == kv_block
138 | # Apply the original mask mod
139 | inner_mask = mask_mod(
140 | b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size
141 | )
142 |
143 | return same_block & inner_mask
144 |
145 | blocked_mask_mod.__name__ = (
146 | f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}"
147 | )
148 |
149 | return blocked_mask_mod
150 |
151 | @staticmethod
152 | @torch.no_grad()
153 | def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None:
154 | # batch is [b, s, h, d] shape
155 | for mask_key in FlexAttention.used_attn_mask_types:
156 | attn_mask_type, fixed_block_size = mask_key
157 | match attn_mask_type:
158 | case "causal":
159 | if FlexAttention.block_masks.get(mask_key, None) is not None:
160 | continue
161 | # We don't care about batch dimension --
162 | # all samples have the same lower triangle mask.
163 | batch_dimension = 1
164 | mask_mod = FlexAttention._get_causal_mask_mod()
165 | case "block_causal":
166 | if eos_id is None:
167 | raise RuntimeError(
168 | "eos_id must be provided for block_causal mask."
169 | )
170 | batch_dimension = batch.shape[0]
171 | mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id)
172 | case _:
173 | raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}")
174 |
175 | if fixed_block_size is not None and fixed_block_size > 0:
176 | mask_mod = FlexAttention._fixed_block_mask_mod(
177 | mask_mod, fixed_block_size
178 | )
179 |
180 | seq_len = batch.shape[1]
181 | block_mask = FlexAttention.compiled_create_block_mask(
182 | mask_mod, batch_dimension, None, seq_len, seq_len
183 | )
184 | FlexAttention.block_masks[mask_key] = block_mask
185 |
186 |
187 | class ScaledDotProductAttention(torch.nn.Module):
188 | backends: ClassVar[list[SDPBackend]] = []
189 |
190 | def __init__(self, attn_mask_type: str) -> None:
191 | super().__init__()
192 | if attn_mask_type != "causal":
193 | raise ValueError(
194 | "TorchTitan with SDPA currently only supports causal mask."
195 | )
196 |
197 | ScaledDotProductAttention._init_backend()
198 |
199 | @classmethod
200 | def _init_backend(cls) -> None:
201 | if cls.backends:
202 | return
203 |
204 | # Add CuDNN on B200 w/ highest priority
205 | cls.backends = [
206 | SDPBackend.FLASH_ATTENTION,
207 | SDPBackend.EFFICIENT_ATTENTION,
208 | SDPBackend.MATH,
209 | ]
210 | if has_cuda_capability(10, 0):
211 | cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION)
212 |
213 | def forward(
214 | self,
215 | q: torch.Tensor,
216 | k: torch.Tensor,
217 | v: torch.Tensor,
218 | scale: float | None = None,
219 | ) -> torch.Tensor:
220 | assert self.backends, "SDPA Backends should not be empty."
221 | with sdpa_kernel(self.backends, set_priority=True):
222 | return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale)
223 |
224 |
225 | def build_attention(
226 | use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None
227 | ):
228 | if use_flex_attn:
229 | return FlexAttention(attn_mask_type, fixed_block_size)
230 | else:
231 | if fixed_block_size is not None:
232 | raise ValueError(
233 | "TorchTitan with SDPA currently does not support fixed_block_size."
234 | )
235 | if attn_mask_type != "causal":
236 | raise ValueError(
237 | "TorchTitan with SDPA currently only supports causal mask."
238 | )
239 | return ScaledDotProductAttention(attn_mask_type)
240 |
241 |
242 | def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None:
243 | FlexAttention.init_attention_mask(batch, eos_id)
244 |
--------------------------------------------------------------------------------
/torchtitan/datasets/hf_datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from dataclasses import dataclass
8 |
9 | from typing import Any
10 |
11 | import torch
12 |
13 | from datasets import Dataset, load_dataset, Audio
14 | from datasets.distributed import split_dataset_by_node
15 | from torch.distributed.checkpoint.stateful import Stateful
16 | from torch.utils.data import IterableDataset
17 |
18 | from torchtitan.components.dataloader import ParallelAwareDataloader
19 | from torchtitan.components.tokenizer import BaseTokenizer
20 | from torchtitan.config_manager import JobConfig
21 | from torchtitan.tools.logging import logger
22 | from transformers import DataCollatorForSeq2Seq
23 |
24 |
25 | def audio_array_to_text(
26 | audio_array: torch.tensor,
27 | audio_tokenizer,
28 | feature_extractor,
29 | num_quantizers: int,
30 | max_seconds: int = 20,
31 | ) -> str:
32 | # truncate the audio array to the expected length
33 | if audio_array.shape[-1] > max_seconds * feature_extractor.sampling_rate:
34 | audio_array = audio_array[: max_seconds * feature_extractor.sampling_rate]
35 | #
36 | inputs = feature_extractor(
37 | raw_audio=audio_array,
38 | sampling_rate=feature_extractor.sampling_rate,
39 | return_tensors="pt",
40 | ).to(audio_tokenizer.device)
41 | with torch.no_grad():
42 | # Encode the audio input to get the audio codes
43 | # This will return a tensor of shape (batch_size, num_quantizers, sequence_length)
44 | # where each quantizer's output is in a separate dimension
45 | encoder_outputs = audio_tokenizer.encode(
46 | inputs["input_values"],
47 | inputs["padding_mask"],
48 | num_quantizers=num_quantizers,
49 | )
50 | flatten_audio_codes = encoder_outputs.audio_codes.transpose(1, 2).reshape(-1)
51 | assert flatten_audio_codes.numel() % num_quantizers == 0
52 | steps = []
53 | for i in range(0, flatten_audio_codes.numel(), num_quantizers):
54 | group = [
55 | f"<{flatten_audio_codes[i + j].item()}_{j}>" for j in range(num_quantizers)
56 | ]
57 | steps.append(group)
58 |
59 | parts = [tok for step in steps for tok in step]
60 |
61 | text = "".join(parts)
62 |
63 | del inputs, encoder_outputs, flatten_audio_codes
64 | torch.cuda.empty_cache()
65 | return f""
66 |
67 |
68 | def process_audio(
69 | sample: dict[str, Any],
70 | audio_tokenizer,
71 | feature_extractor,
72 | num_quantizers: int,
73 | task: str = "a2a",
74 | ) -> str:
75 | audio_sample = sample["audio"]["array"]
76 | text = audio_array_to_text(
77 | audio_sample,
78 | audio_tokenizer,
79 | feature_extractor,
80 | num_quantizers,
81 | )
82 | if task == "tts":
83 | transcription = sample["text"]
84 | text = transcription + text
85 | return text
86 |
87 |
88 | class HuggingFaceDataset(IterableDataset, Stateful):
89 | def __init__(
90 | self,
91 | dataset_name: str,
92 | tokenizer: BaseTokenizer,
93 | audio_tokenizer=None,
94 | feature_extractor=None,
95 | num_quantizers: int = 4,
96 | seq_len: int = 2048,
97 | dp_rank: int = 0,
98 | dp_world_size: int = 1,
99 | infinite: bool = False,
100 | task: str = "a2a",
101 | ) -> None:
102 | if dataset_name == "peoples_speech":
103 | ds = load_dataset(
104 | "parquet",
105 | data_files="data/peoples_speech/**/*.parquet",
106 | split="train",
107 | streaming=True,
108 | )
109 | ds = ds.cast_column(
110 | "audio", Audio(sampling_rate=feature_extractor.sampling_rate)
111 | )
112 | elif dataset_name == "librispeech_asr_train":
113 | ds = load_dataset(
114 | "openslr/librispeech_asr",
115 | split="train.other.500",
116 | streaming=True,
117 | )
118 | ds = ds.cast_column(
119 | "audio", Audio(sampling_rate=feature_extractor.sampling_rate)
120 | )
121 | elif dataset_name == "librispeech_asr_test":
122 | ds = load_dataset(
123 | "openslr/librispeech_asr",
124 | split="test.clean",
125 | streaming=True,
126 | )
127 | ds = ds.cast_column(
128 | "audio", Audio(sampling_rate=feature_extractor.sampling_rate)
129 | )
130 |
131 | else:
132 | raise ValueError(f"Dataset {dataset_name} is not supported. ")
133 |
134 | self.dataset_name = dataset_name
135 | self._data = split_dataset_by_node(ds, dp_rank, dp_world_size)
136 | self._data = self._data.shuffle(seed=42, buffer_size=10_000)
137 | self.tokenizer = tokenizer
138 | self.audio_tokenizer = audio_tokenizer
139 | self.feature_extractor = feature_extractor
140 | self.num_quantizers = num_quantizers
141 | self.seq_len = seq_len
142 | self.task = task
143 | self.infinite = infinite
144 |
145 | # Variables for checkpointing
146 | self._sample_idx = 0
147 | self._token_buffer: list[int] = []
148 |
149 | def _get_data_iter(self):
150 | # For map-style datasets, resume by skipping to the correct index
151 | # For iterable-style datasets, the underlying iterator already points to the correct index
152 | if isinstance(self._data, Dataset):
153 | if self._sample_idx == len(self._data):
154 | return iter([])
155 | else:
156 | return iter(self._data.skip(self._sample_idx))
157 |
158 | return iter(self._data)
159 |
160 | def __iter__(self):
161 | while True:
162 | data_iter = self._get_data_iter()
163 | while True:
164 | try:
165 | sample = next(data_iter)
166 | except StopIteration:
167 | break
168 | except Exception as e:
169 | logger.error(
170 | f"Error while iterating over dataset {self.dataset_name}: {e}"
171 | )
172 | self._sample_idx += 1
173 | continue
174 |
175 | try:
176 | sample_text = process_audio(
177 | sample,
178 | self.audio_tokenizer,
179 | self.feature_extractor,
180 | self.num_quantizers,
181 | self.task,
182 | )
183 | self._sample_idx += 1
184 | yield self.tokenizer(
185 | sample_text,
186 | max_length=self.seq_len,
187 | padding="max_length",
188 | truncation=True,
189 | )
190 | except Exception as e:
191 | logger.error(
192 | f"Error while processing sample in dataset {self.dataset_name}: {e}"
193 | )
194 | self._sample_idx += 1
195 | continue
196 |
197 | if not self.infinite:
198 | logger.warning(f"Dataset {self.dataset_name} has run out of data")
199 | break
200 | else:
201 | # Reset offset for the next iteration
202 | self._sample_idx = 0
203 | logger.warning(f"Dataset {self.dataset_name} is being re-looped")
204 | # Ensures re-looping a dataset loaded from a checkpoint works correctly
205 | if not isinstance(self._data, Dataset):
206 | if hasattr(self._data, "set_epoch") and hasattr(
207 | self._data, "epoch"
208 | ):
209 | self._data.set_epoch(self._data.epoch + 1)
210 |
211 | def load_state_dict(self, state_dict):
212 | if isinstance(self._data, Dataset):
213 | self._sample_idx = state_dict["sample_idx"]
214 | else:
215 | assert "data" in state_dict
216 | self._data.load_state_dict(state_dict["data"])
217 |
218 | def state_dict(self):
219 | _state_dict = {"token_buffer": self._token_buffer}
220 |
221 | if isinstance(self._data, Dataset):
222 | _state_dict["sample_idx"] = self._sample_idx
223 | else:
224 | # Save the iterable dataset's state to later efficiently resume from it
225 | # https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration
226 | _state_dict["data"] = self._data.state_dict()
227 |
228 | return _state_dict
229 |
230 |
231 | def build_hf_dataloader(
232 | dp_world_size: int,
233 | dp_rank: int,
234 | tokenizer: BaseTokenizer,
235 | audio_tokenizer,
236 | feature_extractor,
237 | job_config: JobConfig,
238 | infinite: bool = True,
239 | ) -> ParallelAwareDataloader:
240 | """Build a data loader for HuggingFace datasets."""
241 | dataset_name = job_config.training.dataset
242 | batch_size = job_config.training.local_batch_size
243 | seq_len = job_config.training.seq_len
244 |
245 | hf_ds = HuggingFaceDataset(
246 | dataset_name=dataset_name,
247 | tokenizer=tokenizer,
248 | audio_tokenizer=audio_tokenizer,
249 | feature_extractor=feature_extractor,
250 | num_quantizers=job_config.model.num_quantizers,
251 | seq_len=seq_len,
252 | dp_rank=dp_rank,
253 | dp_world_size=dp_world_size,
254 | infinite=infinite,
255 | task=job_config.training.task,
256 | )
257 |
258 | collate_fn = DataCollatorForSeq2Seq(
259 | tokenizer,
260 | pad_to_multiple_of=8,
261 | return_tensors="pt",
262 | padding=True,
263 | max_length=seq_len,
264 | )
265 |
266 | return ParallelAwareDataloader(
267 | dataset=hf_ds,
268 | dp_rank=dp_rank,
269 | dp_world_size=dp_world_size,
270 | batch_size=batch_size,
271 | collate_fn=collate_fn,
272 | )
273 |
274 |
275 | def build_hf_validation_dataloader(
276 | dp_world_size: int,
277 | dp_rank: int,
278 | tokenizer: BaseTokenizer,
279 | audio_tokenizer,
280 | feature_extractor,
281 | job_config: JobConfig,
282 | ) -> ParallelAwareDataloader:
283 | """Build a validation data loader for HuggingFace datasets."""
284 | dataset_name = job_config.validation.dataset
285 | batch_size = job_config.validation.local_batch_size
286 | seq_len = job_config.validation.seq_len
287 |
288 | hf_ds = HuggingFaceDataset(
289 | dataset_name=dataset_name,
290 | tokenizer=tokenizer,
291 | audio_tokenizer=audio_tokenizer,
292 | feature_extractor=feature_extractor,
293 | num_quantizers=job_config.model.num_quantizers,
294 | seq_len=seq_len,
295 | dp_rank=dp_rank,
296 | dp_world_size=dp_world_size,
297 | infinite=False,
298 | task=job_config.training.task,
299 | )
300 |
301 | collate_fn = DataCollatorForSeq2Seq(
302 | tokenizer,
303 | pad_to_multiple_of=8,
304 | return_tensors="pt",
305 | padding=True,
306 | max_length=seq_len,
307 | )
308 |
309 | return ParallelAwareDataloader(
310 | dataset=hf_ds,
311 | dp_rank=dp_rank,
312 | dp_world_size=dp_world_size,
313 | batch_size=batch_size,
314 | collate_fn=collate_fn,
315 | )
316 |
--------------------------------------------------------------------------------
/torchtitan/components/optimizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import functools
8 | from typing import Any, Generic, Iterator, TypeVar
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.distributed.checkpoint.state_dict import (
13 | get_optimizer_state_dict,
14 | set_optimizer_state_dict,
15 | StateDictOptions,
16 | )
17 | from torch.distributed.checkpoint.stateful import Stateful
18 | from torch.optim import Optimizer
19 |
20 | from torchtitan.components.ft import FTManager, has_torchft
21 | from torchtitan.config_manager import JobConfig
22 | from torchtitan.distributed import ParallelDims
23 |
24 | __all__ = [
25 | "OptimizersContainer",
26 | "build_optimizers",
27 | ]
28 |
29 |
30 | if has_torchft:
31 | import torchft as ft
32 |
33 |
34 | T = TypeVar("T", bound=Optimizer)
35 |
36 |
37 | class OptimizersContainer(Optimizer, Stateful, Generic[T]):
38 | """A container for multiple optimizers.
39 |
40 | This class is used to wrap multiple optimizers into a single object that can be
41 | used to reduce the complexity of the training loop. This mimics the behavior of
42 | ``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``.
43 |
44 | **Note**
45 | Users who want to customize the optimizer behavior can inherit from this class and
46 | extend the functionality as needed. The following methods must follow the same signature
47 | as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``,
48 | ``load_state_dict()``.
49 |
50 | **Limitations**
51 | This class assumes that all the optimizers are the same type and have the same
52 | configurations. With this assumption, TorchTitan can support lr scheduler resharding
53 | (e.g., loading a checkpoint with a different number of GPUs and/or different
54 | parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the
55 | resharding for the optimizer state but not for the lr scheduler state, hence the limitation.
56 |
57 | Args:
58 | model_parts (List[nn.Module]): List of model parts to be optimized.
59 | optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers.
60 | name (str): Name of the optimizers.
61 | """
62 |
63 | optimizers: list[T]
64 | model_parts: list[nn.Module]
65 |
66 | def __init__(
67 | self,
68 | model_parts: list[nn.Module],
69 | optimizer_cls: type[T],
70 | optimizer_kwargs: dict[str, Any],
71 | ) -> None:
72 | all_params = []
73 | self.optimizers = []
74 | self.model_parts = model_parts
75 | for model in self.model_parts:
76 | params = [p for p in model.parameters() if p.requires_grad]
77 | self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
78 | all_params.extend(params)
79 | self._validate_length(len(self.model_parts))
80 | self._post_init(all_params, optimizer_kwargs)
81 |
82 | def __iter__(self) -> Iterator[T]:
83 | return iter(self.optimizers)
84 |
85 | def __len__(self) -> int:
86 | return len(self.optimizers)
87 |
88 | def step(self, *args, **kwargs) -> None:
89 | for optimizer in self.optimizers:
90 | optimizer.step(*args, **kwargs)
91 |
92 | def zero_grad(self, *args, **kwargs) -> None:
93 | for optimizer in self.optimizers:
94 | optimizer.zero_grad(*args, **kwargs)
95 |
96 | def state_dict(self) -> dict[str, Any]:
97 | func = functools.partial(
98 | get_optimizer_state_dict,
99 | options=StateDictOptions(flatten_optimizer_state_dict=True),
100 | )
101 | return {
102 | k: v
103 | for sd in map(func, self.model_parts, self.optimizers)
104 | for k, v in sd.items()
105 | }
106 |
107 | def load_state_dict(self, state_dict: dict[str, Any]) -> None:
108 | func = functools.partial(
109 | set_optimizer_state_dict,
110 | optim_state_dict=state_dict,
111 | options=StateDictOptions(flatten_optimizer_state_dict=True),
112 | )
113 | list(map(func, self.model_parts, self.optimizers))
114 |
115 | def _validate_length(self, expected_length: int) -> None:
116 | assert expected_length == len(self.optimizers), (
117 | "Must pass one optimizer per model part or per param if "
118 | "using OptimizersInBackwardContainer."
119 | )
120 |
121 | def _post_init(
122 | self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any]
123 | ) -> None:
124 | # We need to call Optimizer.__init__() to initialize some necessary optimizer
125 | # functionality such as hooks.
126 | Optimizer.__init__(self, all_params, optimizer_kwargs)
127 |
128 |
129 | class OptimizersInBackwardContainer(OptimizersContainer):
130 | """OptimizersContainer for executing ``optim.step()`` in backward pass.
131 |
132 | This class extend ``OptimizersContainer`` to support optimizer step in
133 | backward pass. ``step()`` and ``zero_grad()`` are no-op in this class.
134 | Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to
135 | execute these methods when the gradient is accumulated.
136 | """
137 |
138 | def __init__(
139 | self,
140 | model_parts: list[nn.Module],
141 | optimizer_cls: type[T],
142 | optimizer_kwargs: dict[str, Any],
143 | ) -> None:
144 | all_params = []
145 | self.model_parts = model_parts
146 |
147 | optim_dict = {}
148 | for model in self.model_parts:
149 | for p in model.parameters():
150 | if p.requires_grad:
151 | optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
152 | all_params.append(p)
153 |
154 | def optim_hook(param) -> None:
155 | optim_dict[param].step()
156 | optim_dict[param].zero_grad()
157 |
158 | for model in self.model_parts:
159 | for param in model.parameters():
160 | if param.requires_grad:
161 | param.register_post_accumulate_grad_hook(optim_hook)
162 |
163 | self.optimizers = list(optim_dict.values())
164 |
165 | self._validate_length(
166 | sum(len(list(model.parameters())) for model in self.model_parts)
167 | )
168 | self._post_init(all_params, optimizer_kwargs)
169 |
170 | def step(self) -> None:
171 | pass
172 |
173 | def zero_grad(self) -> None:
174 | pass
175 |
176 |
177 | class FTOptimizersContainer(OptimizersContainer):
178 | def __init__(
179 | self,
180 | model_parts: list[nn.Module],
181 | optimizer_cls: type[T],
182 | optimizer_kwargs: dict[str, Any],
183 | ft_manager: "ft.Manager",
184 | use_ft_optimizer: bool = True,
185 | ) -> None:
186 | super().__init__(model_parts, optimizer_cls, optimizer_kwargs)
187 |
188 | # Force to initialize the optimizer state so that `optim.step()`
189 | # won't be called by state_dict() and load_state_dict().
190 | _ = {
191 | k: v
192 | for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
193 | for k, v in sd.items()
194 | }
195 | self.cache_state_dict: dict[str, Any] = {}
196 | self._ft_optimizer = ft.Optimizer(ft_manager, self)
197 | # Whether to determine quorum using FT.optimizer,
198 | # in semi-sync training we use the synchronization step to start quorum
199 | self._use_ft_optimizer: bool = use_ft_optimizer
200 |
201 | def init_cache_state_dict(self) -> None:
202 | self.cache_state_dict = super().state_dict()
203 |
204 | def state_dict(self) -> dict[str, Any]:
205 | return self.cache_state_dict
206 |
207 | def load_state_dict(self, state_dict: dict[str, Any]) -> None:
208 | # We have to invalidate the `cache_state_dict` because optimizer uses
209 | # assign instead of copy when doing `load_state_dict()`. Without
210 | # invalidating the `cache_state_dict`, there will be memory leakage.
211 | self.cache_state_dict = {}
212 | super().load_state_dict(state_dict)
213 | self.init_cache_state_dict()
214 |
215 | def step(self, *args, **kwargs) -> None:
216 | """Calling the correct step() depending on the caller.
217 |
218 | TorchFT's OptimizerWrapper.step() is designed to be called only once
219 | per train step per ft.Manager regardless how many optimizers are used.
220 | Hence we will need to appropriately dispatch the call.
221 | """
222 | if self._use_ft_optimizer:
223 | self._use_ft_optimizer = False
224 | self._ft_optimizer.step(*args, **kwargs)
225 | self._use_ft_optimizer = True
226 | else:
227 | super().step(*args, **kwargs)
228 |
229 | def zero_grad(self, *args, **kwargs) -> None:
230 | """Calling the correct zero_grad() depending on the caller.
231 |
232 | Check the comment in ``step()``.
233 | """
234 | if self._use_ft_optimizer:
235 | self._use_ft_optimizer = False
236 | self._ft_optimizer.zero_grad(*args, **kwargs)
237 | self._use_ft_optimizer = True
238 | else:
239 | super().zero_grad(*args, **kwargs)
240 |
241 |
242 | def build_optimizers(
243 | model_parts: list[nn.Module],
244 | job_config: JobConfig,
245 | parallel_dims: ParallelDims,
246 | ft_manager: FTManager | None = None,
247 | ) -> OptimizersContainer:
248 | """Create a OptimizersContainer for the given model parts and job config.
249 |
250 | This function creates a ``OptimizersContainer`` for the given model parts.
251 | ``job_config`` should define the correct optimizer name and parameters.
252 | This function currently supports creating ``OptimizersContainer`` and
253 | ``OptimizersInBackwardContainer``.
254 |
255 | **Note**
256 | Users who want to customize the optimizer behavior can create their own
257 | ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the
258 | customized ``build_optimizers`` to ``TrainSpec`` will create the customized
259 | ``OptimizersContainer``.
260 |
261 | Args:
262 | model_parts (List[nn.Module]): List of model parts to be optimized.
263 | job_config (JobConfig): Job config containing the optimizer name and parameters.
264 | parallel_dims (ParallelDims): Parallel dimensions for the model.
265 | """
266 | optim_in_bwd = job_config.optimizer.early_step_in_backward
267 | if optim_in_bwd:
268 | if parallel_dims.ep_enabled:
269 | raise NotImplementedError(
270 | "Optimizers in backward is not supported with Expert Parallel."
271 | )
272 | if parallel_dims.pp_enabled:
273 | raise NotImplementedError(
274 | "Optimizers in backward is not supported with Pipeline Parallel."
275 | )
276 | if ft_manager and ft_manager.enabled:
277 | raise NotImplementedError(
278 | "TorchFT is not supported with optimizers in backward."
279 | )
280 |
281 | name = job_config.optimizer.name
282 | lr = job_config.optimizer.lr
283 | beta1 = job_config.optimizer.beta1
284 | beta2 = job_config.optimizer.beta2
285 | eps = job_config.optimizer.eps
286 | weight_decay = job_config.optimizer.weight_decay
287 |
288 | optim_implementation = job_config.optimizer.implementation
289 | assert optim_implementation in ["fused", "foreach", "for-loop"]
290 |
291 | fused = optim_implementation == "fused"
292 | foreach = optim_implementation == "foreach"
293 |
294 | optimizer_kwargs = {
295 | "lr": lr,
296 | "betas": (beta1, beta2),
297 | "eps": eps,
298 | "weight_decay": weight_decay,
299 | "fused": fused,
300 | "foreach": foreach,
301 | }
302 |
303 | optimizer_classes = {
304 | "Adam": torch.optim.Adam,
305 | "AdamW": torch.optim.AdamW,
306 | }
307 | if name not in optimizer_classes:
308 | raise NotImplementedError(f"Optimizer {name} not added.")
309 | optimizer_cls = optimizer_classes[name]
310 |
311 | if optim_in_bwd:
312 | return OptimizersInBackwardContainer(
313 | model_parts, optimizer_cls, optimizer_kwargs
314 | )
315 |
316 | if ft_manager and ft_manager.enabled:
317 | return FTOptimizersContainer(
318 | model_parts,
319 | optimizer_cls,
320 | optimizer_kwargs,
321 | ft_manager.manager,
322 | use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None,
323 | )
324 |
325 | return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
326 |
--------------------------------------------------------------------------------
/torchtitan/components/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import time
9 | from collections import namedtuple
10 | from datetime import datetime
11 | from typing import Any, TYPE_CHECKING
12 |
13 | import torch
14 | from torchtitan.components.lr_scheduler import LRSchedulersContainer
15 | from torchtitan.components.optimizer import OptimizersContainer
16 | from torchtitan.config_manager import JobConfig
17 | from torchtitan.distributed import ParallelDims
18 | from torchtitan.tools import utils
19 | from torchtitan.tools.logging import logger
20 | from torchtitan.tools.utils import Color, device_module, device_type
21 |
22 | if TYPE_CHECKING:
23 | pass
24 |
25 |
26 | # named tuple for passing device memory stats for logging
27 | DeviceMemStats = namedtuple(
28 | "DeviceMemStats",
29 | [
30 | "max_active_gib",
31 | "max_active_pct",
32 | "max_reserved_gib",
33 | "max_reserved_pct",
34 | "num_alloc_retries",
35 | "num_ooms",
36 | ],
37 | )
38 |
39 |
40 | class DeviceMemoryMonitor:
41 | def __init__(self, device: str = f"{device_type}:0"):
42 | self.device = torch.device(device) # device object
43 | self.device_name = device_module.get_device_name(self.device)
44 | self.device_index = device_module.current_device()
45 | self.device_capacity = device_module.get_device_properties(
46 | self.device
47 | ).total_memory
48 | self.device_capacity_gib = self._to_gib(self.device_capacity)
49 |
50 | device_module.reset_peak_memory_stats()
51 | device_module.empty_cache()
52 |
53 | def _to_gib(self, memory_in_bytes):
54 | # NOTE: GiB (gibibyte) is 1024, vs GB is 1000
55 | _gib_in_bytes = 1024 * 1024 * 1024
56 | memory_in_gib = memory_in_bytes / _gib_in_bytes
57 | return memory_in_gib
58 |
59 | def _to_pct(self, memory):
60 | return 100 * memory / self.device_capacity
61 |
62 | def get_peak_stats(self):
63 | device_info = device_module.memory_stats(self.device)
64 |
65 | max_active = device_info.get("active_bytes.all.peak", -1)
66 | max_active_gib = self._to_gib(max_active)
67 | max_active_pct = self._to_pct(max_active)
68 |
69 | max_reserved = device_info.get("reserved_bytes.all.peak", -1)
70 | max_reserved_gib = self._to_gib(max_reserved)
71 | max_reserved_pct = self._to_pct(max_reserved)
72 |
73 | num_retries = device_info.get("num_alloc_retries", -1)
74 | num_ooms = device_info.get("num_ooms", -1)
75 |
76 | if num_retries > 0:
77 | logger.warning(
78 | f"{num_retries} {device_type.upper()} memory allocation retries."
79 | )
80 | if num_ooms > 0:
81 | logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.")
82 |
83 | return DeviceMemStats(
84 | max_active_gib,
85 | max_active_pct,
86 | max_reserved_gib,
87 | max_reserved_pct,
88 | num_retries,
89 | num_ooms,
90 | )
91 |
92 | def reset_peak_stats(self):
93 | device_module.reset_peak_memory_stats()
94 |
95 |
96 | def build_device_memory_monitor():
97 | device_memory_monitor = DeviceMemoryMonitor(device_type)
98 | logger.info(
99 | f"{device_type.upper()} capacity: {device_memory_monitor.device_name} "
100 | f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory"
101 | )
102 | return device_memory_monitor
103 |
104 |
105 | class BaseLogger:
106 | """Logger that does nothing, used when logging is disabled."""
107 |
108 | def log(self, metrics: dict[str, Any], step: int) -> None:
109 | pass
110 |
111 | def close(self) -> None:
112 | pass
113 |
114 |
115 | class WandBLogger(BaseLogger):
116 | """Logger implementation for Weights & Biases."""
117 |
118 | def __init__(self, log_dir: str, job_config: JobConfig, tag: str | None = None):
119 | # Import wandb here to avoid startup import
120 | import wandb
121 |
122 | self.wandb = wandb
123 | self.tag = tag
124 |
125 | # Create logging directory
126 | os.makedirs(log_dir, exist_ok=True)
127 |
128 | self.wandb.init(
129 | project=os.getenv("WANDB_PROJECT", "torchtitan"),
130 | dir=log_dir,
131 | name=job_config.job.wandb_run_name,
132 | config=job_config.to_dict(),
133 | )
134 | logger.info("WandB logging enabled")
135 |
136 | def log(self, metrics: dict[str, Any], step: int) -> None:
137 | wandb_metrics = {
138 | (k if self.tag is None else f"{self.tag}/{k}"): v
139 | for k, v in metrics.items()
140 | }
141 | self.wandb.log(wandb_metrics, step=step)
142 |
143 | def close(self) -> None:
144 | if self.wandb.run is not None:
145 | self.wandb.finish()
146 |
147 |
148 | def ensure_pp_loss_visible(
149 | parallel_dims: ParallelDims, job_config: JobConfig, color: Color
150 | ) -> None:
151 | """
152 | Ensures that the loss is visible on the console for pipeline-parallel training.
153 |
154 | For pipeline-parallel training, the loss is only visible on the last pipeline stage.
155 | This function checks if the appropriate rank is included in the LOG_RANK environment
156 | variable and warns if it's not.
157 | """
158 |
159 | # V Block Schedules return loss on rank 0
160 | if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
161 | return
162 |
163 | # Calculate the rank where loss is visible (first rank of the last pipeline stage)
164 | world_size = parallel_dims.world_size
165 | pp_size = parallel_dims.pp
166 | loss_visible_rank = (world_size // pp_size) * (pp_size - 1)
167 |
168 | # Check if the loss-visible rank is included in LOG_RANK environment variable
169 | env_logged_ranks = os.environ.get("LOG_RANK", "").split(",")
170 | if env_logged_ranks == [""]:
171 | env_logged_ranks = []
172 |
173 | if str(loss_visible_rank) not in env_logged_ranks:
174 | logger.warning(
175 | f"{color.red}Pipeline Parallel loss is not visible. "
176 | f"Please add {color.yellow}rank {loss_visible_rank}{color.red} "
177 | f"to LOG_RANK environment variable in run_train.sh.{color.reset}"
178 | )
179 |
180 |
181 | def _get_metrics_rank(
182 | parallel_dims: ParallelDims,
183 | job_config: JobConfig,
184 | ) -> int:
185 | """
186 | Determines which rank should log metrics.
187 |
188 | Returns:
189 | int: The rank responsible for logging metrics:
190 | - Rank 0 for non-pipeline-parallel configs
191 | - Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule
192 | - The first rank of the last pipeline stage for other pipeline-parallel schedules
193 | """
194 | # Early return for non-pipeline-parallel configurations
195 | if not parallel_dims.pp_enabled:
196 | return 0
197 |
198 | # V Block Schedules return loss on rank 0
199 | if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble":
200 | return 0
201 |
202 | # Calculate first rank of the last pipeline stage
203 | world_size = parallel_dims.world_size
204 | pp_size = parallel_dims.pp
205 | return (world_size // pp_size) * (pp_size - 1)
206 |
207 |
208 | def _build_metric_logger(
209 | job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None
210 | ) -> BaseLogger:
211 | """
212 | Build an appropriate metric logger based on configuration.
213 | """
214 | metrics_config = job_config.metrics
215 |
216 | # Log initial config state
217 | logger.debug(f"Building logger with config: wandb={metrics_config.enable_wandb}, ")
218 |
219 | # Check if any logging backend is enabled
220 | has_logging_enabled = metrics_config.enable_wandb
221 |
222 | # Determine if this rank should log
223 | should_log = has_logging_enabled
224 | if (not metrics_config.save_for_all_ranks) and should_log:
225 | metrics_rank = _get_metrics_rank(parallel_dims, job_config)
226 | should_log = torch.distributed.get_rank() == metrics_rank
227 |
228 | logger.debug(
229 | f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}"
230 | )
231 |
232 | if not should_log:
233 | logger.debug("Returning BaseLogger due to should_log=False")
234 | return BaseLogger()
235 |
236 | # Setup logging directory
237 | dump_dir = job_config.job.dump_folder
238 | base_log_dir = os.path.join(
239 | dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M")
240 | )
241 |
242 | if job_config.fault_tolerance.enable:
243 | base_log_dir = os.path.join(
244 | base_log_dir,
245 | f"replica_{job_config.fault_tolerance.replica_id}",
246 | )
247 |
248 | if metrics_config.save_for_all_ranks:
249 | base_log_dir = os.path.join(
250 | base_log_dir, f"rank_{torch.distributed.get_rank()}"
251 | )
252 |
253 | # Create loggers in priority order
254 | if metrics_config.enable_wandb:
255 | logger.debug("Attempting to create WandB logger")
256 | try:
257 | return WandBLogger(base_log_dir, job_config, tag)
258 | except Exception as e:
259 | if "No module named 'wandb'" in str(e):
260 | logger.error(
261 | "Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'."
262 | )
263 | else:
264 | logger.error(f"Failed to create WandB logger: {e}")
265 |
266 | logger.debug("No loggers enabled, returning BaseLogger")
267 | return BaseLogger()
268 |
269 |
270 | class MetricsProcessor:
271 | """Metrics processor to processes the metrics and log metrics.
272 |
273 | The current MetricsProcessor log some metrics to STDOUT and some metrics to
274 | WandB.
275 |
276 | Args:
277 | job_config (JobConfig): Job configuration.
278 | parallel_dims (ParallelDims): Parallel dimensions.
279 | tag (Optional[str]): Tag to use for WandB. Defaults to None.
280 | """
281 |
282 | logger: BaseLogger
283 | parallel_dims: ParallelDims
284 | job_config: JobConfig
285 | device_memory_monitor: DeviceMemoryMonitor
286 | color: utils.NoColor | utils.Color
287 |
288 | gpu_peak_flops: int
289 | ntokens_since_last_log: int
290 | data_loading_times: list[float]
291 | time_last_log: float
292 |
293 | num_flops_per_token: int
294 | optimizers: OptimizersContainer | None
295 | lr_schedulers: LRSchedulersContainer | None
296 |
297 | def __init__(
298 | self,
299 | job_config: JobConfig,
300 | parallel_dims: ParallelDims,
301 | tag: str | None = None,
302 | ):
303 | self.logger = _build_metric_logger(job_config, parallel_dims, tag)
304 | self.parallel_dims = parallel_dims
305 | self.job_config = job_config
306 | self.device_memory_monitor = build_device_memory_monitor()
307 | # used for colorful printing
308 | self.color = (
309 | utils.NoColor()
310 | if job_config.metrics.disable_color_printing
311 | else utils.Color()
312 | )
313 |
314 | self.gpu_peak_flops = utils.get_peak_flops(
315 | self.device_memory_monitor.device_name
316 | )
317 | self.ntokens_since_last_log = 0
318 | self.data_loading_times = []
319 | self.time_last_log = time.perf_counter()
320 | self.device_memory_monitor.reset_peak_stats()
321 |
322 | # These variables have to be set later as they depend on other components or model.
323 | self.num_flops_per_token = -1
324 | self.optimizers = None
325 | self.lr_schedulers = None
326 |
327 | def should_log(self, step: int) -> bool:
328 | return step == 1 or step % self.job_config.metrics.log_freq == 0
329 |
330 | def log(
331 | self,
332 | step: int,
333 | global_avg_loss: float,
334 | global_max_loss: float,
335 | grad_norm: float,
336 | lr: float,
337 | epoch: int,
338 | extra_metrics: dict[str, Any] | None = None,
339 | ):
340 | assert self.num_flops_per_token > 0, "num_flops_per_token must be set"
341 |
342 | time_delta = time.perf_counter() - self.time_last_log
343 |
344 | # tokens per second per device, abbreviated as tps
345 | tps = self.ntokens_since_last_log / (
346 | time_delta * self.parallel_dims.non_data_parallel_size
347 | )
348 | # model FLOPS utilization
349 | # For its definition and calculation, please refer to the PaLM paper:
350 | # https://arxiv.org/abs/2204.02311
351 | mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops
352 | tflops = self.num_flops_per_token * tps / 1e12
353 |
354 | time_end_to_end = time_delta / self.job_config.metrics.log_freq
355 | time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times)
356 | time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta
357 |
358 | device_mem_stats = self.device_memory_monitor.get_peak_stats()
359 |
360 | metrics = {
361 | "loss_metrics/global_avg_loss": global_avg_loss,
362 | "loss_metrics/global_max_loss": global_max_loss,
363 | "grad_norm": grad_norm,
364 | "lr": lr,
365 | "epoch": epoch,
366 | "throughput(tps)": tps,
367 | "tflops": tflops,
368 | "mfu(%)": mfu,
369 | "time_metrics/end_to_end(s)": time_end_to_end,
370 | "time_metrics/data_loading(s)": time_data_loading,
371 | "time_metrics/data_loading(%)": time_data_loading_pct,
372 | "memory/max_active(GiB)": device_mem_stats.max_active_gib,
373 | "memory/max_active(%)": device_mem_stats.max_active_pct,
374 | "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
375 | "memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
376 | "memory/num_alloc_retries": device_mem_stats.num_alloc_retries,
377 | "memory/num_ooms": device_mem_stats.num_ooms,
378 | }
379 |
380 | if extra_metrics:
381 | metrics.update(extra_metrics)
382 |
383 | self.logger.log(metrics, step)
384 |
385 | color = self.color
386 | logger.info(
387 | f"{color.red}step: {step:2} "
388 | f"{color.green}loss: {global_avg_loss:7.4f} "
389 | f"{color.orange}grad_norm: {grad_norm:7.4f} "
390 | f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
391 | f"({device_mem_stats.max_reserved_pct:.2f}%) "
392 | f"{color.blue}tps: {round(tps):,} "
393 | f"{color.cyan}tflops: {tflops:,.2f} "
394 | f"{color.magenta}mfu: {mfu:.2f}%{color.reset}"
395 | )
396 |
397 | self.ntokens_since_last_log = 0
398 | self.data_loading_times.clear()
399 | self.time_last_log = time.perf_counter()
400 | self.device_memory_monitor.reset_peak_stats()
401 |
402 | def log_validation(self, loss: float, step: int):
403 | time_delta = time.perf_counter() - self.time_last_log
404 |
405 | device_mem_stats = self.device_memory_monitor.get_peak_stats()
406 |
407 | # tokens per second per device, abbreviated as tps
408 | tps = self.ntokens_since_last_log / (
409 | time_delta * self.parallel_dims.non_data_parallel_size
410 | )
411 |
412 | metrics = {
413 | "validation_metrics/loss": loss,
414 | "validation_metrics/throughput(tps)": tps,
415 | "validation_metrics/memory/max_active(GiB)": device_mem_stats.max_active_gib,
416 | "validation_metrics/memory/max_active(%)": device_mem_stats.max_active_pct,
417 | "validation_metrics/memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib,
418 | "validation_metrics/memory/max_reserved(%)": device_mem_stats.max_reserved_pct,
419 | }
420 | self.logger.log(metrics, step)
421 |
422 | color = self.color
423 | logger.info(
424 | f"{color.yellow}validate step: {step:2} "
425 | f"{color.green}loss: {loss:7.4f} "
426 | f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB"
427 | f"({device_mem_stats.max_reserved_pct:.2f}%) "
428 | f"{color.blue}tps: {round(tps):,}{color.reset}"
429 | )
430 |
431 | self.ntokens_since_last_log = 0
432 | self.time_last_log = time.perf_counter()
433 | self.device_memory_monitor.reset_peak_stats()
434 |
435 | def close(self):
436 | self.logger.close()
437 |
438 |
439 | def build_metrics_processor(
440 | job_config: JobConfig,
441 | parallel_dims: ParallelDims,
442 | tag: str | None = None,
443 | ) -> MetricsProcessor:
444 | """Create a metrics processor.
445 |
446 | Args:
447 | job_config (JobConfig): Job configuration.
448 | parallel_dims (ParallelDims): Parallel dimensions.
449 | tag (str | None): Tag to use for WandB. Defaults to None.
450 |
451 | Returns:
452 | MetricsProcessor: A metrics processor.
453 | """
454 | return MetricsProcessor(job_config, parallel_dims, tag)
455 |
--------------------------------------------------------------------------------
/torchtitan/components/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the BSD-style license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import json
9 |
10 | import os
11 | from abc import ABC, abstractmethod
12 | from typing import Any, Optional, Union
13 |
14 | from tokenizers import AddedToken, Tokenizer
15 | from torchtitan.config_manager import JobConfig
16 | from torchtitan.tools.logging import logger
17 | from typing_extensions import override
18 |
19 |
20 | class BaseTokenizer(ABC):
21 | # base tokenizer interface, for typing purpose mainly
22 | def __init__(self):
23 | self.eos_id = 0
24 |
25 | @abstractmethod
26 | def encode(self, *args, **kwargs) -> list[int]: ...
27 |
28 | @abstractmethod
29 | def decode(self, *args, **kwargs) -> str: ...
30 |
31 | @abstractmethod
32 | def get_vocab_size(self) -> int: ...
33 |
34 |
35 | class HuggingFaceTokenizer(BaseTokenizer):
36 | """
37 | A tokenizer wrapper that handles BOS/EOS token inference and encoding.
38 |
39 | This class loads tokenizer files and automatically infers BOS/EOS tokens from
40 | a configuration file (tokenizer_config.json). It provides an encode method that adds
41 | BOS/EOS tokens based on whether the underlying tokenizer adds them automatically.
42 |
43 | Args:
44 | tokenizer_path (str): Path to directory containing tokenizer files
45 | """
46 |
47 | def __init__(
48 | self,
49 | tokenizer_path: str,
50 | ):
51 | super().__init__()
52 | self.tokenizer_path = tokenizer_path
53 |
54 | # Initialize BOS/EOS token attributes (frequently used)
55 | self.bos_id = None
56 | self.eos_id = None
57 | self.bos_token = None
58 | self.eos_token = None
59 |
60 | # Load the underlying tokenizer
61 | self.tokenizer = self._load_tokenizer_from_path(tokenizer_path)
62 |
63 | # Load configuration files
64 | self.config = self._load_config(
65 | os.path.join(tokenizer_path, "tokenizer_config.json")
66 | )
67 |
68 | # Infer special tokens and adding BOS/EOS behavior
69 | self._infer_special_tokens()
70 | self._infer_should_add_bos_eos()
71 |
72 | def _load_config(self, config_path: str) -> Optional[dict]:
73 | """Load configuration from JSON file if it exists."""
74 | if os.path.exists(config_path):
75 | with open(config_path, "r") as f:
76 | return json.load(f)
77 | return None
78 |
79 | def _load_tokenizer_from_path(self, tokenizer_path: str) -> Tokenizer:
80 | """Load tokenizer from various file formats."""
81 | if not os.path.exists(tokenizer_path):
82 | raise FileNotFoundError(f"Tokenizer path '{tokenizer_path}' does not exist")
83 |
84 | # Define paths for different tokenizer file types
85 | tokenizer_json_path = os.path.join(tokenizer_path, "tokenizer.json")
86 | vocab_txt_path = os.path.join(tokenizer_path, "vocab.txt")
87 | vocab_json_path = os.path.join(tokenizer_path, "vocab.json")
88 | merges_txt_path = os.path.join(tokenizer_path, "merges.txt")
89 |
90 | # Strategy 1: Load from tokenizer.json (preferred for modern tokenizers)
91 | if os.path.exists(tokenizer_json_path):
92 | logger.info("Loading tokenizer from tokenizer.json")
93 | return Tokenizer.from_file(tokenizer_json_path)
94 | # Strategy 2: Load from vocab files (with or without merges.txt)
95 | elif os.path.exists(vocab_json_path) or os.path.exists(vocab_txt_path):
96 | # Load vocabulary
97 | if os.path.exists(vocab_json_path):
98 | logger.info("Loading vocabulary from vocab.json")
99 | with open(vocab_json_path, "r") as f:
100 | vocab = json.load(f)
101 | vocab_source = "vocab.json"
102 | else:
103 | logger.info("Loading vocabulary from vocab.txt")
104 | vocab = {}
105 | with open(vocab_txt_path, "r") as f:
106 | for i, line in enumerate(f):
107 | token = line.strip()
108 | if token:
109 | vocab[token] = i
110 | vocab_source = "vocab.txt"
111 |
112 | # Strategy 2a: Use BPE if merges.txt exists
113 | if os.path.exists(merges_txt_path):
114 | logger.info(f"Loading BPE tokenizer from {vocab_source} + merges.txt")
115 | from tokenizers import decoders, pre_tokenizers, processors
116 | from tokenizers.models import BPE
117 |
118 | # Load merges from file and convert to tuples
119 | merges = []
120 | with open(merges_txt_path, "r") as f:
121 | for line in f:
122 | line = line.strip()
123 | if line and not line.startswith(
124 | "#"
125 | ): # Skip comments and empty lines
126 | parts = line.split()
127 | if len(parts) >= 2:
128 | merges.append((parts[0], parts[1]))
129 |
130 | # Create BPE model
131 | bpe_model = BPE(vocab=vocab, merges=merges)
132 | tokenizer = Tokenizer(bpe_model)
133 |
134 | # Configure GPT-2 style components for proper space handling
135 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(
136 | add_prefix_space=False
137 | )
138 | tokenizer.decoder = decoders.ByteLevel()
139 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=True)
140 |
141 | return tokenizer
142 |
143 | # Strategy 2b: Use WordLevel if no merges.txt
144 | else:
145 | logger.info(f"Loading WordLevel tokenizer from {vocab_source}")
146 | from tokenizers.models import WordLevel
147 |
148 | word_level_model = WordLevel(vocab=vocab, unk_token="[UNK]")
149 | return Tokenizer(word_level_model)
150 |
151 | else:
152 | # List available files for debugging
153 | available_files = [
154 | f
155 | for f in os.listdir(tokenizer_path)
156 | if os.path.isfile(os.path.join(tokenizer_path, f))
157 | ]
158 | raise FileNotFoundError(
159 | f"No supported tokenizer files found in '{tokenizer_path}'. "
160 | f"Available files: {available_files}. "
161 | "Looking for: tokenizer.json, tokenizer.model, vocab.txt+merges.txt, or vocab.json+merges.txt"
162 | )
163 |
164 | def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[str]:
165 | """
166 | Parse special tokens from config that can be either strings or dicts.
167 | HF tokens are stored as either {'bos_token': ''} or {'bos_token': {'content': '', ...}}.
168 | """
169 | token = config.get(key)
170 | if isinstance(token, dict):
171 | if "content" not in token:
172 | raise ValueError(f"Could not parse {key} from config")
173 | token = token["content"]
174 | elif token is not None and not isinstance(token, str):
175 | raise ValueError(
176 | f"Could not parse {key} from config - expected string or dict"
177 | )
178 | return token
179 |
180 | def _process_special_token(
181 | self, token_str: str, token_config: dict, token_id: Optional[int] = None
182 | ) -> AddedToken:
183 | """
184 | Process a special token and update BOS/EOS attributes if applicable.
185 |
186 | Args:
187 | token_str: The token string content
188 | token_config: Token configuration dictionary
189 | token_id: Optional explicit token ID (for added_tokens_decoder)
190 |
191 | Returns:
192 | AddedToken object to be added to the tokenizer
193 | """
194 | # Get reference BOS/EOS tokens from config for comparison
195 | config_bos_token = (
196 | self._get_token_from_config(self.config, "bos_token")
197 | if self.config
198 | else None
199 | )
200 | config_eos_token = (
201 | self._get_token_from_config(self.config, "eos_token")
202 | if self.config
203 | else None
204 | )
205 |
206 | # Store BOS/EOS tokens as class attributes if they match
207 | if token_str == config_bos_token:
208 | self.bos_token = token_str
209 | self.bos_id = (
210 | token_id
211 | if token_id is not None
212 | else self.tokenizer.token_to_id(token_str)
213 | )
214 | elif token_str == config_eos_token:
215 | self.eos_token = token_str
216 | self.eos_id = (
217 | token_id
218 | if token_id is not None
219 | else self.tokenizer.token_to_id(token_str)
220 | )
221 |
222 | # Create AddedToken object based on config format
223 | if isinstance(token_config, dict):
224 | if token_config.get("__type") == "AddedToken" or "content" in token_config:
225 | # Handle both AddedToken format and added_tokens_decoder format
226 | return AddedToken(
227 | content=token_str,
228 | single_word=token_config.get("single_word", False),
229 | lstrip=token_config.get("lstrip", False),
230 | rstrip=token_config.get("rstrip", False),
231 | normalized=token_config.get("normalized", True),
232 | special=token_config.get("special", True),
233 | )
234 |
235 | # Fallback to simple special token
236 | return AddedToken(content=token_str, special=True)
237 |
238 | def _infer_special_tokens(self):
239 | """
240 | Read special tokens from config and add them to the underlying tokenizer.
241 | Store BOS/EOS tokens as class attributes since they are frequently used.
242 |
243 | This method handles multiple token configuration formats:
244 | 1. Standard top-level keys (bos_token, eos_token, etc.)
245 | 2. added_tokens_decoder dictionary (used by models like Llama 3.1)
246 | """
247 | standard_keys = [
248 | "bos_token",
249 | "eos_token",
250 | "pad_token",
251 | "unk_token",
252 | "sep_token",
253 | "cls_token",
254 | "mask_token",
255 | ]
256 |
257 | # List to collect AddedToken objects for updating the underlying tokenizer
258 | added_tokens_to_add = []
259 |
260 | if not self.config:
261 | return
262 |
263 | # Process standard top-level token keys
264 | for key in standard_keys:
265 | token_config = self.config.get(key)
266 | if token_config is not None:
267 | token_str = self._get_token_from_config(self.config, key)
268 | if token_str is not None:
269 | added_token = self._process_special_token(token_str, token_config)
270 | added_tokens_to_add.append(added_token)
271 |
272 | # Process added_tokens_decoder (comprehensive special token definitions)
273 | added_tokens_decoder = self.config.get("added_tokens_decoder", {})
274 | for token_id_str, token_config in added_tokens_decoder.items():
275 | if isinstance(token_config, dict) and "content" in token_config:
276 | token_str = token_config["content"]
277 | token_id = int(token_id_str)
278 | added_token = self._process_special_token(
279 | token_str, token_config, token_id
280 | )
281 | added_tokens_to_add.append(added_token)
282 |
283 | # Update the underlying tokenizer with special tokens
284 | if added_tokens_to_add:
285 | self.tokenizer.add_special_tokens(added_tokens_to_add)
286 |
287 | # Update BOS/EOS token IDs after adding to tokenizer (in case they changed)
288 | if self.bos_token:
289 | self.bos_id = self.tokenizer.token_to_id(self.bos_token)
290 | if self.eos_token:
291 | self.eos_id = self.tokenizer.token_to_id(self.eos_token)
292 |
293 | def _infer_should_add_bos_eos(self):
294 | """
295 | Determine if we should add BOS/EOS tokens based on config settings.
296 | If config explicitly specifies add_bos_token/add_eos_token, follow that.
297 | Otherwise, determine if the underlying tokenizer automatically adds them.
298 | """
299 | self.default_add_bos = False
300 | self.default_add_eos = False
301 | self.hf_adds_bos = False
302 | self.hf_adds_eos = False
303 |
304 | # First, determine if underlying tokenizer auto-adds BOS/EOS tokens empirically
305 | encoded_empty_str = self.tokenizer.encode("").ids
306 | if self.bos_id is not None and self.bos_id in encoded_empty_str:
307 | self.hf_adds_bos = True
308 | if self.eos_id is not None and self.eos_id in encoded_empty_str:
309 | self.hf_adds_eos = True
310 |
311 | # Check tokenizer_config.json for explicit settings - these override empirical detection
312 | if self.config:
313 | config_add_bos = self.config.get("add_bos_token")
314 | config_add_eos = self.config.get("add_eos_token")
315 | if config_add_bos is not None:
316 | self.default_add_bos = bool(config_add_bos)
317 | if config_add_eos is not None:
318 | self.default_add_eos = bool(config_add_eos)
319 |
320 | def encode(self, *args, **kwargs) -> list[int]:
321 | """
322 | Encode text into token IDs with BOS/EOS handling.
323 |
324 | Args:
325 | text (str): The text to encode
326 | add_bos (bool): Whether to add BOS token (if not already added by tokenizer)
327 | add_eos (bool): Whether to add EOS token (if not already added by tokenizer)
328 |
329 | Returns:
330 | list[int]: List of token IDs
331 | """
332 | # Extract arguments
333 | if len(args) >= 1:
334 | text = args[0]
335 | else:
336 | text = kwargs.get("text", "")
337 |
338 | add_bos = kwargs.get("add_bos", self.default_add_bos)
339 | add_eos = kwargs.get("add_eos", self.default_add_eos)
340 |
341 | # Get base token IDs from the underlying tokenizer
342 | token_ids = self.tokenizer.encode(text).ids
343 |
344 | # Add BOS token if requested and not already added by tokenizer
345 | if not self.hf_adds_bos and add_bos:
346 | if self.bos_id is not None:
347 | token_ids.insert(0, self.bos_id)
348 |
349 | # Add EOS token if requested and not already added by tokenizer
350 | if not self.hf_adds_eos and add_eos:
351 | if self.eos_id is not None:
352 | token_ids.append(self.eos_id)
353 |
354 | return token_ids
355 |
356 | @override
357 | def decode(self, *args, **kwargs) -> str:
358 | """
359 | Decode token IDs back to text.
360 |
361 | Args:
362 | token_ids (list[int]): List of token IDs to decode
363 | **kwargs: Additional arguments passed to the underlying tokenizer's decode method
364 | (e.g., skip_special_tokens)
365 |
366 | Returns:
367 | str: Decoded text
368 | """
369 | # Extract token_ids from arguments
370 | if len(args) >= 1:
371 | token_ids = args[0]
372 | # Pass through remaining kwargs
373 | return self.tokenizer.decode(token_ids, **kwargs)
374 | else:
375 | token_ids = kwargs.pop("token_ids", [])
376 | # Pass through remaining kwargs after removing token_ids
377 | return self.tokenizer.decode(token_ids, **kwargs)
378 |
379 | @property
380 | def vocab_size(self) -> int:
381 | """Get the vocabulary size."""
382 | return self.tokenizer.get_vocab_size()
383 |
384 | def get_vocab_size(self) -> int:
385 | """Get the vocabulary size."""
386 | return self.tokenizer.get_vocab_size()
387 |
388 | def get_vocab(self) -> dict[str, int]:
389 | """Get the vocabulary as a dictionary."""
390 | return self.tokenizer.get_vocab()
391 |
392 | def token_to_id(self, token: str) -> Optional[int]:
393 | """Convert token to ID."""
394 | return self.tokenizer.token_to_id(token)
395 |
396 | def id_to_token(self, token_id: int) -> Optional[str]:
397 | """Convert ID to token."""
398 | return self.tokenizer.id_to_token(token_id)
399 |
400 |
401 | def build_hf_tokenizer(
402 | job_config: JobConfig,
403 | ) -> Union[HuggingFaceTokenizer, BaseTokenizer]:
404 | """
405 | Builds a HuggingFaceTokenizer from the specified path.
406 |
407 | This function creates a HuggingFaceTokenizer instance that handles BOS/EOS token
408 | inference and intelligent encoding. The tokenizer automatically detects and loads
409 | from various file formats and infers special token behavior.
410 |
411 | Args:
412 | JobConfig: A JobConfig object containing the path to the tokenizer directory.
413 |
414 | Returns:
415 | tokenizer (HuggingFaceTokenizer): Loaded tokenizer instance with intelligent BOS/EOS handling
416 | """
417 | tokenizer = HuggingFaceTokenizer(job_config.model.tokenizer_path)
418 | return tokenizer
419 |
--------------------------------------------------------------------------------