├── .python-version ├── assets ├── output.wav ├── llama-mimi.png ├── great_day_gt.wav └── log_validation.png ├── tests ├── __init__.py └── unit_tests │ ├── __init__.py │ └── test_audio_array_to_text.py ├── train.sh ├── torchtitan ├── distributed │ ├── __init__.py │ ├── parallel_dims.py │ └── pipeline.py ├── models │ ├── __init__.py │ ├── llama3 │ │ ├── model │ │ │ ├── state_dict_adapter.py │ │ │ └── args.py │ │ ├── train_configs │ │ │ ├── llama3_70b.toml │ │ │ ├── llama3_405b.toml │ │ │ ├── llama3_8b.toml │ │ │ └── debug_model.toml │ │ ├── __init__.py │ │ └── infra │ │ │ └── pipeline.py │ └── attention.py ├── __init__.py ├── components │ ├── quantization │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── mx.py │ │ └── float8.py │ ├── loss.py │ ├── dataloader.py │ ├── ft.py │ ├── validate.py │ ├── lr_scheduler.py │ ├── optimizer.py │ ├── metrics.py │ └── tokenizer.py ├── tools │ ├── logging.py │ ├── profiling.py │ └── utils.py ├── protocols │ ├── state_dict_adapter.py │ ├── model_converter.py │ └── train_spec.py └── datasets │ └── hf_datasets.py ├── .gitignore ├── pyproject.toml ├── scripts └── convert_dcp_to_hf.py ├── LICENSE ├── config ├── llama3_1_8b.toml ├── llama3_2_1b.toml └── llama3_2_1b_peoples_speech.toml ├── eval ├── salmon.py ├── sLM21.py └── sStoryCloze.py ├── README.md └── inference.py /.python-version: -------------------------------------------------------------------------------- 1 | 3.13 2 | 3 | -------------------------------------------------------------------------------- /assets/output.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/output.wav -------------------------------------------------------------------------------- /assets/llama-mimi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/llama-mimi.png -------------------------------------------------------------------------------- /assets/great_day_gt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/great_day_gt.wav -------------------------------------------------------------------------------- /assets/log_validation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/llm-jp/llama-mimi/HEAD/assets/log_validation.png -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /tests/unit_tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | set -eux 2 | 3 | export NGPU=8 4 | export LOG_RANK=0 5 | export CONFIG_FILE="config/llama3_2_1b_peoples_speech.toml" 6 | 7 | torchrun --nproc_per_node=${NGPU} --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ 8 | -m torchtitan.train --job.config_file ${CONFIG_FILE} 9 | -------------------------------------------------------------------------------- /torchtitan/distributed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from torchtitan.distributed.parallel_dims import ParallelDims 9 | 10 | 11 | __all__ = ["ParallelDims"] 12 | -------------------------------------------------------------------------------- /torchtitan/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | # Import the built-in models here so that the corresponding register_model_spec() 9 | # will be called. 10 | # import torchtitan.models.deepseek_v3 # noqa: F401 11 | import torchtitan.models.llama3 # noqa: F401 12 | -------------------------------------------------------------------------------- /torchtitan/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Import to register quantization modules. 8 | import torchtitan.components.quantization # noqa: F401 9 | 10 | # Import the built-in models here so that the corresponding register_model_spec() 11 | # will be called. 12 | import torchtitan.models # noqa: F401 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .idea 3 | .DS_Store 4 | *.egg-info 5 | build 6 | outputs 7 | dist/* 8 | .vscode 9 | .vs 10 | 11 | # data 12 | data 13 | out 14 | wandb 15 | 16 | torchtitan/datasets/**/*.model 17 | 18 | # tokenizer models 19 | assets/**/*.model 20 | assets/**/*.json 21 | assets/**/*.txt 22 | torchtitan/experiments/flux/assets/* 23 | 24 | # temp files 25 | *.log 26 | error.json 27 | _remote_module_non_scriptable.py 28 | 29 | # Editor temporaries (VIM) 30 | [._]*.s[a-v][a-z] 31 | [._]*.sw[a-p] 32 | [._]s[a-rt-v][a-z] 33 | [._]ss[a-gi-z] 34 | [._]sw[a-p] 35 | Session.vim 36 | Sessionx.vim 37 | .netrwhist 38 | *~ 39 | .~lock.* 40 | 41 | # macOS dir files 42 | .DS_Store 43 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/model/state_dict_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any 8 | 9 | from torchtitan.protocols.state_dict_adapter import StateDictAdapter 10 | 11 | 12 | class Llama3StateDictAdapter(StateDictAdapter): 13 | @staticmethod 14 | def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]: 15 | # TODO: implement this 16 | return state_dict 17 | 18 | @staticmethod 19 | def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]: 20 | # TODO: implement this 21 | return hf_state_dict 22 | -------------------------------------------------------------------------------- /torchtitan/components/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # [Note] Getting the 'torchao' package: 8 | # This script requires the 'torchao' package to function correctly. 9 | # Please ensure you have this package installed from the appropriate repository. 10 | # You can obtain it from https://github.com/pytorch/ao by following the 11 | # installation instructions. 12 | 13 | # Note: Performance 14 | # The quantization modules are intended to be ran under `torch.compile`` for competitive performance 15 | 16 | # Import to register quantization modules as ModelConverter 17 | import torchtitan.components.quantization.float8 # noqa: F401 18 | import torchtitan.components.quantization.mx # noqa: F401 19 | -------------------------------------------------------------------------------- /torchtitan/tools/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | def init_logger(): 15 | rank = int(os.environ.get("RANK", 0)) # デフォルトは0 16 | 17 | if rank == 0: 18 | logger.setLevel(logging.INFO) 19 | ch = logging.StreamHandler() 20 | ch.setLevel(logging.INFO) 21 | formatter = logging.Formatter( 22 | f"[rank{rank}]:[titan] %(asctime)s - %(name)s - %(levelname)s - %(message)s" 23 | ) 24 | ch.setFormatter(formatter) 25 | logger.addHandler(ch) 26 | else: 27 | # rank != 0 のときはログ出力を抑制 28 | logger.setLevel(logging.ERROR) # もしくは logging.CRITICAL でもOK 29 | 30 | # suppress verbose torch.profiler logging 31 | os.environ["KINETO_LOG_LEVEL"] = "5" 32 | -------------------------------------------------------------------------------- /torchtitan/components/quantization/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch.nn as nn 8 | 9 | 10 | def module_filter_fn(mod: nn.Module, fqn: str, filter_fqns: list[str]) -> bool: 11 | """ 12 | Filter function to determine which modules should be converted. 13 | For both Float8 and MXFP8, we only convert Linear modules 14 | with dimensions divisible by 16 and not matching any filtered FQNs. 15 | """ 16 | if not isinstance(mod, nn.Linear): 17 | return False 18 | 19 | # All dims must be divisible by 16 due to float8 tensorcore hardware requirements. 20 | dims_multiples_of_16 = ( 21 | mod.weight.shape[0] % 16 == 0 and mod.weight.shape[1] % 16 == 0 22 | ) 23 | 24 | # If the fqn matches any filtered fqn, then we should not convert this module. 25 | is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns) 26 | 27 | return dims_multiples_of_16 and not is_filtered_fqn 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "torchtitan" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "accelerate>=1.9.0", 9 | "blobfile>=3.0.0", 10 | "datasets>=3.6.0", 11 | "fsspec>=2025.3.0", 12 | "librosa>=0.11.0", 13 | "openai-whisper>=20250625", 14 | "soundfile>=0.13.1", 15 | "tabulate>=0.9.0", 16 | "tiktoken>=0.9.0", 17 | "tomli>=1.1.0 ; python_full_version < '3.11'", 18 | "torch>=2.7.1", 19 | "torchaudio>=2.7.1", 20 | "torchcodec==0.5.0", 21 | "torchdata>=0.8.0", 22 | "transformers>=4.53.2", 23 | "tyro>=0.9.25", 24 | "wandb>=0.20.1", 25 | "seaborn>=0.13.2", 26 | "moshi>=0.2.11", 27 | "audiobox-aesthetics>=0.0.4", 28 | "pyannote-audio>=3.3.2", 29 | "python-dotenv>=1.1.1", 30 | "openai>=1.107.1", 31 | "flash-attn==2.8.2", 32 | ] 33 | 34 | [dependency-groups] 35 | dev = [ 36 | "matplotlib>=3.10.3", 37 | "mypy>=1.17.0", 38 | "ruff>=0.12.4", 39 | ] 40 | 41 | 42 | [build-system] 43 | requires = ["hatchling"] 44 | build-backend = "hatchling.build" 45 | 46 | [tool.hatch.build.targets.wheel] 47 | packages = ["torchtitan"] 48 | 49 | [tool.uv] 50 | no-build-isolation-package = ["flash-attn"] 51 | -------------------------------------------------------------------------------- /torchtitan/protocols/state_dict_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import ABC, abstractmethod 8 | from typing import Any 9 | 10 | 11 | class StateDictAdapter(ABC): 12 | """Abstract base class for state dict transformations. 13 | 14 | This class defines the interface for converting between native model 15 | state dict format and other model state dict formats. 16 | """ 17 | 18 | @staticmethod 19 | @abstractmethod 20 | def to_hf(state_dict: dict[str, Any]) -> dict[str, Any]: 21 | """Convert from native model state dict to HuggingFace format. 22 | 23 | Args: 24 | state_dict: The native model state dict 25 | 26 | Returns: 27 | The converted HuggingFace format state dict 28 | """ 29 | pass 30 | 31 | @staticmethod 32 | @abstractmethod 33 | def from_hf(hf_state_dict: dict[str, Any]) -> dict[str, Any]: 34 | """Obtain native model state dict from HuggingFace format. 35 | 36 | Args: 37 | hf_state_dict: The HuggingFace format state dict 38 | 39 | Returns: 40 | The converted native model state dict 41 | """ 42 | pass 43 | -------------------------------------------------------------------------------- /scripts/convert_dcp_to_hf.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from torchtitan.components.checkpoint import ModelWrapper 3 | import torch.distributed.checkpoint as dcp 4 | import torch 5 | from torchtitan.train import expand_tokenizer_with_unit_tokens 6 | 7 | if __name__ == "__main__": 8 | num_quantizers = 4 9 | checkpoint_id = f"outputs/Llama-3.2-1B_peoples_speech-q4-s1024/checkpoint/step-5000" 10 | output_dir = f"models/Llama-3.2-1B_peoples_speech-q4-s1024" 11 | model_name = "meta-llama/Llama-3.2-1B" 12 | model = AutoModelForCausalLM.from_pretrained(model_name) 13 | 14 | device = "cuda" if torch.cuda.is_available() else "cpu" 15 | tokenizer = AutoTokenizer.from_pretrained(model_name) 16 | tokenizer.pad_token_id = 0 17 | tokenizer = expand_tokenizer_with_unit_tokens( 18 | tokenizer, 19 | codebook_size=2048, 20 | num_quantizers=num_quantizers, 21 | ) 22 | 23 | embedding_size = model.get_input_embeddings().weight.shape[0] 24 | if len(tokenizer) > embedding_size: 25 | model.resize_token_embeddings(len(tokenizer)) 26 | 27 | wrapped = ModelWrapper(model) 28 | print(wrapped) 29 | dcp.load(wrapped.state_dict(), checkpoint_id=checkpoint_id) 30 | model.config.num_quantizers = num_quantizers 31 | model.save_pretrained(output_dir) 32 | tokenizer.save_pretrained(output_dir) 33 | -------------------------------------------------------------------------------- /torchtitan/components/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | from typing import Callable, TypeAlias 9 | 10 | import torch 11 | 12 | from torchtitan.config_manager import JobConfig 13 | from torchtitan.tools.logging import logger 14 | 15 | LossFunction: TypeAlias = Callable[..., torch.Tensor] 16 | 17 | 18 | def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 19 | """Common cross-entropy loss function for Transformer models training.""" 20 | return torch.nn.functional.cross_entropy( 21 | pred.flatten(0, 1).float(), labels.flatten(0, 1) 22 | ) 23 | 24 | 25 | def build_cross_entropy_loss(job_config: JobConfig): 26 | loss_fn = cross_entropy_loss 27 | if job_config.training.compile: 28 | logger.info("Compiling the loss function with torch.compile") 29 | loss_fn = torch.compile(loss_fn) 30 | return loss_fn 31 | 32 | 33 | def rescale_accumulated_loss(unwrapped_loss_fn, accumulation_steps): 34 | """Add a mean reduction over `accumulation_steps` to the given 35 | `unwrapped_loss_fn`. 36 | """ 37 | 38 | @functools.wraps(unwrapped_loss_fn) 39 | def accumulated_loss_fn(*args, **kwargs): 40 | loss = unwrapped_loss_fn(*args, **kwargs) 41 | return loss / accumulation_steps 42 | 43 | return accumulated_loss_fn 44 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/llama3_70b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 64 A100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 70B training" 7 | 8 | [profiling] 9 | enable_profiling = true 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 10 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | 18 | [model] 19 | name = "llama3" 20 | flavor = "70B" 21 | tokenizer_path = "./assets/tokenizer/Llama-3.1-8B" 22 | # converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 1.5e-4 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps 31 | 32 | [training] 33 | local_batch_size = 8 34 | seq_len = 8192 35 | max_norm = 1.0 # grad norm clipping 36 | steps = 1000 37 | compile = false 38 | dataset = "c4" 39 | 40 | [parallelism] 41 | data_parallel_replicate_degree = 1 42 | data_parallel_shard_degree = -1 43 | tensor_parallel_degree = 8 # 8-way TP 44 | pipeline_parallel_degree = 1 45 | context_parallel_degree = 1 46 | 47 | [checkpoint] 48 | enable_checkpoint = false 49 | folder = "checkpoint" 50 | interval = 500 51 | last_save_model_only = true 52 | export_dtype = "float32" 53 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 54 | 55 | [activation_checkpoint] 56 | mode = "full" 57 | 58 | [float8] 59 | enable_fsdp_float8_all_gather = false 60 | precompute_float8_dynamic_scale_for_fsdp = false 61 | filter_fqns = ["output"] 62 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | (c) Meta Platforms, Inc. and affiliates. 4 | LLM-jp (2025) 5 | 6 | Redistribution and use in source and binary forms, with or without modification, 7 | are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice,this list 10 | of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, this 13 | list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its contributors may 17 | be used to endorse or promote products derived from this software without specific 18 | prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY 21 | EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 22 | OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT 23 | SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, 24 | INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED 25 | TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 26 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 27 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN 28 | ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH 29 | DAMAGE. 30 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/llama3_405b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 128 H100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 405B training" 7 | 8 | [profiling] 9 | enable_profiling = true 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 10 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | 18 | [model] 19 | name = "llama3" 20 | flavor = "405B" 21 | tokenizer_path = "./assets/tokenizer/Llama-3.1-8B" 22 | converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 8e-5 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps 31 | 32 | [training] 33 | local_batch_size = 2 34 | seq_len = 8192 35 | max_norm = 1.0 # grad norm clipping 36 | steps = 3000 37 | compile = true 38 | dataset = "c4" 39 | 40 | [parallelism] 41 | data_parallel_replicate_degree = 1 42 | data_parallel_shard_degree = -1 43 | tensor_parallel_degree = 8 # 8-way TP 44 | enable_async_tensor_parallel = true 45 | pipeline_parallel_degree = 1 46 | context_parallel_degree = 1 47 | 48 | [checkpoint] 49 | enable_checkpoint = false 50 | folder = "checkpoint" 51 | interval = 500 52 | last_save_model_only = true 53 | export_dtype = "float32" 54 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 55 | 56 | [activation_checkpoint] 57 | mode = "full" # ["none", "selective", "full"] 58 | 59 | [float8] 60 | enable_fsdp_float8_all_gather = true 61 | precompute_float8_dynamic_scale_for_fsdp = true 62 | filter_fqns = ["output"] 63 | -------------------------------------------------------------------------------- /tests/unit_tests/test_audio_array_to_text.py: -------------------------------------------------------------------------------- 1 | from torchtitan.datasets.hf_datasets import audio_array_to_text 2 | import torchaudio 3 | from transformers import MimiModel, AutoFeatureExtractor 4 | 5 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to("cpu") 6 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi") 7 | print(feature_extractor.sampling_rate) 8 | 9 | audio_path = "assets/great_day.mp3" 10 | waveform, sample_rate = torchaudio.load(audio_path) 11 | # load audio array 12 | print(waveform) 13 | print(sample_rate) 14 | 15 | 16 | text = audio_array_to_text(waveform[0], audio_tokenizer, feature_extractor, 4) 17 | print(text) 18 | assert text == "" 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/llama3_8b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 64 A100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 8B training" 7 | 8 | [profiling] 9 | enable_profiling = false 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 1 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | enable_wandb = true 18 | 19 | [model] 20 | name = "llama3" 21 | flavor = "8B" 22 | tokenizer_path = "./assets/tokenizer/Llama-3.1-8B" 23 | num_quantizers = 4 # Number of quantizers to use for float8 conversion 24 | # converters = ["float8"] 25 | 26 | [optimizer] 27 | name = "AdamW" 28 | lr = 3e-4 29 | eps = 1e-8 30 | 31 | [lr_scheduler] 32 | warmup_steps = 200 # lr scheduler warm up 33 | 34 | [training] 35 | local_batch_size = 8 36 | global_batch_size = 256 37 | seq_len = 1024 38 | max_norm = 1.0 # grad norm clipping 39 | steps = 30000 40 | compile = false 41 | dataset = "libri_light" 42 | 43 | [parallelism] 44 | data_parallel_replicate_degree = 1 45 | data_parallel_shard_degree = -1 46 | tensor_parallel_degree = 1 47 | pipeline_parallel_degree = 1 48 | context_parallel_degree = 1 49 | 50 | [checkpoint] 51 | enable_checkpoint = true 52 | folder = "checkpoint" 53 | interval = 100 54 | last_save_model_only = false 55 | export_dtype = "float32" 56 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 57 | 58 | [activation_checkpoint] 59 | mode = "none" # ["none", "selective", "full"] 60 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy 61 | 62 | [float8] 63 | enable_fsdp_float8_all_gather = false 64 | precompute_float8_dynamic_scale_for_fsdp = false 65 | filter_fqns = ["output"] 66 | -------------------------------------------------------------------------------- /config/llama3_1_8b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 64 A100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 8B training" 7 | 8 | [profiling] 9 | enable_profiling = false 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 1 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | enable_wandb = true 18 | 19 | [model] 20 | name = "meta-llama/Llama-3.1-8B" 21 | num_quantizers = 4 22 | # converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 3e-4 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 1500 # lr scheduler warm up 31 | decay_ratio = 0.2 32 | lr_min = 0.1 33 | 34 | [training] 35 | local_batch_size = 8 36 | global_batch_size = 1024 37 | seq_len = 1024 38 | max_norm = 1.0 # grad norm clipping 39 | steps = 100_000 40 | compile = false 41 | dataset = "all" #"libri_light" 42 | task = "a2a" 43 | 44 | [validation] 45 | enabled = true 46 | dataset = "librispeech_asr" 47 | local_batch_size = 8 48 | seq_len = 1024 49 | freq = 100 50 | steps = 10 51 | 52 | [evaluation] 53 | enable_evaluation = false 54 | evaluation_freq = 500 55 | 56 | [parallelism] 57 | data_parallel_replicate_degree = 1 58 | data_parallel_shard_degree = -1 59 | tensor_parallel_degree = 1 60 | pipeline_parallel_degree = 1 61 | context_parallel_degree = 1 62 | 63 | [checkpoint] 64 | enable_checkpoint = true 65 | folder = "checkpoint" 66 | interval = 1000 67 | last_save_model_only = false 68 | export_dtype = "float32" 69 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 70 | keep_latest_k = 10 71 | 72 | [activation_checkpoint] 73 | mode = "none" # ["none", "selective", "full"] 74 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy 75 | 76 | [float8] 77 | enable_fsdp_float8_all_gather = false 78 | precompute_float8_dynamic_scale_for_fsdp = false 79 | filter_fqns = ["output"] 80 | -------------------------------------------------------------------------------- /config/llama3_2_1b.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 64 A100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 1B training" 7 | 8 | [profiling] 9 | enable_profiling = false 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 1 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | enable_wandb = true 18 | 19 | [model] 20 | name = "meta-llama/Llama-3.2-1B" 21 | num_quantizers = 4 22 | # converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 3e-4 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 1500 # lr scheduler warm up 31 | decay_ratio = 0.2 32 | lr_min = 0.1 33 | 34 | [training] 35 | local_batch_size = 32 36 | global_batch_size = 1024 37 | seq_len = 1024 38 | max_norm = 1.0 # grad norm clipping 39 | steps = 100_000 40 | compile = false 41 | dataset = "all" #"libri_light" # "Emilia-librilight" 42 | task = "a2a" 43 | 44 | [validation] 45 | enabled = true 46 | dataset = "librispeech_asr" 47 | local_batch_size = 8 48 | seq_len = 1024 49 | freq = 100 50 | steps = 10 51 | 52 | [evaluation] 53 | enable_evaluation = false 54 | evaluation_freq = 500 55 | 56 | [parallelism] 57 | data_parallel_replicate_degree = 1 58 | data_parallel_shard_degree = -1 59 | tensor_parallel_degree = 1 60 | pipeline_parallel_degree = 1 61 | context_parallel_degree = 1 62 | 63 | [checkpoint] 64 | enable_checkpoint = true 65 | folder = "checkpoint" 66 | interval = 1000 67 | last_save_model_only = false 68 | export_dtype = "float32" 69 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 70 | keep_latest_k = 10 71 | 72 | [activation_checkpoint] 73 | mode = "none" # ["none", "selective", "full"] 74 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy 75 | 76 | [float8] 77 | enable_fsdp_float8_all_gather = false 78 | precompute_float8_dynamic_scale_for_fsdp = false 79 | filter_fqns = ["output"] 80 | -------------------------------------------------------------------------------- /config/llama3_2_1b_peoples_speech.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | # NOTE: this toml config is a preset for 64 A100 GPUs. 3 | 4 | [job] 5 | dump_folder = "./outputs" 6 | description = "Llama 3 1B training" 7 | 8 | [profiling] 9 | enable_profiling = false 10 | save_traces_folder = "profile_trace" 11 | profile_freq = 100 12 | 13 | [metrics] 14 | log_freq = 1 15 | enable_tensorboard = true 16 | save_tb_folder = "tb" 17 | enable_wandb = true 18 | 19 | [model] 20 | name = "meta-llama/Llama-3.2-1B" 21 | num_quantizers = 4 22 | # converters = ["float8"] 23 | 24 | [optimizer] 25 | name = "AdamW" 26 | lr = 3e-4 27 | eps = 1e-8 28 | 29 | [lr_scheduler] 30 | warmup_steps = 250 # lr scheduler warm up 31 | decay_ratio = 0.2 32 | lr_min = 0.1 33 | 34 | [training] 35 | local_batch_size = 32 36 | global_batch_size = 1024 37 | seq_len = 1024 38 | max_norm = 1.0 # grad norm clipping 39 | steps = 5_000 40 | compile = false 41 | dataset = "peoples_speech" #"libri_light" # "Emilia-librilight" 42 | task = "a2a" 43 | 44 | [validation] 45 | enabled = true 46 | dataset = "librispeech_asr_test" 47 | local_batch_size = 8 48 | seq_len = 1024 49 | freq = 100 50 | steps = 10 51 | 52 | [evaluation] 53 | enable_evaluation = false 54 | evaluation_freq = 500 55 | 56 | [parallelism] 57 | data_parallel_replicate_degree = 1 58 | data_parallel_shard_degree = -1 59 | tensor_parallel_degree = 1 60 | pipeline_parallel_degree = 1 61 | context_parallel_degree = 1 62 | 63 | [checkpoint] 64 | enable_checkpoint = true 65 | folder = "checkpoint" 66 | interval = 1000 67 | last_save_model_only = false 68 | export_dtype = "float32" 69 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 70 | keep_latest_k = 10 71 | 72 | [activation_checkpoint] 73 | mode = "none" # ["none", "selective", "full"] 74 | selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy 75 | 76 | [float8] 77 | enable_fsdp_float8_all_gather = false 78 | precompute_float8_dynamic_scale_for_fsdp = false 79 | filter_fqns = ["output"] 80 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/train_configs/debug_model.toml: -------------------------------------------------------------------------------- 1 | # torchtitan Config.toml 2 | 3 | [job] 4 | dump_folder = "./outputs" 5 | description = "Llama 3 debug training" 6 | print_args = false 7 | use_for_integration_test = true 8 | 9 | [profiling] 10 | enable_profiling = false 11 | save_traces_folder = "profile_trace" 12 | profile_freq = 10 13 | enable_memory_snapshot = false 14 | save_memory_snapshot_folder = "memory_snapshot" 15 | 16 | [metrics] 17 | log_freq = 1 18 | disable_color_printing = false 19 | enable_tensorboard = false 20 | save_tb_folder = "tb" 21 | enable_wandb = false 22 | 23 | [model] 24 | name = "llama3" 25 | flavor = "debugmodel" 26 | # test folder with tokenizer.json, for debug purpose only 27 | tokenizer_path = "./tests/assets/tokenizer" 28 | # converters = ["float8"] 29 | 30 | [optimizer] 31 | name = "AdamW" 32 | lr = 8e-4 33 | eps = 1e-8 34 | 35 | [lr_scheduler] 36 | warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps 37 | decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps 38 | decay_type = "linear" 39 | lr_min = 0.0 40 | 41 | [training] 42 | local_batch_size = 8 43 | seq_len = 2048 44 | max_norm = 1.0 # grad norm clipping 45 | steps = 10 46 | compile = false 47 | dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) 48 | 49 | [parallelism] 50 | data_parallel_replicate_degree = 1 51 | data_parallel_shard_degree = -1 52 | fsdp_reshard_after_forward = "default" # default / never / always 53 | tensor_parallel_degree = 1 54 | enable_async_tensor_parallel = false 55 | pipeline_parallel_degree = 1 56 | context_parallel_degree = 1 57 | 58 | [checkpoint] 59 | enable_checkpoint = false 60 | folder = "checkpoint" 61 | interval = 10 62 | last_save_model_only = false 63 | export_dtype = "float32" 64 | async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] 65 | 66 | [activation_checkpoint] 67 | mode = "selective" # ["none", "selective", "full"] 68 | selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy 69 | 70 | [float8] 71 | enable_fsdp_float8_all_gather = false 72 | precompute_float8_dynamic_scale_for_fsdp = false 73 | filter_fqns = ["output"] 74 | 75 | [validation] 76 | enabled = false 77 | dataset = "c4_validation" 78 | freq = 5 79 | steps = 10 80 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torchtitan.components.loss import build_cross_entropy_loss 8 | from torchtitan.components.lr_scheduler import build_lr_schedulers 9 | from torchtitan.components.optimizer import build_optimizers 10 | from torchtitan.components.tokenizer import build_hf_tokenizer 11 | from torchtitan.components.validate import build_validator 12 | from torchtitan.datasets.hf_datasets import build_hf_dataloader 13 | from torchtitan.protocols.train_spec import register_train_spec, TrainSpec 14 | 15 | from .infra.parallelize import parallelize_llama 16 | from .infra.pipeline import pipeline_llama 17 | from .model.args import TransformerModelArgs 18 | from .model.model import Transformer 19 | from .model.state_dict_adapter import Llama3StateDictAdapter 20 | 21 | __all__ = [ 22 | "parallelize_llama", 23 | "pipeline_llama", 24 | "TransformerModelArgs", 25 | "Transformer", 26 | "llama3_configs", 27 | ] 28 | 29 | 30 | llama3_configs = { 31 | "debugmodel": TransformerModelArgs( 32 | dim=256, n_layers=6, n_heads=16, rope_theta=500000 33 | ), 34 | "debugmodel_flex_attn": TransformerModelArgs( 35 | dim=256, 36 | n_layers=6, 37 | n_heads=16, 38 | rope_theta=500000, 39 | use_flex_attn=True, 40 | attn_mask_type="block_causal", 41 | ), 42 | "8B": TransformerModelArgs( 43 | dim=4096, 44 | n_layers=32, 45 | n_heads=32, 46 | n_kv_heads=8, 47 | ffn_dim_multiplier=1.3, 48 | multiple_of=1024, 49 | rope_theta=500000, 50 | ), 51 | "70B": TransformerModelArgs( 52 | dim=8192, 53 | n_layers=80, 54 | n_heads=64, 55 | n_kv_heads=8, 56 | ffn_dim_multiplier=1.3, 57 | multiple_of=4096, 58 | rope_theta=500000, 59 | ), 60 | "405B": TransformerModelArgs( 61 | dim=16384, 62 | n_layers=126, 63 | n_heads=128, 64 | n_kv_heads=8, 65 | ffn_dim_multiplier=1.2, 66 | multiple_of=4096, 67 | rope_theta=500000, 68 | ), 69 | } 70 | 71 | 72 | register_train_spec( 73 | TrainSpec( 74 | name="llama3", 75 | model_cls=Transformer, 76 | model_args=llama3_configs, 77 | parallelize_fn=parallelize_llama, 78 | pipelining_fn=pipeline_llama, 79 | build_optimizers_fn=build_optimizers, 80 | build_lr_schedulers_fn=build_lr_schedulers, 81 | build_dataloader_fn=build_hf_dataloader, 82 | build_tokenizer_fn=build_hf_tokenizer, 83 | build_loss_fn=build_cross_entropy_loss, 84 | build_validator_fn=build_validator, 85 | state_dict_adapter=Llama3StateDictAdapter, 86 | ) 87 | ) 88 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/model/args.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | 10 | from dataclasses import dataclass 11 | 12 | from torch import nn 13 | 14 | from torchtitan.components.tokenizer import BaseTokenizer 15 | from torchtitan.config_manager import JobConfig 16 | from torchtitan.protocols.train_spec import BaseModelArgs 17 | 18 | 19 | @dataclass 20 | class TransformerModelArgs(BaseModelArgs): 21 | dim: int = 4096 22 | n_layers: int = 32 23 | n_heads: int = 32 24 | n_kv_heads: int | None = None 25 | vocab_size: int = -1 # defined later by tokenizer 26 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 27 | ffn_dim_multiplier: float | None = None 28 | norm_eps: float = 1e-5 29 | rope_theta: float = 10000 30 | 31 | max_seq_len: int = 2048 32 | # If `True`, then each transformer block init uses its layer ID, and if 33 | # `False`, each uses the total number of transformer blocks 34 | depth_init: bool = True 35 | 36 | use_flex_attn: bool = False 37 | attn_mask_type: str = "causal" 38 | eos_id: int = 0 39 | 40 | def update_from_config( 41 | self, job_config: JobConfig, tokenizer: BaseTokenizer 42 | ) -> None: 43 | self.vocab_size = tokenizer.get_vocab_size() 44 | self.max_seq_len = job_config.training.seq_len 45 | self.eos_id = tokenizer.eos_id 46 | 47 | if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: 48 | raise ValueError( 49 | "FlexAttention is not compatible with CP yet. " 50 | "We are still working on this." 51 | ) 52 | 53 | def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: 54 | nparams = sum(p.numel() for p in model.parameters()) 55 | nparams_embedding = sum( 56 | sum(p.numel() for p in m.parameters()) 57 | for m in model.children() 58 | if isinstance(m, nn.Embedding) 59 | ) 60 | 61 | l, h, q, t = ( 62 | self.n_layers, 63 | self.n_heads, 64 | self.dim // self.n_heads, 65 | seq_len, 66 | ) 67 | # Reasoning behind the factor of 12 for the self-attention part of the formula: 68 | # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) 69 | # 2. the flash attention does 1 more matmul recomputation in the backward 70 | # but recomputation should not be counted in calculating MFU (+0) 71 | # 3. each matmul performs 1 multiplication and 1 addition (*2) 72 | # 4. we follow the convention and do not account for sparsity in causal attention 73 | num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t 74 | 75 | return nparams, num_flops_per_token 76 | -------------------------------------------------------------------------------- /torchtitan/protocols/model_converter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Dict, List, Protocol, Union 7 | 8 | import torch.nn as nn 9 | 10 | from torchtitan.config_manager import JobConfig 11 | from torchtitan.distributed import ParallelDims 12 | from torchtitan.tools.logging import logger 13 | 14 | 15 | class ModelConverter(Protocol): 16 | """General model converter interface. 17 | 18 | A model converter is applying a modification to PyTorch model. 19 | Typical use cases are: 20 | - Quantization: using QAT, FP8, ... specialized linear layers; 21 | - Fused optimized layers (e.g. flash-attention, norms, ...) 22 | """ 23 | 24 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ... 25 | 26 | def convert(self, model: nn.Module): 27 | """Inplace convertion of the model.""" 28 | ... 29 | 30 | def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): 31 | """Post-optimizer (optional) hook (e.g. compute weights statistics).""" 32 | ... 33 | 34 | 35 | _registry_model_converter_cls: Dict[str, type[ModelConverter]] = {} 36 | """Registry of model converter classes. 37 | """ 38 | 39 | 40 | def register_model_converter(converter_cls: type[ModelConverter], name: str): 41 | """Register a model converter class. 42 | 43 | A registered model converter can be applied on any model 44 | using the `model.converters` config parameter. 45 | """ 46 | assert name not in _registry_model_converter_cls, ( 47 | f"A model converter '{name}' is already registered." 48 | ) 49 | _registry_model_converter_cls[name] = converter_cls 50 | 51 | 52 | class ModelConvertersContainer(ModelConverter): 53 | """Model converters sequential container. 54 | 55 | The class build the sequence of model converters defined in `model.converters` 56 | job config, and apply them to the model sequentially. 57 | """ 58 | 59 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 60 | converter_classes = [ 61 | _registry_model_converter_cls[name] for name in job_config.model.converters 62 | ] 63 | self.converters = [ 64 | mh_cls(job_config, parallel_dims) for mh_cls in converter_classes 65 | ] 66 | self.print_after_conversion = job_config.model.print_after_conversion 67 | 68 | def convert(self, model: nn.Module): 69 | for mh in self.converters: 70 | mh.convert(model) 71 | if self.print_after_conversion: 72 | logger.info(f"Model definion after conversion:\n\n{model}\n\n") 73 | 74 | def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): 75 | for mh in self.converters: 76 | mh.post_optimizer_hook(model) 77 | 78 | 79 | def build_model_converters( 80 | job_config: JobConfig, parallel_dims: ParallelDims 81 | ) -> ModelConvertersContainer: 82 | """Build the collection of model converters to apply to the model.""" 83 | return ModelConvertersContainer(job_config, parallel_dims) 84 | -------------------------------------------------------------------------------- /torchtitan/components/quantization/mx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | from importlib.metadata import version 9 | from importlib.util import find_spec 10 | from typing import Any, List 11 | 12 | import torch.nn as nn 13 | 14 | from torchtitan.config_manager import JobConfig, MX 15 | from torchtitan.distributed import ParallelDims 16 | from torchtitan.protocols.model_converter import ( 17 | ModelConverter, 18 | register_model_converter, 19 | ) 20 | from torchtitan.tools.logging import logger 21 | from torchtitan.tools.utils import has_cuda_capability 22 | 23 | from .utils import module_filter_fn 24 | 25 | # Maps titan recipe names to torchao mx recipe names 26 | NAME_MAP = {"mxfp8": "mxfp8_cublas"} 27 | 28 | 29 | class MXConverter(ModelConverter): 30 | """Converts the linear layers of `model` to `MXLinear`.""" 31 | 32 | enabled: bool 33 | filter_fqns: List[str] 34 | mx_config: Any # MXLinearConfig type when imported 35 | 36 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 37 | # Ensure minimum torchao versions 38 | if find_spec("torchao") is None: 39 | raise ImportError( 40 | "torchao is not installed. Please install it to use MXFP8 linear layers." 41 | ) 42 | torchao_version = version("torchao") 43 | mxfp8_min_version = "0.11.0" 44 | if torchao_version < mxfp8_min_version: 45 | raise ImportError( 46 | f"torchao version {torchao_version} is too old, please install torchao {mxfp8_min_version} or later and try again" 47 | ) 48 | 49 | # Can be removed if we enable the emulated versions 50 | assert has_cuda_capability(10, 0), ( 51 | "MXFP8 is only supported on SM100 or architectures" 52 | ) 53 | 54 | self.enabled = True 55 | mx_job_config: MX = job_config.mx 56 | self.filter_fqns = mx_job_config.filter_fqns 57 | 58 | # Configure MXFP8 59 | from torchao.prototype.mx_formats.config import MXLinearConfig 60 | 61 | config = MXLinearConfig.from_recipe_name(NAME_MAP[mx_job_config.recipe_name]) 62 | config.use_fp8_dim1_cast_triton_kernel = ( 63 | mx_job_config.use_fp8_dim1_cast_triton_kernel 64 | ) 65 | self.config = config 66 | 67 | logger.info(f"Float8 training active with recipe {mx_job_config.recipe_name}") 68 | 69 | def convert(self, model: nn.Module): 70 | """ 71 | Converts the linear layers of `model` to `MXLinear`. 72 | Note that today, only dynamic tensor scaling (the default) is supported. 73 | This will mutate the model inplace. 74 | """ 75 | if not self.enabled: 76 | return 77 | 78 | from torchao.prototype.mx_formats.config import MXLinearConfig 79 | from torchao.quantization import quantize_ 80 | 81 | assert isinstance(self.config, MXLinearConfig) 82 | quantize_( 83 | model, 84 | config=self.config, 85 | filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns), 86 | ) 87 | logger.info("Swapped to MXLinear layers") 88 | 89 | def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): 90 | """ 91 | MXFP8 doesn't require any post-optimizer hooks at the moment 92 | """ 93 | return 94 | 95 | 96 | register_model_converter(MXConverter, "mx") 97 | -------------------------------------------------------------------------------- /eval/salmon.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Audio, get_dataset_config_names 2 | from transformers import MimiModel, AutoFeatureExtractor 3 | import os 4 | from torchtitan.datasets.hf_datasets import audio_array_to_text 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from argparse import ArgumentParser 8 | from tqdm import tqdm 9 | import json 10 | import torch.nn.functional as F 11 | from torch.utils.data import Dataset, DataLoader 12 | from torch.nn.utils.rnn import pad_sequence 13 | 14 | 15 | def parse_args(): 16 | parser = ArgumentParser(description="Audio completion generation script") 17 | parser.add_argument( 18 | "--model_name", 19 | type=str, 20 | default="llm-jp/Llama-Mimi-1.3B", 21 | help="Run name for the model", 22 | ) 23 | parser.add_argument( 24 | "--output_dir", type=str, default="results", help="Output directory" 25 | ) 26 | return parser.parse_args() 27 | 28 | 29 | def compute_loss( 30 | audio, audio_tokenizer, feature_extractor, num_quantizers, tokenizer, model, device 31 | ): 32 | text = audio_array_to_text( 33 | audio, audio_tokenizer, feature_extractor, num_quantizers 34 | ) 35 | inputs = tokenizer(text, return_tensors="pt") 36 | 37 | labels = inputs.input_ids.clone() 38 | labels[labels == tokenizer.pad_token_id] = -100 39 | inputs = inputs.to(device) 40 | outputs = model(input_ids=inputs.input_ids, labels=labels) 41 | loss = outputs.loss 42 | return loss 43 | 44 | 45 | if __name__ == "__main__": 46 | args = parse_args() 47 | 48 | device = "cuda" if torch.cuda.is_available() else "cpu" 49 | 50 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to(device) 51 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi") 52 | 53 | model = ( 54 | AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16) 55 | .eval() 56 | .to(device) 57 | ) 58 | num_quantizers = model.config.num_quantizers 59 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 60 | 61 | # Set pad token if not already set 62 | if tokenizer.pad_token is None: 63 | tokenizer.pad_token = tokenizer.eos_token 64 | 65 | tasks = get_dataset_config_names("slprl/SALMon") 66 | tasks = [c for c in tasks if not c.startswith("all_")] 67 | result = {} 68 | for task in tasks: 69 | ds = load_dataset("slprl/SALMon", task, split="train") 70 | ds = ds.cast_column( 71 | "negative_audio", Audio(sampling_rate=feature_extractor.sampling_rate) 72 | ) 73 | ds = ds.cast_column( 74 | "positive_audio", Audio(sampling_rate=feature_extractor.sampling_rate) 75 | ) 76 | 77 | total_correct = 0 78 | total_samples = 0 79 | 80 | for example in tqdm(ds): 81 | negative_audio = example["negative_audio"]["array"] 82 | positive_audio = example["positive_audio"]["array"] 83 | 84 | neg_loss = compute_loss( 85 | negative_audio, 86 | audio_tokenizer, 87 | feature_extractor, 88 | num_quantizers, 89 | tokenizer, 90 | model, 91 | device, 92 | ) 93 | pos_loss = compute_loss( 94 | positive_audio, 95 | audio_tokenizer, 96 | feature_extractor, 97 | num_quantizers, 98 | tokenizer, 99 | model, 100 | device, 101 | ) 102 | # print(f"Neg loss: {neg_loss.item()}, Pos loss: {pos_loss.item()}") 103 | 104 | total_correct += (neg_loss > pos_loss).item() 105 | total_samples += 1 106 | 107 | acc = total_correct / total_samples 108 | result[task] = acc 109 | print(f"Accuracy for {task}: {acc}") 110 | 111 | output_path = os.path.join( 112 | args.output_dir, "SALMON", args.model_name, "accuracy.json" 113 | ) 114 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 115 | with open(output_path, "w") as f: 116 | json.dump(result, f, indent=4) 117 | -------------------------------------------------------------------------------- /torchtitan/components/dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | import pickle 10 | from abc import ABC, abstractmethod 11 | from collections.abc import Callable 12 | from typing import Any 13 | 14 | from torch.distributed.checkpoint.stateful import Stateful 15 | from torch.utils.data import IterableDataset 16 | from torchdata.stateful_dataloader import StatefulDataLoader 17 | from torchtitan.tools.logging import logger 18 | 19 | 20 | class DataloaderStopIteration(StopIteration): 21 | """An exception that indicates dataloader exhaustion.""" 22 | 23 | pass 24 | 25 | 26 | class BaseDataLoader(Stateful, ABC): 27 | """Base class for all dataloaders. 28 | 29 | This is used to enforce that all dataloaders have the methods defined in ``Stateful``, 30 | ``state_dict()`` and ``load_state_dict()``. 31 | """ 32 | 33 | @abstractmethod 34 | def __iter__(self): ... 35 | 36 | 37 | class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): 38 | """Dataloader that is aware of distributed data parallelism. 39 | 40 | This dataloader is used to load data in a distributed data parallel fashion. It also 41 | utilizes ``torchdata.stateful_dataloader.StatefulDataLoader`` to implement the necessary 42 | methods such as ``__iter__``. 43 | 44 | Args: 45 | dataset (IterableDataset): The dataset to iterate over. 46 | dp_rank: Data parallelism rank for this dataloader. 47 | dp_world_size: The world size of the data parallelism. 48 | batch_size: The batch size to use for each iteration. 49 | collate_fn: Optional function to collate samples in a batch. 50 | """ 51 | 52 | dp_rank: int 53 | dp_world_size: int 54 | batch_size: int 55 | 56 | def __init__( 57 | self, 58 | dataset: IterableDataset, 59 | dp_rank: int, 60 | dp_world_size: int, 61 | batch_size: int, 62 | collate_fn: Callable | None = None, 63 | num_workers: int = 2, 64 | prefetch_factor: int | None = 2, 65 | pin_memory: bool = True, 66 | persistent_workers: bool | None = True, 67 | ): 68 | self.dp_world_size = dp_world_size 69 | self.dp_rank = dp_rank 70 | self.batch_size = batch_size 71 | super().__init__( 72 | dataset, 73 | batch_size, 74 | collate_fn=collate_fn, 75 | num_workers=num_workers, 76 | prefetch_factor=prefetch_factor, 77 | pin_memory=pin_memory, 78 | persistent_workers=persistent_workers, 79 | ) 80 | self._rank_id = f"dp_rank_{dp_rank}" 81 | 82 | def state_dict(self) -> dict[str, Any]: 83 | # Store state only for dp rank to avoid replicating the same state across other dimensions. 84 | return { 85 | # We don't have to use pickle as DCP will serialize the state_dict. However, 86 | # we have to keep this for backward compatibility. 87 | self._rank_id: pickle.dumps(super().state_dict()), 88 | "world_size": self.dp_world_size, 89 | } 90 | 91 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 92 | # State being empty is valid. 93 | if not state_dict: 94 | return 95 | 96 | if self._rank_id not in state_dict: 97 | logger.warning( 98 | f"DataLoader state is empty for dp rank {self.dp_rank}, " 99 | "expected key {self._rank_id}" 100 | ) 101 | return 102 | 103 | assert self.dp_world_size == state_dict["world_size"], ( 104 | "dp_degree is inconsistent before and after checkpoint, " 105 | "dataloader resharding is not supported yet." 106 | ) 107 | # We don't have to use pickle as DCP will serialize the state_dict. However, we have to 108 | # keep this for backward compatibility. 109 | super().load_state_dict(pickle.loads(state_dict[self._rank_id])) 110 | -------------------------------------------------------------------------------- /eval/sLM21.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Audio 2 | from transformers import MimiModel, AutoFeatureExtractor 3 | import os 4 | from torchtitan.datasets.hf_datasets import audio_array_to_text 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from argparse import ArgumentParser 8 | from tqdm import tqdm 9 | import json 10 | 11 | 12 | def parse_args(): 13 | parser = ArgumentParser(description="Audio completion generation script") 14 | parser.add_argument( 15 | "--model_name", 16 | type=str, 17 | default="llm-jp/Llama-Mimi-1.3B", 18 | help="Run name for the model", 19 | ) 20 | parser.add_argument( 21 | "--output_dir", type=str, default="results", help="Output directory" 22 | ) 23 | return parser.parse_args() 24 | 25 | 26 | def compute_loss( 27 | audio, audio_tokenizer, feature_extractor, num_quantizers, tokenizer, model, device 28 | ): 29 | text = audio_array_to_text( 30 | audio, audio_tokenizer, feature_extractor, num_quantizers 31 | ) 32 | inputs = tokenizer(text, return_tensors="pt") 33 | 34 | labels = inputs.input_ids.clone() 35 | labels[labels == tokenizer.pad_token_id] = -100 36 | mod_mask = torch.zeros_like(labels) 37 | mod_mask[:, 2 :: model.config.num_quantizers] = 1 38 | labels[mod_mask == 0] = -100 39 | # print([tokenizer.decode(ids) for ids in neg_labels[0].tolist() if ids != -100]) 40 | inputs = inputs.to(device) 41 | outputs = model(input_ids=inputs.input_ids, labels=labels) 42 | loss = outputs.loss 43 | return loss 44 | 45 | 46 | if __name__ == "__main__": 47 | args = parse_args() 48 | 49 | device = "cuda" if torch.cuda.is_available() else "cpu" 50 | 51 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to(device) 52 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi") 53 | 54 | model = ( 55 | AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16) 56 | .eval() 57 | .to(device) 58 | ) 59 | num_quantizers = model.config.num_quantizers 60 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 61 | 62 | # Set pad token if not already set 63 | if tokenizer.pad_token is None: 64 | tokenizer.pad_token = tokenizer.eos_token 65 | 66 | tasks = ["sWUGGY", "sBLIMP"] 67 | result = {} 68 | for task in tasks: 69 | ds = load_dataset( 70 | f"speed/{task}", split="train" 71 | ) # .shuffle(seed=42).select(range(1000)) 72 | ds = ds.cast_column( 73 | "negative", Audio(sampling_rate=feature_extractor.sampling_rate) 74 | ) 75 | ds = ds.cast_column( 76 | "positive", Audio(sampling_rate=feature_extractor.sampling_rate) 77 | ) 78 | 79 | total_correct = 0 80 | total_samples = 0 81 | 82 | for example in tqdm(ds): 83 | negative_audio = example["negative"]["array"] 84 | positive_audio = example["positive"]["array"] 85 | 86 | neg_loss = compute_loss( 87 | negative_audio, 88 | audio_tokenizer, 89 | feature_extractor, 90 | num_quantizers, 91 | tokenizer, 92 | model, 93 | device, 94 | ) 95 | pos_loss = compute_loss( 96 | positive_audio, 97 | audio_tokenizer, 98 | feature_extractor, 99 | num_quantizers, 100 | tokenizer, 101 | model, 102 | device, 103 | ) 104 | # print(f"Neg loss: {neg_loss.item()}, Pos loss: {pos_loss.item()}") 105 | 106 | total_correct += (neg_loss > pos_loss).item() 107 | total_samples += 1 108 | 109 | acc = total_correct / total_samples 110 | result[task] = acc 111 | print(f"Accuracy for {task}: {acc}") 112 | 113 | output_path = os.path.join( 114 | args.output_dir, "sLM21", args.model_name, "accuracy.json" 115 | ) 116 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 117 | with open(output_path, "w") as f: 118 | json.dump(result, f, indent=4) 119 | -------------------------------------------------------------------------------- /eval/sStoryCloze.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset, Audio 2 | from transformers import MimiModel, AutoFeatureExtractor 3 | import os 4 | from torchtitan.datasets.hf_datasets import audio_array_to_text 5 | import torch 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from argparse import ArgumentParser 8 | from tqdm import tqdm 9 | import json 10 | 11 | 12 | def parse_args(): 13 | parser = ArgumentParser(description="Audio completion generation script") 14 | parser.add_argument( 15 | "--model_name", 16 | type=str, 17 | default="llm-jp/Llama-Mimi-1.3B", 18 | help="Run name for the model", 19 | ) 20 | parser.add_argument( 21 | "--output_dir", type=str, default="results", help="Output directory" 22 | ) 23 | return parser.parse_args() 24 | 25 | 26 | def compute_loss( 27 | audio, audio_tokenizer, feature_extractor, num_quantizers, tokenizer, model, device 28 | ): 29 | text = audio_array_to_text( 30 | audio, audio_tokenizer, feature_extractor, num_quantizers 31 | ) 32 | inputs = tokenizer(text, return_tensors="pt") 33 | 34 | labels = inputs.input_ids.clone() 35 | labels[labels == tokenizer.pad_token_id] = -100 36 | mod_mask = torch.zeros_like(labels) 37 | mod_mask[:, 2 :: model.config.num_quantizers] = 1 38 | labels[mod_mask == 0] = -100 39 | # print([tokenizer.decode(ids) for ids in neg_labels[0].tolist() if ids != -100]) 40 | inputs = inputs.to(device) 41 | outputs = model(input_ids=inputs.input_ids, labels=labels) 42 | loss = outputs.loss 43 | return loss 44 | 45 | 46 | if __name__ == "__main__": 47 | args = parse_args() 48 | 49 | import multiprocessing as mp 50 | 51 | mp.set_start_method("spawn", force=True) 52 | 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | 55 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi").to(device) 56 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi") 57 | 58 | model = ( 59 | AutoModelForCausalLM.from_pretrained(args.model_name, torch_dtype=torch.float16) 60 | .eval() 61 | .to(device) 62 | ) 63 | num_quantizers = model.config.num_quantizers 64 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 65 | 66 | # Set pad token if not already set 67 | if tokenizer.pad_token is None: 68 | tokenizer.pad_token = tokenizer.eos_token 69 | 70 | tasks = ["sStoryCloze", "tStoryCloze"] 71 | result = {} 72 | for task in tasks: 73 | ds = load_dataset( 74 | f"speed/{task}", split="train" 75 | ) # .shuffle(seed=42).select(range(1000)) 76 | ds = ds.cast_column( 77 | "negative", Audio(sampling_rate=feature_extractor.sampling_rate) 78 | ) 79 | ds = ds.cast_column( 80 | "positive", Audio(sampling_rate=feature_extractor.sampling_rate) 81 | ) 82 | 83 | total_correct = 0 84 | total_samples = 0 85 | 86 | for example in tqdm(ds): 87 | negative_audio = example["negative"]["array"] 88 | positive_audio = example["positive"]["array"] 89 | 90 | neg_loss = compute_loss( 91 | negative_audio, 92 | audio_tokenizer, 93 | feature_extractor, 94 | num_quantizers, 95 | tokenizer, 96 | model, 97 | device, 98 | ) 99 | pos_loss = compute_loss( 100 | positive_audio, 101 | audio_tokenizer, 102 | feature_extractor, 103 | num_quantizers, 104 | tokenizer, 105 | model, 106 | device, 107 | ) 108 | 109 | total_correct += (neg_loss > pos_loss).item() 110 | total_samples += 1 111 | 112 | acc = total_correct / total_samples 113 | result[task] = acc 114 | print(f"Accuracy for {task}: {acc}") 115 | 116 | output_path = os.path.join( 117 | args.output_dir, "sStoryCloze", args.model_name, "accuracy.json" 118 | ) 119 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 120 | with open(output_path, "w") as f: 121 | json.dump(result, f, indent=4) 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Llama-Mimi 4 | #### Autoregressive Speech Language Modeling with Interleaved Semantic and Acoustic Tokens 5 | | [📃Paper](https://arxiv.org/abs/2509.14882) | [🤗Models](https://huggingface.co/llm-jp/Llama-Mimi-1.3B) | [🗣️Online Demo](https://speed1313.github.io/llama-mimi/) | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | ## Introduction 14 | Llama-Mimi is a speech language model that uses a unified tokenizer (Mimi) and a single Transformer decoder (Llama) to jointly model sequences of interleaved semantic and acoustic tokens. 15 | Trained on ~240k hours of English audio, Llama-Mimi achieves state-of-the-art performance in acoustic consistency on [SALMon](https://arxiv.org/abs/2409.07437) and effectively preserves speaker identity. 16 | 17 | Visit our [demo site](https://speed1313.github.io/llama-mimi/) to hear generated speech samples. 18 | 19 | 20 | ## Repository Overview 21 | This repository lets you: 22 | - Run inference with our pretrained models 23 | - Pre-train Llama-Mimi on [The People's Speech](https://huggingface.co/datasets/MLCommons/peoples_speech) 24 | - Evaluate the model on multiple benchmarks 25 | 26 | ## Setup 27 | 28 | 29 | Install dependencies using uv: 30 | ```bash 31 | uv sync 32 | ``` 33 | 34 | ## Generate Speech 35 | 36 | Generate audio continuations from a given audio prompt using our pretrained model (Llama-Mimi-1.3B): 37 | ```bash 38 | uv run python inference.py 39 | ``` 40 | 41 | [▶️ Listen to samples on our demo site](https://speed1313.github.io/llama-mimi) 42 | 43 | ## Pre-train Llama-Mimi on The People's Speech 44 | 45 | To pre-train Llama-Mimi on [The People's Speech](https://huggingface.co/datasets/MLCommons/peoples_speech) (30k hours), first download the dataset locally: 46 | ```bash 47 | uv run huggingface-cli download MLCommons/peoples_speech --repo-type dataset --local-dir data/peoples_speech 48 | ``` 49 | 50 | Then launch training with: 51 | ```bash 52 | torchrun --nproc_per_node=8 --local-ranks-filter 0 \ 53 | --role rank --tee 3 -m torchtitan.train \ 54 | --job.config_file config/llama3_2_1b_peoples_speech.toml 55 | ``` 56 | This configuration trains Llama-Mimi-1.3B for 5,000 steps with a global batch size of 1,024 on 8 GPUs, taking about 8 hours. 57 | Training progress can be monitored with Weights & Biases (W&B). 58 | 59 |
60 | 61 |
62 | 63 | To use a custom dataset, update the configuration in `torchtitan/datasets/hf_dataset.py`. We recommend downloading multiple large datasets, shuffling them, and then using `load_dataset()` with local files. 64 | 65 | After training, convert dcp checkpoint to HuggingFace format to use the model with `transformers` library: 66 | 67 | ```bash 68 | uv run python scripts/convert_dcp_to_hf.py 69 | ``` 70 | 71 | 72 | ## Evaluation 73 | Evaluate models on [SALMon](https://github.com/slp-rl/salmon), [sLM21](https://arxiv.org/abs/2104.14700) (sWUGGY and sBLIMP), and [sStoryCloze](https://github.com/slp-rl/SpokenStoryCloze) tasks. 74 | 75 | SALMon: 76 | ```bash 77 | uv run python eval/salmon.py --model_name llm-jp/Llama-Mimi-1.3B 78 | ``` 79 | 80 | sStoryCloze: 81 | ```bash 82 | uv run python eval/sStoryCloze.py --model_name llm-jp/Llama-Mimi-1.3B 83 | ``` 84 | 85 | sLM21: 86 | ```bash 87 | uv run python eval/sLM21.py --model_name llm-jp/Llama-Mimi-1.3B 88 | ``` 89 | 90 | 91 | 92 | ## Acknowledge 93 | 94 | - Our training code is built on top of [TorchTitan](https://github.com/pytorch/torchtitan). 95 | 96 | - Our model employs [Llama 3](https://arxiv.org/abs/2407.21783) as the base language model, and [Mimi](https://arxiv.org/abs/2410.00037) as the audio tokenizer. 97 | 98 | 99 | ## Citation 100 | Star us on GitHub if you find this repository useful! ⭐ 101 | 102 | If you find this work interesting, please cite our paper: 103 | ``` 104 | @misc{sugiura2025llamamimispeechlanguagemodels, 105 | title={Llama-Mimi: Speech Language Models with Interleaved Semantic and Acoustic Tokens}, 106 | author={Issa Sugiura and Shuhei Kurita and Yusuke Oda and Ryuichiro Higashinaka}, 107 | year={2025}, 108 | eprint={2509.14882}, 109 | archivePrefix={arXiv}, 110 | primaryClass={cs.CL}, 111 | url={https://arxiv.org/abs/2509.14882}, 112 | } 113 | ``` 114 | 115 | -------------------------------------------------------------------------------- /torchtitan/protocols/train_spec.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from abc import abstractmethod 8 | from collections.abc import Callable 9 | from dataclasses import dataclass 10 | from typing import Protocol, TypeAlias 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.distributed.pipelining.schedules import _PipelineSchedule 15 | 16 | from torchtitan.components.dataloader import BaseDataLoader 17 | from torchtitan.components.ft import FTManager 18 | from torchtitan.components.loss import LossFunction 19 | from torchtitan.components.lr_scheduler import LRSchedulersContainer 20 | from torchtitan.components.metrics import MetricsProcessor 21 | from torchtitan.components.optimizer import OptimizersContainer 22 | from torchtitan.components.tokenizer import BaseTokenizer 23 | from torchtitan.components.validate import BaseValidator 24 | from torchtitan.config_manager import JobConfig 25 | from torchtitan.distributed import ParallelDims 26 | from torchtitan.protocols.state_dict_adapter import StateDictAdapter 27 | 28 | 29 | @dataclass 30 | class BaseModelArgs: 31 | """All ModelArgs should inherit from this class. 32 | 33 | The only usage of this class is type checking but allows us to extend common 34 | arguments to all models in the future. 35 | """ 36 | 37 | _enforced: str = "This field is used to enforce all fields have defaults." 38 | 39 | @abstractmethod 40 | def update_from_config( 41 | self, job_config: JobConfig, tokenizer: BaseTokenizer 42 | ) -> None: 43 | pass 44 | 45 | @abstractmethod 46 | def get_nparams_and_flops( 47 | self, model: nn.Module, seq_len: int 48 | ) -> tuple[int, float]: 49 | pass 50 | 51 | 52 | class ModelProtocol(Protocol): 53 | """Defines the interface for a model class. 54 | 55 | This is used to enforce that all model classes have some methods that are 56 | required by the trainer. 57 | """ 58 | 59 | def __init__(self, model_args: BaseModelArgs) -> None: 60 | pass 61 | 62 | @abstractmethod 63 | def init_weights(self, buffer_device: torch.device | None = None) -> None: 64 | """Initialize model weights. 65 | 66 | Args: 67 | buffer_device: Optional device to place buffers on during initialization. 68 | """ 69 | pass 70 | 71 | 72 | ParallelizeFunction: TypeAlias = Callable[..., nn.Module] 73 | PipeliningFunction: TypeAlias = Callable[ 74 | ..., tuple[_PipelineSchedule, list[nn.Module], bool, bool] 75 | ] 76 | DataLoaderBuilder: TypeAlias = Callable[..., BaseDataLoader] 77 | TokenizerBuilder: TypeAlias = Callable[..., BaseTokenizer] 78 | MetricsProcessorBuilder: TypeAlias = Callable[..., MetricsProcessor] 79 | OptimizersBuilder: TypeAlias = Callable[ 80 | [list[nn.Module], JobConfig, ParallelDims, FTManager | None], 81 | OptimizersContainer, 82 | ] 83 | LRSchedulersBuilder: TypeAlias = Callable[ 84 | [OptimizersContainer, JobConfig], LRSchedulersContainer 85 | ] 86 | LossFunctionBuilder: TypeAlias = Callable[..., LossFunction] 87 | ValidatorBuilder: TypeAlias = Callable[..., BaseValidator] 88 | 89 | 90 | @dataclass 91 | class TrainSpec: 92 | name: str 93 | model_cls: type[ModelProtocol] 94 | model_args: dict[str, BaseModelArgs] 95 | parallelize_fn: ParallelizeFunction 96 | pipelining_fn: PipeliningFunction | None 97 | build_optimizers_fn: OptimizersBuilder 98 | build_lr_schedulers_fn: LRSchedulersBuilder 99 | build_dataloader_fn: DataLoaderBuilder 100 | build_tokenizer_fn: TokenizerBuilder | None 101 | build_loss_fn: LossFunctionBuilder 102 | build_validator_fn: ValidatorBuilder | None = None 103 | build_metrics_processor_fn: MetricsProcessorBuilder | None = None 104 | state_dict_adapter: type[StateDictAdapter] | None = None 105 | 106 | 107 | _train_specs = {} 108 | 109 | 110 | def register_train_spec(train_spec: TrainSpec) -> None: 111 | global _train_specs 112 | if train_spec.name in _train_specs: 113 | raise ValueError(f"Model {train_spec.name} is already registered.") 114 | 115 | _train_specs[train_spec.name] = train_spec 116 | 117 | 118 | def get_train_spec(name: str) -> TrainSpec: 119 | global _train_specs 120 | if name not in _train_specs: 121 | raise ValueError(f"Model {name} is not registered.") 122 | return _train_specs[name] 123 | 124 | 125 | def apply_to_train_specs(func: Callable[[TrainSpec], TrainSpec]) -> None: 126 | global _train_specs 127 | for name, train_spec in _train_specs.items(): 128 | _train_specs[name] = func(train_spec) 129 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | AutoModelForCausalLM, 3 | AutoTokenizer, 4 | MimiModel, 5 | AutoFeatureExtractor, 6 | StoppingCriteria, 7 | ) 8 | import torch 9 | import torchaudio 10 | import re 11 | import requests 12 | import io 13 | 14 | 15 | def audio_array_to_text( 16 | audio_array: torch.tensor, 17 | audio_tokenizer, 18 | feature_extractor, 19 | num_quantizers: int, 20 | ) -> str: 21 | inputs = feature_extractor( 22 | raw_audio=audio_array, 23 | sampling_rate=feature_extractor.sampling_rate, 24 | return_tensors="pt", 25 | ).to(audio_tokenizer.device) 26 | with torch.no_grad(): 27 | encoder_outputs = audio_tokenizer.encode( 28 | inputs["input_values"], 29 | inputs["padding_mask"], 30 | num_quantizers=num_quantizers, 31 | ) 32 | flatten_audio_codes = encoder_outputs.audio_codes.transpose(1, 2).reshape(-1) 33 | assert flatten_audio_codes.numel() % num_quantizers == 0 34 | steps = [] 35 | for i in range(0, flatten_audio_codes.numel(), num_quantizers): 36 | group = [ 37 | f"<{flatten_audio_codes[i + j].item()}_{j}>" for j in range(num_quantizers) 38 | ] 39 | steps.append(group) 40 | 41 | parts = [tok for step in steps for tok in step] 42 | 43 | text = "".join(parts) 44 | 45 | return f"" 46 | 47 | 48 | def text_to_audio_values( 49 | text: str, 50 | num_quantizers: int, 51 | output_file: str, 52 | audio_tokenizer, 53 | feature_extractor, 54 | ): 55 | # Extract (val, idx) pairs from the format in the text 56 | matches = re.findall(r"<(\d+)_(\d+)>", text) 57 | vals = [] 58 | for i in range(0, len(matches), num_quantizers): 59 | chunk = matches[i : i + num_quantizers] 60 | if len(chunk) < num_quantizers: 61 | break 62 | indices = [int(idx) for _, idx in chunk] 63 | if indices == list(range(num_quantizers)): 64 | vals.extend(int(val) for val, _ in chunk) 65 | else: 66 | break 67 | vals = vals[: len(vals) - len(vals) % num_quantizers] 68 | tensor_bt4 = torch.tensor(vals).reshape(1, -1, num_quantizers) # (B, T, 4) 69 | tensor_b4t = tensor_bt4.transpose(1, 2) # (B, 4, T) 70 | audio_values = audio_tokenizer.decode(tensor_b4t)[0] 71 | torchaudio.save( 72 | output_file, 73 | audio_values[0].detach().cpu(), 74 | feature_extractor.sampling_rate, 75 | ) 76 | 77 | 78 | class StopOnAudioEnd(StoppingCriteria): 79 | def __init__(self, tokenizer): 80 | self.tokenizer = tokenizer 81 | self.target_text = "" 82 | self.target_ids = tokenizer( 83 | self.target_text, add_special_tokens=False 84 | ).input_ids 85 | 86 | def __call__(self, input_ids, scores, **kwargs): 87 | if len(input_ids[0]) < len(self.target_ids): 88 | return False 89 | return input_ids[0][-len(self.target_ids) :].tolist() == self.target_ids 90 | 91 | 92 | temperature = 0.8 93 | top_k = 30 94 | do_sample = True 95 | max_length = 1024 96 | device = "cuda" if torch.cuda.is_available() else "cpu" 97 | model_id = "llm-jp/Llama-Mimi-1.3B" 98 | model = ( 99 | AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) 100 | .eval() 101 | .to(device) 102 | ) 103 | num_quantizers = model.config.num_quantizers 104 | tokenizer = AutoTokenizer.from_pretrained(model_id) 105 | audio_tokenizer = MimiModel.from_pretrained("kyutai/mimi") 106 | feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi") 107 | stopping_criteria = StopOnAudioEnd(tokenizer) 108 | 109 | audio_url = ( 110 | "https://speed1313.github.io/llama-mimi/data/prompt/natural/great_day_gt.wav" 111 | ) 112 | response = requests.get(audio_url) 113 | response.raise_for_status() 114 | waveform, sample_rate = torchaudio.load(io.BytesIO(response.content)) 115 | if sample_rate != feature_extractor.sampling_rate: 116 | waveform = torchaudio.transforms.Resample( 117 | sample_rate, feature_extractor.sampling_rate 118 | )(waveform) 119 | sample_rate = feature_extractor.sampling_rate 120 | prompt_array = waveform.squeeze().cpu().numpy() 121 | 122 | text = audio_array_to_text( 123 | prompt_array, audio_tokenizer, feature_extractor, num_quantizers 124 | ) 125 | 126 | text = text.replace("", "") 127 | inputs = tokenizer(text, return_tensors="pt").to(device) 128 | 129 | with torch.no_grad(): 130 | generated = model.generate( 131 | **inputs, 132 | max_length=max_length, 133 | do_sample=do_sample, 134 | temperature=temperature, 135 | top_k=top_k, 136 | stopping_criteria=[stopping_criteria], 137 | ) 138 | 139 | generated_text = tokenizer.decode(generated[0]) 140 | 141 | text_to_audio_values( 142 | generated_text, 143 | num_quantizers=num_quantizers, 144 | output_file="output.wav", 145 | audio_tokenizer=audio_tokenizer, 146 | feature_extractor=feature_extractor, 147 | ) 148 | -------------------------------------------------------------------------------- /torchtitan/tools/profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import contextlib 8 | import os 9 | import pickle 10 | import time 11 | 12 | import torch 13 | 14 | from torchtitan.config_manager import JobConfig 15 | from torchtitan.tools.logging import logger 16 | 17 | # the number of warmup steps before the active step in each profiling cycle 18 | WARMUP = 3 19 | 20 | # how much memory allocation/free ops to record in memory snapshots 21 | MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 22 | 23 | 24 | @contextlib.contextmanager 25 | def maybe_enable_profiling(config: JobConfig, *, global_step: int = 0): 26 | # get user defined profiler settings 27 | enable_profiling = config.profiling.enable_profiling 28 | 29 | if enable_profiling: 30 | dump_dir = config.job.dump_folder 31 | save_trace_dir = config.profiling.save_traces_folder 32 | trace_dir = os.path.join(dump_dir, save_trace_dir) 33 | profile_freq = config.profiling.profile_freq 34 | 35 | rank = torch.distributed.get_rank() 36 | 37 | replica_id = None 38 | if config.fault_tolerance.enable: 39 | replica_id = config.fault_tolerance.replica_id 40 | 41 | def trace_handler(prof): 42 | curr_trace_dir_name = "iteration_" + str(prof.step_num) 43 | curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name) 44 | if not os.path.exists(curr_trace_dir): 45 | os.makedirs(curr_trace_dir, exist_ok=True) 46 | 47 | logger.info(f"Dumping profiler traces at step {prof.step_num}") 48 | begin = time.monotonic() 49 | 50 | output_file = curr_trace_dir 51 | if replica_id is not None: 52 | output_file = os.path.join(output_file, f"replica{replica_id}") 53 | output_file = os.path.join(output_file, f"rank{rank}_trace.json") 54 | 55 | prof.export_chrome_trace(output_file) 56 | logger.info( 57 | f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" 58 | ) 59 | 60 | logger.info(f"Profiling active. Traces will be saved at {trace_dir}") 61 | 62 | if not os.path.exists(trace_dir): 63 | os.makedirs(trace_dir, exist_ok=True) 64 | 65 | warmup, active = WARMUP, 1 66 | wait = profile_freq - (active + warmup) 67 | assert wait >= 0, ( 68 | "profile_freq must be greater than or equal to warmup + active" 69 | ) 70 | gpu_device_profiled = None 71 | if torch.cuda.is_available(): 72 | gpu_device_profiled = torch.profiler.ProfilerActivity.CUDA 73 | elif torch.xpu.is_available(): 74 | gpu_device_profiled = torch.profiler.ProfilerActivity.XPU 75 | with torch.profiler.profile( 76 | activities=[ 77 | torch.profiler.ProfilerActivity.CPU, 78 | gpu_device_profiled, 79 | ], 80 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), 81 | on_trace_ready=trace_handler, 82 | record_shapes=True, 83 | ) as torch_profiler: 84 | torch_profiler.step_num = global_step 85 | yield torch_profiler 86 | else: 87 | torch_profiler = contextlib.nullcontext() 88 | yield None 89 | 90 | 91 | @contextlib.contextmanager 92 | def maybe_enable_memory_snapshot(config: JobConfig, *, global_step: int = 0): 93 | enable_snapshot = config.profiling.enable_memory_snapshot 94 | if enable_snapshot: 95 | snapshot_folder = config.profiling.save_memory_snapshot_folder 96 | snapshot_dir = os.path.join(config.job.dump_folder, snapshot_folder) 97 | if not os.path.exists(snapshot_dir): 98 | os.makedirs(snapshot_dir, exist_ok=True) 99 | rank = torch.distributed.get_rank() 100 | 101 | class MemoryProfiler: 102 | def __init__(self, step_num: int, freq: int): 103 | torch.cuda.memory._record_memory_history( 104 | max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES 105 | ) 106 | # when resume training, we start from the last step 107 | self.step_num = step_num 108 | self.freq = freq 109 | 110 | def step(self, exit_ctx: bool = False): 111 | self.step_num += 1 112 | if not exit_ctx and self.step_num % self.freq != 0: 113 | return 114 | if not exit_ctx: 115 | curr_step = self.step_num 116 | dir_name = f"iteration_{curr_step}" 117 | else: 118 | # dump as iteration_0_exit if OOM at iter 1 119 | curr_step = self.step_num - 1 120 | dir_name = f"iteration_{curr_step}_exit" 121 | curr_snapshot_dir = os.path.join(snapshot_dir, dir_name) 122 | if not os.path.exists(curr_snapshot_dir): 123 | os.makedirs(curr_snapshot_dir, exist_ok=True) 124 | logger.info(f"Dumping memory snapshot at step {curr_step}") 125 | begin = time.monotonic() 126 | with open( 127 | f"{curr_snapshot_dir}/rank{rank}_memory_snapshot.pickle", "wb" 128 | ) as output: 129 | pickle.dump(torch.cuda.memory._snapshot(), output) 130 | logger.info( 131 | f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" 132 | ) 133 | 134 | logger.info(f"Memory profiler active. Snapshot will be saved at {snapshot_dir}") 135 | profiler = MemoryProfiler(global_step, config.profiling.profile_freq) 136 | try: 137 | yield profiler 138 | except torch.OutOfMemoryError: 139 | profiler.step(exit_ctx=True) 140 | else: 141 | yield None 142 | -------------------------------------------------------------------------------- /torchtitan/models/llama3/infra/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # This file applies the PT-D pipeline parallelism to the Llama model. 8 | 9 | import copy 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.distributed import DeviceMesh 14 | from torch.distributed.pipelining import PipelineStage 15 | from torch.distributed.pipelining.schedules import ( 16 | _PipelineSchedule, 17 | get_schedule_class, 18 | ScheduleZBVZeroBubble, 19 | ) 20 | 21 | from torchtitan.components.loss import LossFunction 22 | from torchtitan.config_manager import JobConfig 23 | from torchtitan.distributed import ParallelDims 24 | from torchtitan.distributed.pipeline import ( 25 | build_pipeline_schedule, 26 | generate_split_points, 27 | stage_ids_this_rank, 28 | ) 29 | from torchtitan.protocols.train_spec import ParallelizeFunction 30 | from torchtitan.tools.logging import logger 31 | 32 | from ..model.args import TransformerModelArgs 33 | 34 | 35 | def pipeline_llama( 36 | model: nn.Module, 37 | parallel_dims: ParallelDims, 38 | job_config: JobConfig, 39 | device: torch.device, 40 | model_config: TransformerModelArgs, 41 | parallelize_fn: ParallelizeFunction, 42 | loss_fn: LossFunction, 43 | ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: 44 | pp_mesh = parallel_dims.world_mesh["pp"] 45 | 46 | stages, model_parts = pipeline_llama_manual_split( 47 | model, pp_mesh, parallel_dims, job_config, device, model_config 48 | ) 49 | 50 | # For PP with looped schedules, each item in model_parts is one stage-model-chunk. 51 | # We need to iterate through model_parts to apply SPMD parallelisms, compilation, 52 | # optimizer, and checkpointing 53 | for i, m in enumerate(model_parts): 54 | # apply SPMD-style PT-D techniques 55 | m = parallelize_fn(m, parallel_dims, job_config) 56 | model_parts[i] = m 57 | # NOTE: this is to update the model in the stage 58 | # in case the model is modified e.g. by torch.compile 59 | stages[i].submod = m 60 | 61 | pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) 62 | 63 | # This is used in the train loop to determine whether to pass in the input_ids and labels 64 | has_first_stage = False 65 | has_last_stage = False 66 | for stage in stages: 67 | if stage.is_first: 68 | has_first_stage = True 69 | if stage.is_last: 70 | has_last_stage = True 71 | 72 | return pp_schedule, model_parts, has_first_stage, has_last_stage 73 | 74 | 75 | def pipeline_llama_manual_split( 76 | whole_model: nn.Module, 77 | pp_mesh: DeviceMesh, 78 | parallel_dims: ParallelDims, 79 | job_config: JobConfig, 80 | device: torch.device, 81 | model_config: TransformerModelArgs, 82 | ) -> tuple[list[PipelineStage], list[nn.Module]]: 83 | """ 84 | This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. 85 | 86 | It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. 87 | 88 | The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD 89 | parallelism. 90 | """ 91 | pp_rank = pp_mesh.get_local_rank() 92 | pp_size = pp_mesh.size() 93 | parallelism_config = job_config.parallelism 94 | 95 | splits = parallelism_config.pipeline_parallel_split_points or generate_split_points( 96 | parallelism_config.pipeline_parallel_schedule, 97 | parallel_dims.pp, 98 | model_config.n_layers, 99 | parallelism_config.pipeline_parallel_layers_per_stage, 100 | ) 101 | 102 | def _build_stage( 103 | stage_idx: int, 104 | start_layer: str | None, 105 | stop_layer: str | None, 106 | is_first: bool = False, 107 | is_last: bool = False, 108 | ) -> tuple[PipelineStage, nn.Module]: 109 | model = copy.deepcopy(whole_model) 110 | if not is_first: 111 | model.tok_embeddings = None 112 | 113 | drop_layers = start_layer is not None 114 | for name in list(model.layers.keys()): 115 | # we keep layers in a contiguous region between start (inclusive) and stop (exclusive) 116 | if f"layers.{name}" == start_layer: 117 | drop_layers = False 118 | if f"layers.{name}" == stop_layer: 119 | drop_layers = True 120 | if drop_layers: 121 | del model.layers[name] 122 | 123 | if not is_last: 124 | model.norm = None 125 | model.output = None 126 | 127 | stage = PipelineStage( 128 | model, 129 | stage_idx, 130 | num_stages, 131 | device, 132 | group=pp_mesh.get_group("pp"), 133 | ) 134 | return stage, model 135 | 136 | num_stages = len(splits) + 1 137 | stage_idx = pp_rank 138 | 139 | stages = [] 140 | models = [] 141 | 142 | schedule_class = get_schedule_class(parallelism_config.pipeline_parallel_schedule) 143 | style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" 144 | 145 | for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): 146 | start_layer = splits[stage_idx - 1] if stage_idx > 0 else None 147 | stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None 148 | stage, model_chunk = _build_stage( 149 | stage_idx, 150 | start_layer, 151 | stop_layer, 152 | is_first=stage_idx == 0, 153 | is_last=stage_idx == num_stages - 1, 154 | ) 155 | logger.info( 156 | f"PP rank {pp_rank} is building stage_idx {stage_idx}" 157 | f" with start_layer {start_layer}, stop_layer {stop_layer}" 158 | ) 159 | stages.append(stage) 160 | models.append(model_chunk) 161 | return stages, models 162 | -------------------------------------------------------------------------------- /torchtitan/components/ft.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import importlib 8 | from contextlib import nullcontext 9 | from datetime import timedelta 10 | from typing import ContextManager, Optional, TYPE_CHECKING, Union 11 | 12 | import torch 13 | import torch.distributed as dist 14 | from torch.distributed._composable.fsdp.fully_shard import FSDPModule 15 | from torch.distributed.distributed_c10d import ReduceOp 16 | from torchtitan.config_manager import FaultTolerance as FTConfig 17 | 18 | if importlib.util.find_spec("torchft") is not None: 19 | import torchft as ft 20 | 21 | if TYPE_CHECKING: 22 | from torchft import local_sgd 23 | 24 | has_torchft = True 25 | else: 26 | has_torchft = False 27 | 28 | 29 | class FTManager: 30 | def __init__( 31 | self, 32 | ft_config: FTConfig, 33 | ) -> None: 34 | if not ft_config.enable: 35 | self._manager = None 36 | return 37 | 38 | if not has_torchft: 39 | raise ImportError("torchft is not installed. Please install it.") 40 | 41 | process_group_timeout = timedelta( 42 | milliseconds=ft_config.process_group_timeout_ms 43 | ) 44 | if ft_config.process_group == "gloo": 45 | pg = ft.ProcessGroupGloo(timeout=process_group_timeout) 46 | elif ft_config.process_group == "nccl": 47 | pg = ft.ProcessGroupNCCL(timeout=process_group_timeout) 48 | else: 49 | raise ValueError(f"Unsuported process group: {ft_config.process_group}") 50 | 51 | # If the training method is specific, then the quorum should be synchronous 52 | self.use_async_quorum = ft_config.semi_sync_method is None 53 | 54 | self._manager = ft.Manager( 55 | pg=pg, 56 | min_replica_size=ft_config.min_replica_size, 57 | load_state_dict=None, 58 | state_dict=None, 59 | use_async_quorum=self.use_async_quorum, 60 | replica_id=f"torchtitan_ft_{ft_config.replica_id}", 61 | ) 62 | self.group_size = ft_config.group_size 63 | self.replica_id = ft_config.replica_id 64 | 65 | if self.use_async_quorum: 66 | self.replicate_pg = ft.process_group.ManagedProcessGroup(self._manager) 67 | self.replicate_pg.register("dp_replicate") 68 | 69 | @property 70 | def enabled(self) -> bool: 71 | return self._manager is not None 72 | 73 | @property 74 | def manager(self) -> "ft.Manager": 75 | assert self._manager is not None 76 | return self._manager 77 | 78 | def get_dp_info(self, dp_degree: int, dp_rank: int) -> tuple[int, int]: 79 | if self.enabled: 80 | return dp_degree * self.group_size, dp_degree * self.replica_id + dp_rank 81 | else: 82 | return dp_degree, dp_rank 83 | 84 | def maybe_set_all_reduce_hook(self, model_parts: list[torch.nn.Module]) -> None: 85 | if self.enabled and self.use_async_quorum: 86 | 87 | def all_reduce_hook(output): 88 | dist.all_reduce(output, group=self.replicate_pg, op=ReduceOp.AVG) 89 | 90 | def apply_set_all_reduce_hook(m): 91 | if isinstance(m, FSDPModule): 92 | m.set_all_reduce_hook(all_reduce_hook) 93 | 94 | for model_part in model_parts: 95 | model_part.apply(apply_set_all_reduce_hook) 96 | 97 | @property 98 | def loss_sync_pg( 99 | self, 100 | ) -> Optional["ft.process_group.ManagedProcessGroup"]: 101 | if self.enabled and self.use_async_quorum: 102 | return self.replicate_pg 103 | else: 104 | # skip loss sync when using semi-sync training 105 | return None 106 | 107 | 108 | def maybe_semi_sync_training( 109 | ft_config: FTConfig, 110 | ft_manager: FTManager, 111 | model_parts: list[torch.nn.Module], 112 | optimizer: torch.optim.Optimizer, 113 | ) -> ContextManager[Union["local_sgd.DiLoCo", "local_sgd.LocalSGD", None]]: 114 | """ 115 | If TorchFT is enabled and the config is set, use semi_sync_method 116 | """ 117 | semi_sync_method = ft_config.semi_sync_method 118 | if ft_config.enable and semi_sync_method is not None: 119 | from torchft import local_sgd 120 | 121 | assert ft_manager._manager is not None, ( 122 | "FTManager must be enabled to use semi-sync training." 123 | ) 124 | if semi_sync_method.lower() == "diloco": 125 | # Create the outer optimizer based on the inner optimizer parameters. 126 | params = [group["params"] for group in optimizer.param_groups] 127 | params = [param for sublist in params for param in sublist] 128 | outer_optimizers = [] 129 | for model in model_parts: 130 | params = [p for p in model.parameters() if p.requires_grad] 131 | outer_optimizer = torch.optim.SGD( 132 | params, lr=0.7, momentum=0.9, nesterov=True 133 | ) 134 | outer_optimizers.append(outer_optimizer) 135 | 136 | return local_sgd.DiLoCo( 137 | manager=ft_manager._manager, 138 | model_fragments=model_parts, 139 | inner_optimizer=optimizer, 140 | outer_optimizer=outer_optimizers, 141 | sync_every=ft_config.sync_steps, 142 | should_quantize=ft_config.should_quantize, 143 | fragment_sync_delay=ft_config.fragment_sync_delay, 144 | fragment_update_alpha=ft_config.fragment_update_alpha, 145 | ) 146 | elif semi_sync_method.lower() == "local_sgd": 147 | assert len(model_parts) == 1 148 | return local_sgd.LocalSGD( 149 | manager=ft_manager._manager, 150 | model=model_parts[0], 151 | optimizer=optimizer, 152 | sync_every=ft_config.sync_steps, 153 | ) 154 | else: 155 | raise ValueError( 156 | f"Unknown training method: {semi_sync_method}, only 'diloco' and 'local_sgd' are supported." 157 | ) 158 | return nullcontext() 159 | -------------------------------------------------------------------------------- /torchtitan/components/validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Generator 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.distributed.fsdp import FSDPModule 12 | from torchtitan.components.dataloader import BaseDataLoader 13 | from torchtitan.components.metrics import MetricsProcessor 14 | from torchtitan.components.tokenizer import BaseTokenizer 15 | from torchtitan.config_manager import JobConfig 16 | from torchtitan.datasets.hf_datasets import build_hf_validation_dataloader 17 | from torchtitan.distributed import ParallelDims, utils as dist_utils 18 | from torchtitan.tools import utils 19 | 20 | 21 | class BaseValidator: 22 | def __init__(self, job_config: JobConfig): 23 | self.job_config = job_config 24 | 25 | def validate(self, model_parts: list[nn.Module]) -> dict[str, float]: 26 | raise NotImplementedError("validate method not implemented") 27 | 28 | def should_validate(self, step: int) -> bool: 29 | return step % self.job_config.validation.freq == 0 30 | 31 | 32 | class Validator(BaseValidator): 33 | """ 34 | Simple validator focused on correctness and integration. 35 | 36 | Args: 37 | job_config: Job configuration 38 | validation_dataloader: The validation dataloader 39 | model: The model to validate (single model, no parallelism) 40 | """ 41 | 42 | validation_dataloader: BaseDataLoader 43 | 44 | def __init__( 45 | self, 46 | job_config: JobConfig, 47 | dp_world_size: int, 48 | dp_rank: int, 49 | tokenizer: BaseTokenizer, 50 | audio_tokenizer, 51 | feature_extractor, 52 | parallel_dims: ParallelDims, 53 | validation_context: Generator[None, None, None], 54 | maybe_enable_amp: Generator[None, None, None], 55 | metrics_processor: MetricsProcessor, 56 | ): 57 | self.job_config = job_config 58 | self.parallel_dims = parallel_dims 59 | self.validation_dataloader = build_hf_validation_dataloader( 60 | dp_world_size=dp_world_size, 61 | dp_rank=dp_rank, 62 | tokenizer=tokenizer, 63 | audio_tokenizer=audio_tokenizer, 64 | feature_extractor=feature_extractor, 65 | job_config=job_config, 66 | ) 67 | self.validation_context = validation_context 68 | self.maybe_enable_amp = maybe_enable_amp 69 | self.metrics_processor = metrics_processor 70 | 71 | @torch.no_grad() 72 | def validate( 73 | self, 74 | model_parts: list[nn.Module], 75 | step: int, 76 | ) -> dict[str, float]: 77 | # Set model to eval mode 78 | # TODO: currently does not support pipeline parallelism 79 | model = model_parts[0] 80 | model.eval() 81 | 82 | parallel_dims = self.parallel_dims 83 | 84 | accumulated_losses = [] 85 | device_type = utils.device_type 86 | num_steps = 0 87 | 88 | for input_dict in self.validation_dataloader: 89 | if ( 90 | self.job_config.validation.steps != -1 91 | and num_steps >= self.job_config.validation.steps 92 | ): 93 | break 94 | 95 | self.metrics_processor.ntokens_since_last_log += input_dict[ 96 | "input_ids" 97 | ].numel() 98 | input_ids = input_dict["input_ids"].to(device_type) 99 | attention_mask = input_dict["attention_mask"].to(device_type) 100 | 101 | optional_context_parallel_ctx = ( 102 | dist_utils.create_context_parallel_ctx( 103 | cp_mesh=parallel_dims.world_mesh["cp"], 104 | cp_buffers=[input_ids], 105 | cp_seq_dims=[1], 106 | cp_no_restore_buffers={inputs}, 107 | cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, 108 | ) 109 | if parallel_dims.cp_enabled 110 | else None 111 | ) 112 | 113 | with self.validation_context(optional_context_parallel_ctx): 114 | assert len(model_parts) == 1 115 | with self.maybe_enable_amp: 116 | outputs = model_parts[0]( 117 | input_ids=input_ids, 118 | attention_mask=attention_mask, 119 | labels=input_ids, 120 | ) 121 | loss = outputs.loss 122 | 123 | accumulated_losses.append(loss.detach()) 124 | 125 | num_steps += 1 126 | 127 | # Compute average loss 128 | loss = torch.sum(torch.stack(accumulated_losses)) 129 | loss /= num_steps 130 | if parallel_dims.dp_cp_enabled: 131 | global_avg_loss = dist_utils.dist_mean( 132 | loss, parallel_dims.world_mesh["dp_cp"] 133 | ) 134 | else: 135 | global_avg_loss = loss.item() 136 | 137 | self.metrics_processor.log_validation(loss=global_avg_loss, step=step) 138 | 139 | # Reshard after run forward pass 140 | # This is to ensure the model weights are sharded the same way for checkpoint saving. 141 | for module in model.modules(): 142 | if isinstance(module, FSDPModule): 143 | module.reshard() 144 | 145 | # Set model back to train mode 146 | model.train() 147 | 148 | 149 | def build_validator( 150 | job_config: JobConfig, 151 | dp_world_size: int, 152 | dp_rank: int, 153 | tokenizer: BaseTokenizer, 154 | audio_tokenizer, 155 | feature_extractor, 156 | parallel_dims: ParallelDims, 157 | validation_context: Generator[None, None, None], 158 | maybe_enable_amp: Generator[None, None, None], 159 | metrics_processor: MetricsProcessor | None = None, 160 | ) -> BaseValidator: 161 | """Build a simple validator focused on correctness.""" 162 | return Validator( 163 | job_config=job_config, 164 | dp_world_size=dp_world_size, 165 | dp_rank=dp_rank, 166 | tokenizer=tokenizer, 167 | audio_tokenizer=audio_tokenizer, 168 | feature_extractor=feature_extractor, 169 | parallel_dims=parallel_dims, 170 | validation_context=validation_context, 171 | maybe_enable_amp=maybe_enable_amp, 172 | metrics_processor=metrics_processor, 173 | ) 174 | -------------------------------------------------------------------------------- /torchtitan/tools/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import gc 8 | import subprocess 9 | import time 10 | from dataclasses import dataclass 11 | from typing import Optional 12 | 13 | import torch 14 | from torch._utils import _get_available_device_type, _get_device_module 15 | 16 | from torchtitan.tools.logging import logger 17 | 18 | 19 | def has_cuda_capability(major: int, minor: int) -> bool: 20 | return torch.cuda.is_available() and torch.cuda.get_device_capability() >= ( 21 | major, 22 | minor, 23 | ) 24 | 25 | 26 | def get_device_info() -> tuple[str, torch.device]: 27 | device_type = _get_available_device_type() or "cuda" 28 | device_module = _get_device_module(device_type) # default device_module:torch.cuda 29 | return device_type, device_module 30 | 31 | 32 | device_type, device_module = get_device_info() 33 | 34 | 35 | # used to avoid stragglers in garbage collection 36 | class GarbageCollection: 37 | def __init__(self, gc_freq: int = 1000, debug: bool = False): 38 | assert gc_freq > 0, "gc_freq must be a positive integer" 39 | self.gc_freq = gc_freq 40 | self.debug = debug 41 | gc.disable() 42 | self.collect("Initial GC collection.") 43 | if debug: 44 | from torch.utils.viz._cycles import warn_tensor_cycles 45 | 46 | if torch.distributed.get_rank() == 0: 47 | warn_tensor_cycles() 48 | 49 | def run(self, step_count: int): 50 | if self.debug: 51 | self.collect( 52 | "Force GC to perform collection to obtain debug information.", 53 | generation=2, 54 | ) 55 | gc.collect() 56 | elif step_count > 1 and step_count % self.gc_freq == 0: 57 | self.collect("Peforming periodical GC collection.") 58 | 59 | @staticmethod 60 | def collect(reason: str, generation: int = 1): 61 | begin = time.monotonic() 62 | gc.collect(generation) 63 | logger.info("[GC] %s %.2f seconds.", reason, time.monotonic() - begin) 64 | 65 | 66 | # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC 67 | def get_peak_flops(device_name: str) -> int: 68 | try: 69 | # Run the lspci command and capture the output 70 | result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) 71 | # Filter the output for lines containing both "NVIDIA" and "H100" 72 | filtered_lines = [ 73 | line 74 | for line in result.stdout.splitlines() 75 | if "NVIDIA" in line and "H100" in line 76 | ] 77 | # Join all filtered lines into a single string 78 | device_name = " ".join(filtered_lines) or device_name 79 | except FileNotFoundError as e: 80 | logger.warning(f"Error running lspci: {e}, fallback to use device_name") 81 | if "A100" in device_name: 82 | # data from https://www.nvidia.com/en-us/data-center/a100/ 83 | return 312e12 84 | elif "H100" in device_name: 85 | # data from https://www.nvidia.com/en-us/data-center/h100/ 86 | # NOTE: Specifications are one-half lower without sparsity. 87 | if "NVL" in device_name: 88 | return 835e12 89 | elif "PCIe" in device_name: 90 | return 756e12 91 | else: # for H100 SXM and other variants 92 | return 989e12 93 | elif "H200" in device_name: 94 | # data from https://www.nvidia.com/en-us/data-center/h200/ 95 | return 989e12 96 | elif "B200" in device_name: 97 | # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 98 | return 2.25e15 99 | elif "MI300X" in device_name or "MI325X" in device_name: 100 | # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html 101 | # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html 102 | return 1300e12 103 | elif "MI250X" in device_name: 104 | # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD) 105 | return 191.5e12 106 | elif "Data Center GPU Max 1550" in device_name: 107 | # Also known as Ponte Vecchio (PVC). 108 | # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html 109 | # Dot Product Accumulate Systolic (DPAS): 110 | # - Freq: 1300MHz 111 | # - #ops: 512 112 | # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16) 113 | # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16) 114 | max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units 115 | return 512 * max_comp_units * 1300 * 10**6 116 | elif "l40s" in device_name: 117 | # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413" 118 | return 362e12 119 | 120 | else: # for other GPU types, assume A100 121 | logger.warning(f"Peak flops undefined for: {device_name}, fallback to A100") 122 | return 312e12 123 | 124 | 125 | @dataclass(frozen=True) 126 | class Color: 127 | black = "\033[30m" 128 | red = "\033[31m" 129 | green = "\033[32m" 130 | yellow = "\033[33m" 131 | blue = "\033[34m" 132 | magenta = "\033[35m" 133 | cyan = "\033[36m" 134 | white = "\033[37m" 135 | reset = "\033[39m" 136 | orange = "\033[38;2;180;60;0m" 137 | turquoise = "\033[38;2;54;234;195m" 138 | 139 | 140 | @dataclass(frozen=True) 141 | class NoColor: 142 | black = "" 143 | red = "" 144 | green = "" 145 | yellow = "" 146 | blue = "" 147 | magenta = "" 148 | cyan = "" 149 | white = "" 150 | reset = "" 151 | orange = "" 152 | turquoise = "" 153 | 154 | 155 | assert set(NoColor.__dataclass_fields__.keys()) == set( 156 | Color.__dataclass_fields__.keys() 157 | ), "NoColor must have the same fields as Color." 158 | 159 | 160 | def check_if_feature_in_pytorch( 161 | feature_name: str, 162 | pull_request: str, 163 | min_nightly_version: Optional[str] = None, 164 | ) -> None: 165 | if "git" in torch.__version__: # pytorch is built from source 166 | # notify users to check if the pull request is included in their pytorch 167 | logger.warning( 168 | "detected that the pytorch is built from source. Please make sure the PR " 169 | f"({pull_request_link}) is included in pytorch for correct {feature_name}." 170 | ) 171 | elif min_nightly_version is not None and torch.__version__ < min_nightly_version: 172 | logger.warning( 173 | f"detected that the pytorch version {torch.__version__} is older than " 174 | f"{min_nightly_version}. Please upgrade a newer version to include the " 175 | f"change in ({pull_request_link}) for correct {feature_name}." 176 | ) 177 | -------------------------------------------------------------------------------- /torchtitan/components/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import copy 8 | import functools 9 | import math 10 | from typing import Any, Callable, Iterator 11 | 12 | from torch.distributed.checkpoint.stateful import Stateful 13 | from torch.optim.lr_scheduler import LambdaLR, LRScheduler 14 | 15 | from torchtitan.components.optimizer import OptimizersContainer 16 | from torchtitan.config_manager import JobConfig 17 | from torchtitan.tools.logging import logger 18 | 19 | __all__ = [ 20 | "LRSchedulersContainer", 21 | "build_lr_schedulers", 22 | ] 23 | 24 | 25 | class LRSchedulersContainer(Stateful): 26 | """Container for multiple learning rate schedulers. 27 | 28 | This class is used to wrap multiple LRSchedulers into a single object that can be 29 | used to reduce the complexity of the training loop. This mimics the behavior of 30 | ``torch.optim.lr_scheduler.LRScheduler``. The design concept is the same as 31 | ``OptimizersContainer``. This class currently only supports ``LambdaLR``. 32 | 33 | **Note** 34 | Users who want to customize the lr_scheduler behavior can inherit from this class and 35 | extend the functionality as needed. The following methods must follow the same 36 | signature as ``torch.optim.lr_scheduler.LRScheduler`` class: ``step()``, ``state_dict()``, 37 | ``load_state_dict()``. 38 | 39 | **Limitations** 40 | This class assumes all the lr schedulers are the same. There is no easy way to support 41 | resharding for multiple different LRSchedulers because LRScheduler.state_dict() is not 42 | resharding friendly. Therefore, the limitation is used to allow TorchTitan to support 43 | lr scheduler resharding. 44 | 45 | Args: 46 | optimizers (OptimizersContainer): The corresponding optimizers for the lr_schedulers. 47 | """ 48 | 49 | schedulers: list[LRScheduler] 50 | 51 | def __init__(self, optimizers: OptimizersContainer, lr_lambda: Callable) -> None: 52 | assert len(optimizers) > 0, ( 53 | "Must have at least one optimizer to create LRScheduler" 54 | ) 55 | 56 | self.schedulers = [LambdaLR(optimizer, lr_lambda) for optimizer in optimizers] 57 | 58 | def __iter__(self) -> Iterator[LRScheduler]: 59 | return iter(self.schedulers) 60 | 61 | def __len__(self) -> int: 62 | return len(self.schedulers) 63 | 64 | def step(self) -> None: 65 | for scheduler in self.schedulers: 66 | scheduler.step() 67 | 68 | def state_dict(self) -> dict[str, Any]: 69 | # While there may be multiple schedulers, we only save the first one because 70 | # the state_dict is the same for all. See the limitations section in the 71 | # docstring. 72 | return self.schedulers[0].state_dict() 73 | 74 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 75 | # Load the same state_dict for all schedulers. The key value we're concerned 76 | # within ``LRScheduler.state_dict()`` is ``last_epoch``, which is an integer 77 | # that is immutable. As long as ``training.steps`` and ``lr_scheduler.warmup_steps`` 78 | # in ``job_config`` remain unchanged when resuming from a checkpoint, this 79 | # approach is safe. We call ``copy()`` here to ensure extra safety. 80 | for scheduler in self.schedulers: 81 | scheduler.load_state_dict(copy.deepcopy(state_dict)) 82 | 83 | 84 | def build_lr_schedulers( 85 | optimizers: OptimizersContainer, job_config: JobConfig 86 | ) -> LRSchedulersContainer: 87 | """Create a LRSchedulerContainer for the given optimizers and job config. 88 | 89 | This function creates a ``LRSchedulersContainer`` for the given optimizers. 90 | ``job_config`` should define the correct lr scheduler parameters. 91 | 92 | **Note** 93 | Users who want to customize the lr scheduler behavior can create their own 94 | ``LRSchedulersContainer`` subclass and ``build_lr_scheduler``. Passing the 95 | customized ``build_lr_schedulers`` to ``TrainSpec`` will create the customized 96 | ``LRSchedulersContainer``. 97 | 98 | 99 | Args: 100 | optimizers (OptimizersContainer): The corresponding optimizers for the 101 | lr_schedulers. 102 | """ 103 | training_steps = job_config.training.steps 104 | warmup_steps = int(job_config.lr_scheduler.warmup_steps) 105 | 106 | if warmup_steps > training_steps: 107 | logger.warning( 108 | f"Warmup steps ({warmup_steps}) exceed total training steps ({training_steps}). " 109 | f"Adjusting warmup steps to {training_steps}." 110 | ) 111 | warmup_steps = training_steps 112 | 113 | if job_config.lr_scheduler.decay_ratio is not None: 114 | decay_steps = round(training_steps * job_config.lr_scheduler.decay_ratio) 115 | if warmup_steps + decay_steps > training_steps: 116 | logger.warning( 117 | f"Warmup ({warmup_steps}) + decay ({decay_steps}) steps exceed " 118 | f"total training steps ({training_steps}). " 119 | f"Adjusting decay steps to {training_steps - warmup_steps}." 120 | ) 121 | decay_steps = training_steps - warmup_steps 122 | else: 123 | decay_steps = training_steps - warmup_steps 124 | # Add a vitual last step to prevent the learning rate from dropping to 0 125 | stable_steps = training_steps + 1 - warmup_steps - decay_steps 126 | lr_decay_type = job_config.lr_scheduler.decay_type 127 | lr_min = job_config.lr_scheduler.lr_min 128 | 129 | def linear_warmup_stable_decay( 130 | current_step: int, 131 | warmup_steps: int, 132 | stable_steps: int, 133 | decay_steps: int, 134 | lr_decay_type: str, 135 | lr_min: float, 136 | ): 137 | """ 138 | Computes linear warmup followed by stable learning rate for a while, 139 | then some type of decay. 140 | 141 | Per LambdaLR requirement, this is accomplished by returning 142 | a multiplicative factor `curr_adjustment` ranging from 1 to 0 143 | to adjust the learning rate to create the desired schedule. 144 | 145 | We offer three types of learning rate decay schedules: 146 | 1. `linear`: decays linearly from 1 to 0 over the decay period. 147 | 2. `sqrt`: decays as 1 minus the square root of the decay progress. 148 | 3. `cosine`: follows a cosine curve, decaying according to the values of the half-period of the cosine function. 149 | 150 | If `lr_min` is specified, the decay range is scaled from 1 to `lr_min` 151 | to ensure the learning rate does not drop below this minimum value. 152 | """ 153 | warmup_stable_steps = warmup_steps + stable_steps 154 | if current_step < warmup_steps: 155 | # linear warmup 156 | # 0-indexed step, hence + 1 adjustments 157 | current_step += 1 158 | assert warmup_steps != 0, ( 159 | "warmup_steps must not be zero to reach this branch" 160 | ) 161 | curr_adjustment = float(current_step / warmup_steps) 162 | elif current_step < warmup_stable_steps: 163 | curr_adjustment = 1.0 164 | else: 165 | # 0-indexed step, hence + 1 adjustments 166 | current_step += 1 167 | assert decay_steps != 0, "decay_steps must not be zero to reach this branch" 168 | progress = float(current_step - warmup_stable_steps) / decay_steps 169 | 170 | if lr_decay_type == "linear": 171 | curr_adjustment = 1 - progress 172 | elif lr_decay_type == "sqrt": 173 | curr_adjustment = 1 - math.sqrt(progress) 174 | elif lr_decay_type == "cosine": 175 | curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) 176 | curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment 177 | return curr_adjustment 178 | 179 | lr_lambda = functools.partial( 180 | linear_warmup_stable_decay, 181 | warmup_steps=warmup_steps, 182 | stable_steps=stable_steps, 183 | decay_steps=decay_steps, 184 | lr_decay_type=lr_decay_type, 185 | lr_min=lr_min, 186 | ) 187 | return LRSchedulersContainer(optimizers, lr_lambda) 188 | -------------------------------------------------------------------------------- /torchtitan/distributed/parallel_dims.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | from functools import cached_property 9 | 10 | from torch.distributed.device_mesh import DeviceMesh, init_device_mesh 11 | 12 | from torchtitan.tools.logging import logger 13 | from torchtitan.tools.utils import device_type 14 | 15 | 16 | __all__ = ["ParallelDims"] 17 | 18 | 19 | @dataclass 20 | class ParallelDims: 21 | dp_replicate: int 22 | dp_shard: int 23 | cp: int 24 | tp: int 25 | pp: int 26 | ep: int 27 | world_size: int 28 | 29 | _world_mesh: DeviceMesh = None 30 | 31 | def __post_init__(self): 32 | self._validate() 33 | 34 | def _validate(self): 35 | dp_replicate, dp_shard, cp, tp, pp, ep = ( 36 | self.dp_replicate, 37 | self.dp_shard, 38 | self.cp, 39 | self.tp, 40 | self.pp, 41 | self.ep, 42 | ) 43 | for d in (dp_replicate, cp, tp, pp, ep): 44 | assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" 45 | 46 | assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." 47 | if dp_shard < 0: 48 | self.dp_shard = dp_shard = self.world_size // (dp_replicate * cp * tp * pp) 49 | assert dp_shard >= 1 50 | 51 | assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, ( 52 | f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * " 53 | f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" 54 | ) 55 | 56 | if ep > 1: 57 | # EP would borrow all cp and some dp_shard degree 58 | assert ep % cp == 0 and (dp_shard * cp) % ep == 0 59 | 60 | def build_mesh(self) -> DeviceMesh: 61 | # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel 62 | # is not very clean, due to the limited support from DeviceMesh 63 | # for creating two staggered meshes. Will improve. 64 | if self.ep > 1: 65 | return self._build_mesh_with_ep() 66 | else: 67 | return self._build_mesh_without_ep() 68 | 69 | def _build_mesh_with_ep(self) -> DeviceMesh: 70 | # With ep, dp_shard and ep are derived submeshes: 71 | # dp_shard = dp_shard_mod_ep * dp_shard_in_ep 72 | # ep = dp_shard_in_ep * cp 73 | dp_shard_mod_ep = self.dp_shard * self.cp // self.ep 74 | dp_shard_in_ep = self.ep // self.cp 75 | 76 | dims = [] 77 | names = [] 78 | for d, name in zip( 79 | [ 80 | self.pp, 81 | self.dp_replicate, 82 | dp_shard_mod_ep, 83 | dp_shard_in_ep, 84 | self.cp, 85 | self.tp, 86 | ], 87 | ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], 88 | ): 89 | # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping 90 | # helps the MoE layers do mixed precision training 91 | if d > 1 or name == "dp_shard_mod_ep": 92 | dims.append(d) 93 | names.append(name) 94 | 95 | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") 96 | mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) 97 | 98 | # Create all the submesh here to ensure all required process groups are 99 | # initialized: 100 | # Mesh for data loading (no communication on this mesh) 101 | dp_mesh_dim_names = [] 102 | # Mesh for param sharding 103 | dp_shard_cp_mesh_dim_names = [] 104 | # Mesh for loss all-reduce 105 | dp_cp_mesh_dim_names = [] 106 | # Mesh for ep 107 | ep_mesh_dim_names = [] 108 | 109 | if self.dp_replicate_enabled: 110 | dp_mesh_dim_names.append("dp_replicate") 111 | dp_cp_mesh_dim_names.append("dp_replicate") 112 | # dp_shard_mod_ep is always needed, even if it's 1 113 | dp_mesh_dim_names.append("dp_shard_mod_ep") 114 | dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") 115 | dp_cp_mesh_dim_names.append("dp_shard_mod_ep") 116 | if "dp_shard_in_ep" in names: 117 | dp_mesh_dim_names.append("dp_shard_in_ep") 118 | dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") 119 | dp_cp_mesh_dim_names.append("dp_shard_in_ep") 120 | ep_mesh_dim_names.append("dp_shard_in_ep") 121 | if self.cp_enabled: 122 | dp_shard_cp_mesh_dim_names.append("cp") 123 | dp_cp_mesh_dim_names.append("cp") 124 | ep_mesh_dim_names.append("cp") 125 | 126 | mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") 127 | mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") 128 | mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") 129 | mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") 130 | 131 | return mesh 132 | 133 | def _build_mesh_without_ep(self) -> DeviceMesh: 134 | dims = [] 135 | names = [] 136 | for d, name in zip( 137 | [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], 138 | ["pp", "dp_replicate", "dp_shard", "cp", "tp"], 139 | ): 140 | if d > 1: 141 | dims.append(d) 142 | names.append(name) 143 | 144 | logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") 145 | mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) 146 | 147 | # Create all the submesh here to ensure all required process groups are 148 | # initialized: 149 | # Mesh for data loading (no communication on this mesh) 150 | dp_mesh_dim_names = [] 151 | # Mesh for param sharding 152 | dp_shard_cp_mesh_dim_names = [] 153 | # Mesh for loss all-reduce 154 | dp_cp_mesh_dim_names = [] 155 | 156 | if self.dp_replicate_enabled: 157 | dp_mesh_dim_names.append("dp_replicate") 158 | dp_cp_mesh_dim_names.append("dp_replicate") 159 | if self.dp_shard_enabled: 160 | dp_mesh_dim_names.append("dp_shard") 161 | dp_shard_cp_mesh_dim_names.append("dp_shard") 162 | dp_cp_mesh_dim_names.append("dp_shard") 163 | if self.cp_enabled: 164 | dp_shard_cp_mesh_dim_names.append("cp") 165 | dp_cp_mesh_dim_names.append("cp") 166 | 167 | if dp_mesh_dim_names != []: 168 | mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") 169 | if dp_shard_cp_mesh_dim_names != []: 170 | mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( 171 | mesh_dim_name="dp_shard_cp" 172 | ) 173 | if dp_cp_mesh_dim_names != []: 174 | mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") 175 | 176 | return mesh 177 | 178 | @property 179 | def world_mesh(self) -> str: 180 | # doing late init so ParallelDims can still be used as a lightweight 181 | # dataclass without having to initialize the world mesh 182 | if self._world_mesh is None: 183 | self._world_mesh = self.build_mesh() 184 | return self._world_mesh 185 | 186 | @property 187 | def dp_enabled(self): 188 | return self.dp_replicate > 1 or self.dp_shard > 1 189 | 190 | @property 191 | def dp_replicate_enabled(self): 192 | return self.dp_replicate > 1 193 | 194 | @property 195 | def dp_shard_enabled(self): 196 | return self.dp_shard > 1 197 | 198 | @property 199 | def cp_enabled(self): 200 | return self.cp > 1 201 | 202 | @property 203 | def dp_cp_enabled(self): 204 | return self.dp_enabled or self.cp_enabled 205 | 206 | @property 207 | def fsdp_enabled(self): 208 | return self.dp_shard_enabled or self.cp_enabled 209 | 210 | @property 211 | def tp_enabled(self): 212 | return self.tp > 1 213 | 214 | @property 215 | def pp_enabled(self): 216 | return self.pp > 1 217 | 218 | @property 219 | def ep_enabled(self): 220 | return self.ep > 1 221 | 222 | @cached_property 223 | def non_data_parallel_size(self): 224 | return self.cp * self.tp * self.pp 225 | 226 | @cached_property 227 | def seq_len_divisor(self): 228 | # Sequence Parallel requires that seq_len be divisible by TP degree. 229 | # https://github.com/pytorch/torchtitan/pull/640#discussion_r1849481001 230 | 231 | # Context Parallel requires that seq_len be divisible by 2 * CP degree, 232 | # when load balancing is enabled (by default). 233 | # https://github.com/pytorch/pytorch/blob/4f62dcc/torch/distributed/tensor/experimental/_attention.py#L1246 234 | return self.tp * (self.cp * 2) 235 | 236 | @cached_property 237 | def dense_params_mesh_ndim(self): 238 | # Note: In dp2ep EP, EP params mesh ndim is 1 more due to the 'ep' mesh 239 | return self.dp_replicate_enabled + self.fsdp_enabled + self.tp_enabled 240 | -------------------------------------------------------------------------------- /torchtitan/distributed/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | from typing import Callable 8 | 9 | from torch.distributed.pipelining.schedules import ( 10 | _PipelineSchedule, 11 | _PipelineScheduleRuntime, 12 | get_schedule_class, 13 | PipelineScheduleMulti, 14 | PipelineScheduleSingle, 15 | ) 16 | from torch.distributed.pipelining.stage import PipelineStage 17 | 18 | from torchtitan.config_manager import JobConfig 19 | from torchtitan.tools.logging import logger 20 | 21 | 22 | __all__ = ["build_pipeline_schedule", "generate_split_points", "stage_ids_this_rank"] 23 | 24 | 25 | # TODO: It's unclear if this API is general enough to be used by other models. 26 | # If not, we should move it to a Transformer-specific directory. 27 | def generate_split_points( 28 | schedule_str: str, 29 | pp_degree: int, 30 | num_layers: int, 31 | num_layers_per_stage: int | None, 32 | input_weight: int = 1, 33 | output_weight: int = 1, 34 | ) -> list[str]: 35 | """ 36 | Generate a list of split points based on the input configs. In this function, 37 | the number of effective layers considered is the summation of num_layers, 38 | input_weight, and output_weight. 39 | 40 | If num_layers_per_virtual_stage is given, we require rigid fit of the 41 | effective layers (regular layers + weighted input + weighted output) 42 | onto pipeline stages and ranks, with several assertions. It is the users' 43 | responsibility to figure out the input weight, output weight, and the 44 | number of regular layers, so that they can be arranged neatly. 45 | 46 | If num_layers_per_virtual_stage is None, we by default set each pipeline rank 47 | to have 1 stage if schedule_str is a single-stage schedule, or 2 virtual stages 48 | if it is a multi-stage schedule, and try to distribute all effective layers 49 | evenly onto the PP stages. If there are extra layers, we disperse them in 50 | the starting stages. 51 | 52 | Args: 53 | schedule_str (str): The string of the schedule name. 54 | pp_degree (int): The pipeline parallel dimension. 55 | num_layers (int): The number of layers in the model. 56 | input_weight (int): The number of layers to consider the input modules in layer calculation. 57 | output_weight (int): The number of layers to consider the output modules in layer calculation. 58 | num_layers_per_stage (int): The number of layers per (virtual) pipeline stage. 59 | 60 | Returns: 61 | list[str]: A list of split point FQNs. 62 | """ 63 | 64 | schedule_class = get_schedule_class(schedule_str) 65 | is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) 66 | 67 | num_effective_layers = num_layers + input_weight + output_weight 68 | 69 | if num_layers_per_stage is not None: 70 | # If num_layers_per_stage is provided, we require a rigid fit of the effective layers 71 | assert num_effective_layers % pp_degree == 0 72 | num_layers_per_pipeline_rank = num_effective_layers // pp_degree 73 | 74 | assert num_layers_per_pipeline_rank % num_layers_per_stage == 0 75 | num_stages_per_rank = num_layers_per_pipeline_rank // num_layers_per_stage 76 | 77 | num_total_virtual_stages = num_stages_per_rank * pp_degree 78 | num_extra_layers = 0 79 | 80 | if is_single_stage_schedule: 81 | assert num_stages_per_rank == 1, ( 82 | f"Number of stages per rank ({num_stages_per_rank}) must be 1 for single-stage schedules." 83 | ) 84 | else: 85 | assert num_stages_per_rank >= 2, ( 86 | f"Number of stages per rank ({num_stages_per_rank}) must be >= 2 for multi-stage schedules." 87 | ) 88 | else: 89 | # In a multi-stage schedule, if num_layers_per_stage is not 90 | # provided, by default each pipeline rank has 2 virtual stages. 91 | num_stages_per_rank = 1 if is_single_stage_schedule else 2 92 | num_total_virtual_stages = pp_degree * num_stages_per_rank 93 | 94 | if num_total_virtual_stages > num_effective_layers: 95 | raise ValueError( 96 | "The number of total stages cannot be greater than the number of effective layers." 97 | ) 98 | 99 | num_layers_per_stage = num_effective_layers // num_total_virtual_stages 100 | num_extra_layers = num_effective_layers % num_total_virtual_stages 101 | 102 | assert num_layers_per_stage >= max(input_weight, output_weight) 103 | 104 | splits = [] 105 | current_layer = 0 106 | for i in range(num_total_virtual_stages - 1): 107 | if i == 0: 108 | current_layer += num_layers_per_stage - input_weight 109 | else: 110 | current_layer += num_layers_per_stage 111 | # extra layers will be dispersed to the first stages 112 | if num_extra_layers > 0: 113 | current_layer += 1 114 | num_extra_layers -= 1 115 | splits.append("layers." + str(current_layer)) 116 | 117 | logger.info( 118 | "No 'pipeline_parallel_split_points' provided. Here is the auto-generated split, " 119 | f"which may be sub-optimal: {splits}." 120 | ) 121 | return splits 122 | 123 | 124 | def build_pipeline_schedule( 125 | job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable 126 | ) -> _PipelineSchedule: 127 | """Builds a pipeline schedule for the given job configuration and stages. 128 | 129 | Args: 130 | job_config (JobConfig): The job configuration. 131 | stages (list[PipelineStage]): The stages to be scheduled. 132 | loss_fn (Callable): The loss function. 133 | 134 | Returns: 135 | _PipelineSchedule: The pipeline schedule for the given stages. 136 | """ 137 | pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv 138 | 139 | # Validate that pp_schedule_csv is a valid path 140 | if pp_schedule_csv: 141 | if not os.path.isfile(pp_schedule_csv): 142 | raise FileNotFoundError( 143 | f"The specified path {pp_schedule_csv} does not exist or is not a file." 144 | ) 145 | schedule_class = _PipelineScheduleRuntime 146 | else: 147 | schedule_class = get_schedule_class( 148 | job_config.parallelism.pipeline_parallel_schedule 149 | ) 150 | 151 | looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) 152 | microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size 153 | batch_size = job_config.training.local_batch_size 154 | # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training 155 | if batch_size % microbatch_size != 0: 156 | raise ValueError( 157 | f"Batch size {job_config.training.local_batch_size} must be divisible by number of microbatches {n_microbatches}. " 158 | "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." 159 | ) 160 | n_microbatches = batch_size // microbatch_size 161 | # We expect that the number of local stages (`len(stages)`) is the same across all ranks 162 | num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages) 163 | if n_microbatches < num_total_stages: 164 | logger.warning( 165 | f"Number of microbatches ({n_microbatches}) is less than the total number " 166 | f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." 167 | ) 168 | 169 | schedule = schedule_class( 170 | stages if looped_schedule else stages[0], 171 | n_microbatches=n_microbatches, 172 | loss_fn=loss_fn, 173 | ) 174 | logger.info( 175 | f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} " 176 | f"with {n_microbatches} microbatches and {num_total_stages} stages." 177 | ) 178 | 179 | if pp_schedule_csv: 180 | assert schedule_class in [ 181 | PipelineScheduleSingle, 182 | PipelineScheduleMulti, 183 | _PipelineScheduleRuntime, 184 | ], ( 185 | "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " 186 | "and _PipelineScheduleRuntime support csv schedules" 187 | ) 188 | schedule._load_csv(pp_schedule_csv) 189 | 190 | return schedule 191 | 192 | 193 | # TODO(whc) should this be a utility inside torch.pipelining? 194 | def stage_ids_this_rank( 195 | pp_rank: int, pp_size: int, num_stages: int, style: str = "loop" 196 | ) -> tuple[int]: 197 | """Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule""" 198 | assert num_stages % pp_size == 0, ( 199 | f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size}" 200 | ) 201 | stages_per_rank = num_stages // pp_size 202 | if style == "loop": 203 | return tuple(pp_rank + s * pp_size for s in range(stages_per_rank)) 204 | elif style == "v": 205 | assert stages_per_rank == 2, ( 206 | f"v schedules assume 2 stages per rank, got {stages_per_rank}" 207 | ) 208 | stage_v_pairs = list( 209 | zip(range(pp_size), range(num_stages - 1, pp_size - 1, -1)) 210 | ) 211 | return stage_v_pairs[pp_rank] 212 | -------------------------------------------------------------------------------- /torchtitan/components/quantization/float8.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from functools import partial 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from torchtitan.config_manager import Float8, JobConfig 12 | from torchtitan.distributed import ParallelDims 13 | from torchtitan.protocols.model_converter import ( 14 | ModelConverter, 15 | register_model_converter, 16 | ) 17 | from torchtitan.tools.logging import logger 18 | from torchtitan.tools.utils import has_cuda_capability 19 | 20 | from .utils import module_filter_fn 21 | 22 | AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn" 23 | 24 | 25 | class Float8Converter(ModelConverter): 26 | def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): 27 | self.enabled = False 28 | 29 | float8_config: Float8 = job_config.float8 30 | if has_cuda_capability(8, 9) or ( 31 | float8_config.emulate and not job_config.training.compile 32 | ): 33 | pass 34 | else: 35 | raise ValueError( 36 | "Failed to swap to Float8Linear because float8 is only supported on SM89 or later." 37 | "To enable testing on older hardware, set `float8.emulate` to True in eager mode.", 38 | ) 39 | try: 40 | from torchao.float8 import Float8LinearConfig 41 | except ImportError as e: 42 | raise ImportError( 43 | "torchao is not installed. Please install it to use float8 linear layers." 44 | ) from e 45 | 46 | if float8_config.recipe_name is not None and not hasattr( 47 | Float8LinearConfig, "from_recipe_name" 48 | ): 49 | logger.warning( 50 | "Failed to swap to Float8Linear with recipe lookup because the torchao version " 51 | "is too old, please install torchao v0.9.0 or later and try again", 52 | ) 53 | return 54 | 55 | self.enabled = True 56 | self.filter_fqns = float8_config.filter_fqns 57 | self.moe_fqns = float8_config.moe_fqns_prototype 58 | self.filter_fn = self._init_filter_fn(float8_config) 59 | 60 | # Validate MoE training prototype limitations. 61 | if self.moe_fqns: 62 | assert job_config.parallelism.pipeline_parallel_degree == 1, ( 63 | "Float8 MoE training prototype does not yet support pipeline parallelism" 64 | ) 65 | assert job_config.parallelism.context_parallel_degree == 1, ( 66 | "Float8 MoE training prototype does not yet support context parallelism" 67 | ) 68 | 69 | if float8_config.recipe_name is not None: 70 | assert not float8_config.enable_fsdp_float8_all_gather, ( 71 | "using `float8_config.enable_fsdp_float8_all_gather` together " 72 | "with `float8_config.recipe_name` is not supported" 73 | ) 74 | 75 | assert not float8_config.force_recompute_fp8_weight_in_bwd, ( 76 | "using `float8_config.force_recompute_fp8_weight_in_bwd` together " 77 | "with `float8_config.recipe_name` is not supported" 78 | ) 79 | 80 | self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) 81 | self.precompute_scale = False 82 | logger.info( 83 | f"Float8 training active with recipe {float8_config.recipe_name}" 84 | ) 85 | 86 | # short-term solution for https://github.com/pytorch/pytorch/issues/150859 87 | if float8_config.recipe_name == "rowwise": 88 | torch._inductor.config.emulate_precision_casts = True 89 | logger.debug( 90 | "Set torch._inductor.config.emulate_precision_casts to True" 91 | ) 92 | else: 93 | # Mutates the model inplace replacing instances of nn.Linear with Float8Linear 94 | enable_fsdp_float8_all_gather = ( 95 | parallel_dims.dp_shard_enabled 96 | and float8_config.enable_fsdp_float8_all_gather 97 | ) 98 | self.config = Float8LinearConfig( 99 | enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, 100 | force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd, 101 | emulate=float8_config.emulate, 102 | ) 103 | # for precompute_float8_dynamic_scale_for_fsdp 104 | self.precompute_scale = ( 105 | enable_fsdp_float8_all_gather 106 | and float8_config.precompute_float8_dynamic_scale_for_fsdp 107 | ) 108 | logger.info("Float8 tensorwise scaled training active") 109 | 110 | def _init_filter_fn(self, float8_config: Float8): 111 | # use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns. 112 | use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns 113 | if use_auto_filter: 114 | try: 115 | from torchao.float8 import _auto_filter_for_recipe 116 | 117 | logger.info( 118 | "Using _auto_filter_for_recipe to avoid converting linear layers with dims too small " 119 | "to benefit from float8 training. See docs/float8.md for more info." 120 | ) 121 | 122 | recipe_name = ( 123 | float8_config.recipe_name 124 | if float8_config.recipe_name 125 | else "tensorwise" 126 | ) 127 | 128 | # remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe 129 | float8_config.filter_fqns.remove(AUTO_FILTER_SMALL_KN_FLAG) 130 | 131 | return _auto_filter_for_recipe( 132 | recipe_name, 133 | filter_fqns=float8_config.filter_fqns, 134 | ) 135 | except ImportError: 136 | logger.warning( 137 | ( 138 | "Using default module_filter_fn for float8 model conversion. " 139 | "To use _auto_filter_for_recipe, please install torchao nightly build." 140 | ) 141 | ) 142 | 143 | # use default filter func 144 | return partial(module_filter_fn, filter_fqns=float8_config.filter_fqns) 145 | 146 | def convert(self, model: nn.Module): 147 | """ 148 | This function converts the linear layers of `model` to `Float8Linear`. 149 | Note that today, only dynamic tensor scaling (the default) is supported. 150 | This will mutate the model inplace. 151 | """ 152 | if not self.enabled: 153 | return 154 | 155 | # MoE conversion must take place before Float8Linear conversion, otherwise the Float8Linears will 156 | # be converted back to nn.Linear: 157 | # https://github.com/pytorch/ao/blob/c2a6568a04075acc371a338206216bb65536fb27/torchao/quantization/quant_api.py#L294-L299 158 | # TODO: add warning in torchao when this happens, or find a better way to avoid this. 159 | if self.moe_fqns: 160 | self._convert_moe_layers(model) 161 | 162 | from torchao.float8 import convert_to_float8_training 163 | 164 | # Mutates the model inplace replacing instances of nn.Linear with Float8Linear 165 | convert_to_float8_training( 166 | model, 167 | config=self.config, 168 | module_filter_fn=self.filter_fn, 169 | ) 170 | logger.info( 171 | "Swapped to Float8Linear layers with enable_fsdp_float8_all_gather=" 172 | f"{self.config.enable_fsdp_float8_all_gather}" 173 | ) 174 | 175 | def _convert_moe_layers(self, model: nn.Module): 176 | """ 177 | Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor, 178 | to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs. 179 | """ 180 | from torchao.quantization.quant_api import quantize_ 181 | 182 | try: 183 | from torchao.prototype.moe_training.conversion_utils import ( 184 | MoETrainingConfig, 185 | ) 186 | except ImportError as e: 187 | raise ImportError( 188 | "torchao installation does not have MoE training support. Please install torchao nightly build." 189 | ) from e 190 | 191 | def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: 192 | for target_fqn in self.moe_fqns: 193 | if target_fqn in cur_fqn: 194 | return True 195 | return False 196 | 197 | config = MoETrainingConfig() 198 | quantize_(model, config=config, filter_fn=moe_module_filter_fn) 199 | logger.info( 200 | f"Converted MoE layers matching FQNS {self.moe_fqns} " 201 | "to use dynamic float8 rowwise quantization with scaled grouped GEMMs" 202 | ) 203 | 204 | def post_optimizer_hook(self, model: nn.Module | list[nn.Module]): 205 | if not self.enabled: 206 | return 207 | 208 | if not self.precompute_scale: 209 | return 210 | 211 | from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp 212 | 213 | models = [model] if isinstance(model, nn.Module) else model 214 | for m in models: 215 | precompute_float8_dynamic_scale_for_fsdp(m) 216 | 217 | 218 | register_model_converter(Float8Converter, "float8") 219 | -------------------------------------------------------------------------------- /torchtitan/models/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved. 8 | 9 | from typing import Callable, ClassVar 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch.nn.attention import sdpa_kernel, SDPBackend 14 | from torch.nn.attention.flex_attention import ( 15 | _mask_mod_signature, 16 | BlockMask, 17 | create_block_mask, 18 | flex_attention, 19 | ) 20 | 21 | from torchtitan.tools.utils import has_cuda_capability 22 | 23 | # FlexAttention mask type. For each mask type, we initialize it at most once per 24 | # batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to 25 | # track the initialized mask. 26 | FLEX_ATTN_MASK_T = tuple[str, int | None] 27 | 28 | 29 | class FlexAttention(torch.nn.Module): 30 | """FlexAttention module that uses torch.nn.attention.flex_attention. 31 | 32 | This module is a wrapper around torch.nn.attention.flex_attention. This module 33 | implements certain common attention types, such as causal and block_causal. 34 | 35 | Args: 36 | attn_mask_type (str): The type of attention mask. Currently, we support 37 | "causal" and "block_causal". "causal" means the lower triangle of the 38 | attention matrix is masked. "block_causal" means the attention matrix 39 | is divided into blocks, where block boundary is defined by EOS token, 40 | and the lower triangle of each block is masked. 41 | fixed_block_size (int | None): The block size to be used to perform attention. 42 | If specified, each sequence will be further divided to blocks, where each 43 | block has the maximum size of ``fixed_block_size``. A query will only attend 44 | to the keys within the same block. 45 | """ 46 | 47 | # We registered flex_attention related attributes as class variables as we 48 | # need to amortize the cost of compilation. 49 | flex_attn: ClassVar[Callable] = torch.compile( 50 | flex_attention, mode="max-autotune-no-cudagraphs" 51 | ) 52 | compiled_create_block_mask: ClassVar[Callable] = torch.compile(create_block_mask) 53 | used_attn_mask_types: ClassVar[set[FLEX_ATTN_MASK_T]] = set() 54 | # Attention mask type to the created BlockMask. 55 | # This allows us to keep track the created block masks for each 56 | # new batch. We will use this to update the block mask when a 57 | # new batch is created. This also allows user to create different 58 | # block masks for different layers. 59 | block_masks: ClassVar[dict[FLEX_ATTN_MASK_T, BlockMask]] = {} 60 | 61 | # Instance variables. 62 | attn_mask_type: str 63 | 64 | def __init__( 65 | self, attn_mask_type: str, fixed_block_size: int | None = None 66 | ) -> None: 67 | super().__init__() 68 | if attn_mask_type not in ["causal", "block_causal"]: 69 | raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") 70 | self.attn_mask_type = attn_mask_type 71 | self.fixed_block_size = fixed_block_size 72 | 73 | FlexAttention.used_attn_mask_types.add(self.mask_key) 74 | 75 | @property 76 | def mask_key(self) -> FLEX_ATTN_MASK_T: 77 | return (self.attn_mask_type, self.fixed_block_size) 78 | 79 | def forward( 80 | self, 81 | q: torch.Tensor, 82 | k: torch.Tensor, 83 | v: torch.Tensor, 84 | scale: float | None = None, 85 | ) -> torch.Tensor: 86 | block_mask = FlexAttention.block_masks[self.mask_key] 87 | return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) 88 | 89 | @staticmethod 90 | def _get_causal_mask_mod() -> _mask_mod_signature: 91 | def causal_mask( 92 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor 93 | ): 94 | return q_idx >= kv_idx 95 | 96 | return causal_mask 97 | 98 | @staticmethod 99 | def _get_block_causal_mask_mod( 100 | batch: torch.Tensor, eos_id: int 101 | ) -> _mask_mod_signature: 102 | # batch is [b, s, h, d] shape 103 | mask = batch == eos_id 104 | mask[:, -1] = True 105 | acc_mask = torch.cumsum(torch.where(mask, 1, 0), dim=1) 106 | seq_idx = torch.zeros_like(acc_mask, dtype=torch.int32) 107 | seq_idx[:, 1:] = acc_mask[:, :-1] 108 | 109 | def block_causal_mask( 110 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor 111 | ): 112 | return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx) 113 | 114 | return block_causal_mask 115 | 116 | @staticmethod 117 | def _fixed_block_mask_mod( 118 | mask_mod: _mask_mod_signature, fixed_block_size: int 119 | ) -> _mask_mod_signature: 120 | """ 121 | Given an arbirary mask_mod, divide the input sequence to blocks 122 | and only allow attention within the same block. 123 | 124 | Args: 125 | mask_mod: The mask mod to apply to the documents 126 | fixed_block_size: The number of tokens in each block. 127 | """ 128 | 129 | # Credit to @drisspg. 130 | def blocked_mask_mod( 131 | b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor 132 | ): 133 | # Get the block index of the query and key 134 | q_block = q_idx // fixed_block_size 135 | kv_block = kv_idx // fixed_block_size 136 | # Only allow attention within the same block 137 | same_block = q_block == kv_block 138 | # Apply the original mask mod 139 | inner_mask = mask_mod( 140 | b, h, q_idx % fixed_block_size, kv_idx % fixed_block_size 141 | ) 142 | 143 | return same_block & inner_mask 144 | 145 | blocked_mask_mod.__name__ = ( 146 | f"blocked_mask_mod_{mask_mod.__name__}_fixed_block_size_{fixed_block_size}" 147 | ) 148 | 149 | return blocked_mask_mod 150 | 151 | @staticmethod 152 | @torch.no_grad() 153 | def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None: 154 | # batch is [b, s, h, d] shape 155 | for mask_key in FlexAttention.used_attn_mask_types: 156 | attn_mask_type, fixed_block_size = mask_key 157 | match attn_mask_type: 158 | case "causal": 159 | if FlexAttention.block_masks.get(mask_key, None) is not None: 160 | continue 161 | # We don't care about batch dimension -- 162 | # all samples have the same lower triangle mask. 163 | batch_dimension = 1 164 | mask_mod = FlexAttention._get_causal_mask_mod() 165 | case "block_causal": 166 | if eos_id is None: 167 | raise RuntimeError( 168 | "eos_id must be provided for block_causal mask." 169 | ) 170 | batch_dimension = batch.shape[0] 171 | mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) 172 | case _: 173 | raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") 174 | 175 | if fixed_block_size is not None and fixed_block_size > 0: 176 | mask_mod = FlexAttention._fixed_block_mask_mod( 177 | mask_mod, fixed_block_size 178 | ) 179 | 180 | seq_len = batch.shape[1] 181 | block_mask = FlexAttention.compiled_create_block_mask( 182 | mask_mod, batch_dimension, None, seq_len, seq_len 183 | ) 184 | FlexAttention.block_masks[mask_key] = block_mask 185 | 186 | 187 | class ScaledDotProductAttention(torch.nn.Module): 188 | backends: ClassVar[list[SDPBackend]] = [] 189 | 190 | def __init__(self, attn_mask_type: str) -> None: 191 | super().__init__() 192 | if attn_mask_type != "causal": 193 | raise ValueError( 194 | "TorchTitan with SDPA currently only supports causal mask." 195 | ) 196 | 197 | ScaledDotProductAttention._init_backend() 198 | 199 | @classmethod 200 | def _init_backend(cls) -> None: 201 | if cls.backends: 202 | return 203 | 204 | # Add CuDNN on B200 w/ highest priority 205 | cls.backends = [ 206 | SDPBackend.FLASH_ATTENTION, 207 | SDPBackend.EFFICIENT_ATTENTION, 208 | SDPBackend.MATH, 209 | ] 210 | if has_cuda_capability(10, 0): 211 | cls.backends.insert(0, SDPBackend.CUDNN_ATTENTION) 212 | 213 | def forward( 214 | self, 215 | q: torch.Tensor, 216 | k: torch.Tensor, 217 | v: torch.Tensor, 218 | scale: float | None = None, 219 | ) -> torch.Tensor: 220 | assert self.backends, "SDPA Backends should not be empty." 221 | with sdpa_kernel(self.backends, set_priority=True): 222 | return F.scaled_dot_product_attention(q, k, v, is_causal=True, scale=scale) 223 | 224 | 225 | def build_attention( 226 | use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None 227 | ): 228 | if use_flex_attn: 229 | return FlexAttention(attn_mask_type, fixed_block_size) 230 | else: 231 | if fixed_block_size is not None: 232 | raise ValueError( 233 | "TorchTitan with SDPA currently does not support fixed_block_size." 234 | ) 235 | if attn_mask_type != "causal": 236 | raise ValueError( 237 | "TorchTitan with SDPA currently only supports causal mask." 238 | ) 239 | return ScaledDotProductAttention(attn_mask_type) 240 | 241 | 242 | def init_attention_mask(batch: torch.Tensor, eos_id: int | None = None) -> None: 243 | FlexAttention.init_attention_mask(batch, eos_id) 244 | -------------------------------------------------------------------------------- /torchtitan/datasets/hf_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from dataclasses import dataclass 8 | 9 | from typing import Any 10 | 11 | import torch 12 | 13 | from datasets import Dataset, load_dataset, Audio 14 | from datasets.distributed import split_dataset_by_node 15 | from torch.distributed.checkpoint.stateful import Stateful 16 | from torch.utils.data import IterableDataset 17 | 18 | from torchtitan.components.dataloader import ParallelAwareDataloader 19 | from torchtitan.components.tokenizer import BaseTokenizer 20 | from torchtitan.config_manager import JobConfig 21 | from torchtitan.tools.logging import logger 22 | from transformers import DataCollatorForSeq2Seq 23 | 24 | 25 | def audio_array_to_text( 26 | audio_array: torch.tensor, 27 | audio_tokenizer, 28 | feature_extractor, 29 | num_quantizers: int, 30 | max_seconds: int = 20, 31 | ) -> str: 32 | # truncate the audio array to the expected length 33 | if audio_array.shape[-1] > max_seconds * feature_extractor.sampling_rate: 34 | audio_array = audio_array[: max_seconds * feature_extractor.sampling_rate] 35 | # 36 | inputs = feature_extractor( 37 | raw_audio=audio_array, 38 | sampling_rate=feature_extractor.sampling_rate, 39 | return_tensors="pt", 40 | ).to(audio_tokenizer.device) 41 | with torch.no_grad(): 42 | # Encode the audio input to get the audio codes 43 | # This will return a tensor of shape (batch_size, num_quantizers, sequence_length) 44 | # where each quantizer's output is in a separate dimension 45 | encoder_outputs = audio_tokenizer.encode( 46 | inputs["input_values"], 47 | inputs["padding_mask"], 48 | num_quantizers=num_quantizers, 49 | ) 50 | flatten_audio_codes = encoder_outputs.audio_codes.transpose(1, 2).reshape(-1) 51 | assert flatten_audio_codes.numel() % num_quantizers == 0 52 | steps = [] 53 | for i in range(0, flatten_audio_codes.numel(), num_quantizers): 54 | group = [ 55 | f"<{flatten_audio_codes[i + j].item()}_{j}>" for j in range(num_quantizers) 56 | ] 57 | steps.append(group) 58 | 59 | parts = [tok for step in steps for tok in step] 60 | 61 | text = "".join(parts) 62 | 63 | del inputs, encoder_outputs, flatten_audio_codes 64 | torch.cuda.empty_cache() 65 | return f"" 66 | 67 | 68 | def process_audio( 69 | sample: dict[str, Any], 70 | audio_tokenizer, 71 | feature_extractor, 72 | num_quantizers: int, 73 | task: str = "a2a", 74 | ) -> str: 75 | audio_sample = sample["audio"]["array"] 76 | text = audio_array_to_text( 77 | audio_sample, 78 | audio_tokenizer, 79 | feature_extractor, 80 | num_quantizers, 81 | ) 82 | if task == "tts": 83 | transcription = sample["text"] 84 | text = transcription + text 85 | return text 86 | 87 | 88 | class HuggingFaceDataset(IterableDataset, Stateful): 89 | def __init__( 90 | self, 91 | dataset_name: str, 92 | tokenizer: BaseTokenizer, 93 | audio_tokenizer=None, 94 | feature_extractor=None, 95 | num_quantizers: int = 4, 96 | seq_len: int = 2048, 97 | dp_rank: int = 0, 98 | dp_world_size: int = 1, 99 | infinite: bool = False, 100 | task: str = "a2a", 101 | ) -> None: 102 | if dataset_name == "peoples_speech": 103 | ds = load_dataset( 104 | "parquet", 105 | data_files="data/peoples_speech/**/*.parquet", 106 | split="train", 107 | streaming=True, 108 | ) 109 | ds = ds.cast_column( 110 | "audio", Audio(sampling_rate=feature_extractor.sampling_rate) 111 | ) 112 | elif dataset_name == "librispeech_asr_train": 113 | ds = load_dataset( 114 | "openslr/librispeech_asr", 115 | split="train.other.500", 116 | streaming=True, 117 | ) 118 | ds = ds.cast_column( 119 | "audio", Audio(sampling_rate=feature_extractor.sampling_rate) 120 | ) 121 | elif dataset_name == "librispeech_asr_test": 122 | ds = load_dataset( 123 | "openslr/librispeech_asr", 124 | split="test.clean", 125 | streaming=True, 126 | ) 127 | ds = ds.cast_column( 128 | "audio", Audio(sampling_rate=feature_extractor.sampling_rate) 129 | ) 130 | 131 | else: 132 | raise ValueError(f"Dataset {dataset_name} is not supported. ") 133 | 134 | self.dataset_name = dataset_name 135 | self._data = split_dataset_by_node(ds, dp_rank, dp_world_size) 136 | self._data = self._data.shuffle(seed=42, buffer_size=10_000) 137 | self.tokenizer = tokenizer 138 | self.audio_tokenizer = audio_tokenizer 139 | self.feature_extractor = feature_extractor 140 | self.num_quantizers = num_quantizers 141 | self.seq_len = seq_len 142 | self.task = task 143 | self.infinite = infinite 144 | 145 | # Variables for checkpointing 146 | self._sample_idx = 0 147 | self._token_buffer: list[int] = [] 148 | 149 | def _get_data_iter(self): 150 | # For map-style datasets, resume by skipping to the correct index 151 | # For iterable-style datasets, the underlying iterator already points to the correct index 152 | if isinstance(self._data, Dataset): 153 | if self._sample_idx == len(self._data): 154 | return iter([]) 155 | else: 156 | return iter(self._data.skip(self._sample_idx)) 157 | 158 | return iter(self._data) 159 | 160 | def __iter__(self): 161 | while True: 162 | data_iter = self._get_data_iter() 163 | while True: 164 | try: 165 | sample = next(data_iter) 166 | except StopIteration: 167 | break 168 | except Exception as e: 169 | logger.error( 170 | f"Error while iterating over dataset {self.dataset_name}: {e}" 171 | ) 172 | self._sample_idx += 1 173 | continue 174 | 175 | try: 176 | sample_text = process_audio( 177 | sample, 178 | self.audio_tokenizer, 179 | self.feature_extractor, 180 | self.num_quantizers, 181 | self.task, 182 | ) 183 | self._sample_idx += 1 184 | yield self.tokenizer( 185 | sample_text, 186 | max_length=self.seq_len, 187 | padding="max_length", 188 | truncation=True, 189 | ) 190 | except Exception as e: 191 | logger.error( 192 | f"Error while processing sample in dataset {self.dataset_name}: {e}" 193 | ) 194 | self._sample_idx += 1 195 | continue 196 | 197 | if not self.infinite: 198 | logger.warning(f"Dataset {self.dataset_name} has run out of data") 199 | break 200 | else: 201 | # Reset offset for the next iteration 202 | self._sample_idx = 0 203 | logger.warning(f"Dataset {self.dataset_name} is being re-looped") 204 | # Ensures re-looping a dataset loaded from a checkpoint works correctly 205 | if not isinstance(self._data, Dataset): 206 | if hasattr(self._data, "set_epoch") and hasattr( 207 | self._data, "epoch" 208 | ): 209 | self._data.set_epoch(self._data.epoch + 1) 210 | 211 | def load_state_dict(self, state_dict): 212 | if isinstance(self._data, Dataset): 213 | self._sample_idx = state_dict["sample_idx"] 214 | else: 215 | assert "data" in state_dict 216 | self._data.load_state_dict(state_dict["data"]) 217 | 218 | def state_dict(self): 219 | _state_dict = {"token_buffer": self._token_buffer} 220 | 221 | if isinstance(self._data, Dataset): 222 | _state_dict["sample_idx"] = self._sample_idx 223 | else: 224 | # Save the iterable dataset's state to later efficiently resume from it 225 | # https://huggingface.co/docs/datasets/v3.5.0/en/stream#save-a-dataset-checkpoint-and-resume-iteration 226 | _state_dict["data"] = self._data.state_dict() 227 | 228 | return _state_dict 229 | 230 | 231 | def build_hf_dataloader( 232 | dp_world_size: int, 233 | dp_rank: int, 234 | tokenizer: BaseTokenizer, 235 | audio_tokenizer, 236 | feature_extractor, 237 | job_config: JobConfig, 238 | infinite: bool = True, 239 | ) -> ParallelAwareDataloader: 240 | """Build a data loader for HuggingFace datasets.""" 241 | dataset_name = job_config.training.dataset 242 | batch_size = job_config.training.local_batch_size 243 | seq_len = job_config.training.seq_len 244 | 245 | hf_ds = HuggingFaceDataset( 246 | dataset_name=dataset_name, 247 | tokenizer=tokenizer, 248 | audio_tokenizer=audio_tokenizer, 249 | feature_extractor=feature_extractor, 250 | num_quantizers=job_config.model.num_quantizers, 251 | seq_len=seq_len, 252 | dp_rank=dp_rank, 253 | dp_world_size=dp_world_size, 254 | infinite=infinite, 255 | task=job_config.training.task, 256 | ) 257 | 258 | collate_fn = DataCollatorForSeq2Seq( 259 | tokenizer, 260 | pad_to_multiple_of=8, 261 | return_tensors="pt", 262 | padding=True, 263 | max_length=seq_len, 264 | ) 265 | 266 | return ParallelAwareDataloader( 267 | dataset=hf_ds, 268 | dp_rank=dp_rank, 269 | dp_world_size=dp_world_size, 270 | batch_size=batch_size, 271 | collate_fn=collate_fn, 272 | ) 273 | 274 | 275 | def build_hf_validation_dataloader( 276 | dp_world_size: int, 277 | dp_rank: int, 278 | tokenizer: BaseTokenizer, 279 | audio_tokenizer, 280 | feature_extractor, 281 | job_config: JobConfig, 282 | ) -> ParallelAwareDataloader: 283 | """Build a validation data loader for HuggingFace datasets.""" 284 | dataset_name = job_config.validation.dataset 285 | batch_size = job_config.validation.local_batch_size 286 | seq_len = job_config.validation.seq_len 287 | 288 | hf_ds = HuggingFaceDataset( 289 | dataset_name=dataset_name, 290 | tokenizer=tokenizer, 291 | audio_tokenizer=audio_tokenizer, 292 | feature_extractor=feature_extractor, 293 | num_quantizers=job_config.model.num_quantizers, 294 | seq_len=seq_len, 295 | dp_rank=dp_rank, 296 | dp_world_size=dp_world_size, 297 | infinite=False, 298 | task=job_config.training.task, 299 | ) 300 | 301 | collate_fn = DataCollatorForSeq2Seq( 302 | tokenizer, 303 | pad_to_multiple_of=8, 304 | return_tensors="pt", 305 | padding=True, 306 | max_length=seq_len, 307 | ) 308 | 309 | return ParallelAwareDataloader( 310 | dataset=hf_ds, 311 | dp_rank=dp_rank, 312 | dp_world_size=dp_world_size, 313 | batch_size=batch_size, 314 | collate_fn=collate_fn, 315 | ) 316 | -------------------------------------------------------------------------------- /torchtitan/components/optimizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import functools 8 | from typing import Any, Generic, Iterator, TypeVar 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.distributed.checkpoint.state_dict import ( 13 | get_optimizer_state_dict, 14 | set_optimizer_state_dict, 15 | StateDictOptions, 16 | ) 17 | from torch.distributed.checkpoint.stateful import Stateful 18 | from torch.optim import Optimizer 19 | 20 | from torchtitan.components.ft import FTManager, has_torchft 21 | from torchtitan.config_manager import JobConfig 22 | from torchtitan.distributed import ParallelDims 23 | 24 | __all__ = [ 25 | "OptimizersContainer", 26 | "build_optimizers", 27 | ] 28 | 29 | 30 | if has_torchft: 31 | import torchft as ft 32 | 33 | 34 | T = TypeVar("T", bound=Optimizer) 35 | 36 | 37 | class OptimizersContainer(Optimizer, Stateful, Generic[T]): 38 | """A container for multiple optimizers. 39 | 40 | This class is used to wrap multiple optimizers into a single object that can be 41 | used to reduce the complexity of the training loop. This mimics the behavior of 42 | ``torch.optim.Optimizer``. This class currently only supports ``Adam`` and ``AdamW``. 43 | 44 | **Note** 45 | Users who want to customize the optimizer behavior can inherit from this class and 46 | extend the functionality as needed. The following methods must follow the same signature 47 | as ``torch.optim.Optimizer`` class: ``step()``, ``zero_grad()``, ``state_dict()``, 48 | ``load_state_dict()``. 49 | 50 | **Limitations** 51 | This class assumes that all the optimizers are the same type and have the same 52 | configurations. With this assumption, TorchTitan can support lr scheduler resharding 53 | (e.g., loading a checkpoint with a different number of GPUs and/or different 54 | parallelization strategy). Note that ``get_optimizer_state_dict`` already enables the 55 | resharding for the optimizer state but not for the lr scheduler state, hence the limitation. 56 | 57 | Args: 58 | model_parts (List[nn.Module]): List of model parts to be optimized. 59 | optimizer_kwargs (Dict[str, Any]): Keyword arguments for the optimizers. 60 | name (str): Name of the optimizers. 61 | """ 62 | 63 | optimizers: list[T] 64 | model_parts: list[nn.Module] 65 | 66 | def __init__( 67 | self, 68 | model_parts: list[nn.Module], 69 | optimizer_cls: type[T], 70 | optimizer_kwargs: dict[str, Any], 71 | ) -> None: 72 | all_params = [] 73 | self.optimizers = [] 74 | self.model_parts = model_parts 75 | for model in self.model_parts: 76 | params = [p for p in model.parameters() if p.requires_grad] 77 | self.optimizers.append(optimizer_cls(params, **optimizer_kwargs)) 78 | all_params.extend(params) 79 | self._validate_length(len(self.model_parts)) 80 | self._post_init(all_params, optimizer_kwargs) 81 | 82 | def __iter__(self) -> Iterator[T]: 83 | return iter(self.optimizers) 84 | 85 | def __len__(self) -> int: 86 | return len(self.optimizers) 87 | 88 | def step(self, *args, **kwargs) -> None: 89 | for optimizer in self.optimizers: 90 | optimizer.step(*args, **kwargs) 91 | 92 | def zero_grad(self, *args, **kwargs) -> None: 93 | for optimizer in self.optimizers: 94 | optimizer.zero_grad(*args, **kwargs) 95 | 96 | def state_dict(self) -> dict[str, Any]: 97 | func = functools.partial( 98 | get_optimizer_state_dict, 99 | options=StateDictOptions(flatten_optimizer_state_dict=True), 100 | ) 101 | return { 102 | k: v 103 | for sd in map(func, self.model_parts, self.optimizers) 104 | for k, v in sd.items() 105 | } 106 | 107 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 108 | func = functools.partial( 109 | set_optimizer_state_dict, 110 | optim_state_dict=state_dict, 111 | options=StateDictOptions(flatten_optimizer_state_dict=True), 112 | ) 113 | list(map(func, self.model_parts, self.optimizers)) 114 | 115 | def _validate_length(self, expected_length: int) -> None: 116 | assert expected_length == len(self.optimizers), ( 117 | "Must pass one optimizer per model part or per param if " 118 | "using OptimizersInBackwardContainer." 119 | ) 120 | 121 | def _post_init( 122 | self, all_params: list[nn.Parameter], optimizer_kwargs: dict[str, Any] 123 | ) -> None: 124 | # We need to call Optimizer.__init__() to initialize some necessary optimizer 125 | # functionality such as hooks. 126 | Optimizer.__init__(self, all_params, optimizer_kwargs) 127 | 128 | 129 | class OptimizersInBackwardContainer(OptimizersContainer): 130 | """OptimizersContainer for executing ``optim.step()`` in backward pass. 131 | 132 | This class extend ``OptimizersContainer`` to support optimizer step in 133 | backward pass. ``step()`` and ``zero_grad()`` are no-op in this class. 134 | Instead, ``register_post_accumulate_grad_hook`` is used to register a hook to 135 | execute these methods when the gradient is accumulated. 136 | """ 137 | 138 | def __init__( 139 | self, 140 | model_parts: list[nn.Module], 141 | optimizer_cls: type[T], 142 | optimizer_kwargs: dict[str, Any], 143 | ) -> None: 144 | all_params = [] 145 | self.model_parts = model_parts 146 | 147 | optim_dict = {} 148 | for model in self.model_parts: 149 | for p in model.parameters(): 150 | if p.requires_grad: 151 | optim_dict[p] = optimizer_cls([p], **optimizer_kwargs) 152 | all_params.append(p) 153 | 154 | def optim_hook(param) -> None: 155 | optim_dict[param].step() 156 | optim_dict[param].zero_grad() 157 | 158 | for model in self.model_parts: 159 | for param in model.parameters(): 160 | if param.requires_grad: 161 | param.register_post_accumulate_grad_hook(optim_hook) 162 | 163 | self.optimizers = list(optim_dict.values()) 164 | 165 | self._validate_length( 166 | sum(len(list(model.parameters())) for model in self.model_parts) 167 | ) 168 | self._post_init(all_params, optimizer_kwargs) 169 | 170 | def step(self) -> None: 171 | pass 172 | 173 | def zero_grad(self) -> None: 174 | pass 175 | 176 | 177 | class FTOptimizersContainer(OptimizersContainer): 178 | def __init__( 179 | self, 180 | model_parts: list[nn.Module], 181 | optimizer_cls: type[T], 182 | optimizer_kwargs: dict[str, Any], 183 | ft_manager: "ft.Manager", 184 | use_ft_optimizer: bool = True, 185 | ) -> None: 186 | super().__init__(model_parts, optimizer_cls, optimizer_kwargs) 187 | 188 | # Force to initialize the optimizer state so that `optim.step()` 189 | # won't be called by state_dict() and load_state_dict(). 190 | _ = { 191 | k: v 192 | for sd in map(get_optimizer_state_dict, model_parts, self.optimizers) 193 | for k, v in sd.items() 194 | } 195 | self.cache_state_dict: dict[str, Any] = {} 196 | self._ft_optimizer = ft.Optimizer(ft_manager, self) 197 | # Whether to determine quorum using FT.optimizer, 198 | # in semi-sync training we use the synchronization step to start quorum 199 | self._use_ft_optimizer: bool = use_ft_optimizer 200 | 201 | def init_cache_state_dict(self) -> None: 202 | self.cache_state_dict = super().state_dict() 203 | 204 | def state_dict(self) -> dict[str, Any]: 205 | return self.cache_state_dict 206 | 207 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 208 | # We have to invalidate the `cache_state_dict` because optimizer uses 209 | # assign instead of copy when doing `load_state_dict()`. Without 210 | # invalidating the `cache_state_dict`, there will be memory leakage. 211 | self.cache_state_dict = {} 212 | super().load_state_dict(state_dict) 213 | self.init_cache_state_dict() 214 | 215 | def step(self, *args, **kwargs) -> None: 216 | """Calling the correct step() depending on the caller. 217 | 218 | TorchFT's OptimizerWrapper.step() is designed to be called only once 219 | per train step per ft.Manager regardless how many optimizers are used. 220 | Hence we will need to appropriately dispatch the call. 221 | """ 222 | if self._use_ft_optimizer: 223 | self._use_ft_optimizer = False 224 | self._ft_optimizer.step(*args, **kwargs) 225 | self._use_ft_optimizer = True 226 | else: 227 | super().step(*args, **kwargs) 228 | 229 | def zero_grad(self, *args, **kwargs) -> None: 230 | """Calling the correct zero_grad() depending on the caller. 231 | 232 | Check the comment in ``step()``. 233 | """ 234 | if self._use_ft_optimizer: 235 | self._use_ft_optimizer = False 236 | self._ft_optimizer.zero_grad(*args, **kwargs) 237 | self._use_ft_optimizer = True 238 | else: 239 | super().zero_grad(*args, **kwargs) 240 | 241 | 242 | def build_optimizers( 243 | model_parts: list[nn.Module], 244 | job_config: JobConfig, 245 | parallel_dims: ParallelDims, 246 | ft_manager: FTManager | None = None, 247 | ) -> OptimizersContainer: 248 | """Create a OptimizersContainer for the given model parts and job config. 249 | 250 | This function creates a ``OptimizersContainer`` for the given model parts. 251 | ``job_config`` should define the correct optimizer name and parameters. 252 | This function currently supports creating ``OptimizersContainer`` and 253 | ``OptimizersInBackwardContainer``. 254 | 255 | **Note** 256 | Users who want to customize the optimizer behavior can create their own 257 | ``OptimizersContainer`` subclass and ``build_optimizers``. Passing the 258 | customized ``build_optimizers`` to ``TrainSpec`` will create the customized 259 | ``OptimizersContainer``. 260 | 261 | Args: 262 | model_parts (List[nn.Module]): List of model parts to be optimized. 263 | job_config (JobConfig): Job config containing the optimizer name and parameters. 264 | parallel_dims (ParallelDims): Parallel dimensions for the model. 265 | """ 266 | optim_in_bwd = job_config.optimizer.early_step_in_backward 267 | if optim_in_bwd: 268 | if parallel_dims.ep_enabled: 269 | raise NotImplementedError( 270 | "Optimizers in backward is not supported with Expert Parallel." 271 | ) 272 | if parallel_dims.pp_enabled: 273 | raise NotImplementedError( 274 | "Optimizers in backward is not supported with Pipeline Parallel." 275 | ) 276 | if ft_manager and ft_manager.enabled: 277 | raise NotImplementedError( 278 | "TorchFT is not supported with optimizers in backward." 279 | ) 280 | 281 | name = job_config.optimizer.name 282 | lr = job_config.optimizer.lr 283 | beta1 = job_config.optimizer.beta1 284 | beta2 = job_config.optimizer.beta2 285 | eps = job_config.optimizer.eps 286 | weight_decay = job_config.optimizer.weight_decay 287 | 288 | optim_implementation = job_config.optimizer.implementation 289 | assert optim_implementation in ["fused", "foreach", "for-loop"] 290 | 291 | fused = optim_implementation == "fused" 292 | foreach = optim_implementation == "foreach" 293 | 294 | optimizer_kwargs = { 295 | "lr": lr, 296 | "betas": (beta1, beta2), 297 | "eps": eps, 298 | "weight_decay": weight_decay, 299 | "fused": fused, 300 | "foreach": foreach, 301 | } 302 | 303 | optimizer_classes = { 304 | "Adam": torch.optim.Adam, 305 | "AdamW": torch.optim.AdamW, 306 | } 307 | if name not in optimizer_classes: 308 | raise NotImplementedError(f"Optimizer {name} not added.") 309 | optimizer_cls = optimizer_classes[name] 310 | 311 | if optim_in_bwd: 312 | return OptimizersInBackwardContainer( 313 | model_parts, optimizer_cls, optimizer_kwargs 314 | ) 315 | 316 | if ft_manager and ft_manager.enabled: 317 | return FTOptimizersContainer( 318 | model_parts, 319 | optimizer_cls, 320 | optimizer_kwargs, 321 | ft_manager.manager, 322 | use_ft_optimizer=job_config.fault_tolerance.semi_sync_method is None, 323 | ) 324 | 325 | return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) 326 | -------------------------------------------------------------------------------- /torchtitan/components/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import time 9 | from collections import namedtuple 10 | from datetime import datetime 11 | from typing import Any, TYPE_CHECKING 12 | 13 | import torch 14 | from torchtitan.components.lr_scheduler import LRSchedulersContainer 15 | from torchtitan.components.optimizer import OptimizersContainer 16 | from torchtitan.config_manager import JobConfig 17 | from torchtitan.distributed import ParallelDims 18 | from torchtitan.tools import utils 19 | from torchtitan.tools.logging import logger 20 | from torchtitan.tools.utils import Color, device_module, device_type 21 | 22 | if TYPE_CHECKING: 23 | pass 24 | 25 | 26 | # named tuple for passing device memory stats for logging 27 | DeviceMemStats = namedtuple( 28 | "DeviceMemStats", 29 | [ 30 | "max_active_gib", 31 | "max_active_pct", 32 | "max_reserved_gib", 33 | "max_reserved_pct", 34 | "num_alloc_retries", 35 | "num_ooms", 36 | ], 37 | ) 38 | 39 | 40 | class DeviceMemoryMonitor: 41 | def __init__(self, device: str = f"{device_type}:0"): 42 | self.device = torch.device(device) # device object 43 | self.device_name = device_module.get_device_name(self.device) 44 | self.device_index = device_module.current_device() 45 | self.device_capacity = device_module.get_device_properties( 46 | self.device 47 | ).total_memory 48 | self.device_capacity_gib = self._to_gib(self.device_capacity) 49 | 50 | device_module.reset_peak_memory_stats() 51 | device_module.empty_cache() 52 | 53 | def _to_gib(self, memory_in_bytes): 54 | # NOTE: GiB (gibibyte) is 1024, vs GB is 1000 55 | _gib_in_bytes = 1024 * 1024 * 1024 56 | memory_in_gib = memory_in_bytes / _gib_in_bytes 57 | return memory_in_gib 58 | 59 | def _to_pct(self, memory): 60 | return 100 * memory / self.device_capacity 61 | 62 | def get_peak_stats(self): 63 | device_info = device_module.memory_stats(self.device) 64 | 65 | max_active = device_info.get("active_bytes.all.peak", -1) 66 | max_active_gib = self._to_gib(max_active) 67 | max_active_pct = self._to_pct(max_active) 68 | 69 | max_reserved = device_info.get("reserved_bytes.all.peak", -1) 70 | max_reserved_gib = self._to_gib(max_reserved) 71 | max_reserved_pct = self._to_pct(max_reserved) 72 | 73 | num_retries = device_info.get("num_alloc_retries", -1) 74 | num_ooms = device_info.get("num_ooms", -1) 75 | 76 | if num_retries > 0: 77 | logger.warning( 78 | f"{num_retries} {device_type.upper()} memory allocation retries." 79 | ) 80 | if num_ooms > 0: 81 | logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") 82 | 83 | return DeviceMemStats( 84 | max_active_gib, 85 | max_active_pct, 86 | max_reserved_gib, 87 | max_reserved_pct, 88 | num_retries, 89 | num_ooms, 90 | ) 91 | 92 | def reset_peak_stats(self): 93 | device_module.reset_peak_memory_stats() 94 | 95 | 96 | def build_device_memory_monitor(): 97 | device_memory_monitor = DeviceMemoryMonitor(device_type) 98 | logger.info( 99 | f"{device_type.upper()} capacity: {device_memory_monitor.device_name} " 100 | f"with {device_memory_monitor.device_capacity_gib:.2f}GiB memory" 101 | ) 102 | return device_memory_monitor 103 | 104 | 105 | class BaseLogger: 106 | """Logger that does nothing, used when logging is disabled.""" 107 | 108 | def log(self, metrics: dict[str, Any], step: int) -> None: 109 | pass 110 | 111 | def close(self) -> None: 112 | pass 113 | 114 | 115 | class WandBLogger(BaseLogger): 116 | """Logger implementation for Weights & Biases.""" 117 | 118 | def __init__(self, log_dir: str, job_config: JobConfig, tag: str | None = None): 119 | # Import wandb here to avoid startup import 120 | import wandb 121 | 122 | self.wandb = wandb 123 | self.tag = tag 124 | 125 | # Create logging directory 126 | os.makedirs(log_dir, exist_ok=True) 127 | 128 | self.wandb.init( 129 | project=os.getenv("WANDB_PROJECT", "torchtitan"), 130 | dir=log_dir, 131 | name=job_config.job.wandb_run_name, 132 | config=job_config.to_dict(), 133 | ) 134 | logger.info("WandB logging enabled") 135 | 136 | def log(self, metrics: dict[str, Any], step: int) -> None: 137 | wandb_metrics = { 138 | (k if self.tag is None else f"{self.tag}/{k}"): v 139 | for k, v in metrics.items() 140 | } 141 | self.wandb.log(wandb_metrics, step=step) 142 | 143 | def close(self) -> None: 144 | if self.wandb.run is not None: 145 | self.wandb.finish() 146 | 147 | 148 | def ensure_pp_loss_visible( 149 | parallel_dims: ParallelDims, job_config: JobConfig, color: Color 150 | ) -> None: 151 | """ 152 | Ensures that the loss is visible on the console for pipeline-parallel training. 153 | 154 | For pipeline-parallel training, the loss is only visible on the last pipeline stage. 155 | This function checks if the appropriate rank is included in the LOG_RANK environment 156 | variable and warns if it's not. 157 | """ 158 | 159 | # V Block Schedules return loss on rank 0 160 | if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": 161 | return 162 | 163 | # Calculate the rank where loss is visible (first rank of the last pipeline stage) 164 | world_size = parallel_dims.world_size 165 | pp_size = parallel_dims.pp 166 | loss_visible_rank = (world_size // pp_size) * (pp_size - 1) 167 | 168 | # Check if the loss-visible rank is included in LOG_RANK environment variable 169 | env_logged_ranks = os.environ.get("LOG_RANK", "").split(",") 170 | if env_logged_ranks == [""]: 171 | env_logged_ranks = [] 172 | 173 | if str(loss_visible_rank) not in env_logged_ranks: 174 | logger.warning( 175 | f"{color.red}Pipeline Parallel loss is not visible. " 176 | f"Please add {color.yellow}rank {loss_visible_rank}{color.red} " 177 | f"to LOG_RANK environment variable in run_train.sh.{color.reset}" 178 | ) 179 | 180 | 181 | def _get_metrics_rank( 182 | parallel_dims: ParallelDims, 183 | job_config: JobConfig, 184 | ) -> int: 185 | """ 186 | Determines which rank should log metrics. 187 | 188 | Returns: 189 | int: The rank responsible for logging metrics: 190 | - Rank 0 for non-pipeline-parallel configs 191 | - Rank 0 for pipeline-parallel 'ZBVZeroBubble' schedule 192 | - The first rank of the last pipeline stage for other pipeline-parallel schedules 193 | """ 194 | # Early return for non-pipeline-parallel configurations 195 | if not parallel_dims.pp_enabled: 196 | return 0 197 | 198 | # V Block Schedules return loss on rank 0 199 | if job_config.parallelism.pipeline_parallel_schedule == "ZBVZeroBubble": 200 | return 0 201 | 202 | # Calculate first rank of the last pipeline stage 203 | world_size = parallel_dims.world_size 204 | pp_size = parallel_dims.pp 205 | return (world_size // pp_size) * (pp_size - 1) 206 | 207 | 208 | def _build_metric_logger( 209 | job_config: JobConfig, parallel_dims: ParallelDims, tag: str | None = None 210 | ) -> BaseLogger: 211 | """ 212 | Build an appropriate metric logger based on configuration. 213 | """ 214 | metrics_config = job_config.metrics 215 | 216 | # Log initial config state 217 | logger.debug(f"Building logger with config: wandb={metrics_config.enable_wandb}, ") 218 | 219 | # Check if any logging backend is enabled 220 | has_logging_enabled = metrics_config.enable_wandb 221 | 222 | # Determine if this rank should log 223 | should_log = has_logging_enabled 224 | if (not metrics_config.save_for_all_ranks) and should_log: 225 | metrics_rank = _get_metrics_rank(parallel_dims, job_config) 226 | should_log = torch.distributed.get_rank() == metrics_rank 227 | 228 | logger.debug( 229 | f"Logging decision: has_logging_enabled={has_logging_enabled}, should_log={should_log}" 230 | ) 231 | 232 | if not should_log: 233 | logger.debug("Returning BaseLogger due to should_log=False") 234 | return BaseLogger() 235 | 236 | # Setup logging directory 237 | dump_dir = job_config.job.dump_folder 238 | base_log_dir = os.path.join( 239 | dump_dir, metrics_config.save_tb_folder, datetime.now().strftime("%Y%m%d-%H%M") 240 | ) 241 | 242 | if job_config.fault_tolerance.enable: 243 | base_log_dir = os.path.join( 244 | base_log_dir, 245 | f"replica_{job_config.fault_tolerance.replica_id}", 246 | ) 247 | 248 | if metrics_config.save_for_all_ranks: 249 | base_log_dir = os.path.join( 250 | base_log_dir, f"rank_{torch.distributed.get_rank()}" 251 | ) 252 | 253 | # Create loggers in priority order 254 | if metrics_config.enable_wandb: 255 | logger.debug("Attempting to create WandB logger") 256 | try: 257 | return WandBLogger(base_log_dir, job_config, tag) 258 | except Exception as e: 259 | if "No module named 'wandb'" in str(e): 260 | logger.error( 261 | "Failed to create WandB logger: No module named 'wandb'. Please install it using 'pip install wandb'." 262 | ) 263 | else: 264 | logger.error(f"Failed to create WandB logger: {e}") 265 | 266 | logger.debug("No loggers enabled, returning BaseLogger") 267 | return BaseLogger() 268 | 269 | 270 | class MetricsProcessor: 271 | """Metrics processor to processes the metrics and log metrics. 272 | 273 | The current MetricsProcessor log some metrics to STDOUT and some metrics to 274 | WandB. 275 | 276 | Args: 277 | job_config (JobConfig): Job configuration. 278 | parallel_dims (ParallelDims): Parallel dimensions. 279 | tag (Optional[str]): Tag to use for WandB. Defaults to None. 280 | """ 281 | 282 | logger: BaseLogger 283 | parallel_dims: ParallelDims 284 | job_config: JobConfig 285 | device_memory_monitor: DeviceMemoryMonitor 286 | color: utils.NoColor | utils.Color 287 | 288 | gpu_peak_flops: int 289 | ntokens_since_last_log: int 290 | data_loading_times: list[float] 291 | time_last_log: float 292 | 293 | num_flops_per_token: int 294 | optimizers: OptimizersContainer | None 295 | lr_schedulers: LRSchedulersContainer | None 296 | 297 | def __init__( 298 | self, 299 | job_config: JobConfig, 300 | parallel_dims: ParallelDims, 301 | tag: str | None = None, 302 | ): 303 | self.logger = _build_metric_logger(job_config, parallel_dims, tag) 304 | self.parallel_dims = parallel_dims 305 | self.job_config = job_config 306 | self.device_memory_monitor = build_device_memory_monitor() 307 | # used for colorful printing 308 | self.color = ( 309 | utils.NoColor() 310 | if job_config.metrics.disable_color_printing 311 | else utils.Color() 312 | ) 313 | 314 | self.gpu_peak_flops = utils.get_peak_flops( 315 | self.device_memory_monitor.device_name 316 | ) 317 | self.ntokens_since_last_log = 0 318 | self.data_loading_times = [] 319 | self.time_last_log = time.perf_counter() 320 | self.device_memory_monitor.reset_peak_stats() 321 | 322 | # These variables have to be set later as they depend on other components or model. 323 | self.num_flops_per_token = -1 324 | self.optimizers = None 325 | self.lr_schedulers = None 326 | 327 | def should_log(self, step: int) -> bool: 328 | return step == 1 or step % self.job_config.metrics.log_freq == 0 329 | 330 | def log( 331 | self, 332 | step: int, 333 | global_avg_loss: float, 334 | global_max_loss: float, 335 | grad_norm: float, 336 | lr: float, 337 | epoch: int, 338 | extra_metrics: dict[str, Any] | None = None, 339 | ): 340 | assert self.num_flops_per_token > 0, "num_flops_per_token must be set" 341 | 342 | time_delta = time.perf_counter() - self.time_last_log 343 | 344 | # tokens per second per device, abbreviated as tps 345 | tps = self.ntokens_since_last_log / ( 346 | time_delta * self.parallel_dims.non_data_parallel_size 347 | ) 348 | # model FLOPS utilization 349 | # For its definition and calculation, please refer to the PaLM paper: 350 | # https://arxiv.org/abs/2204.02311 351 | mfu = 100 * self.num_flops_per_token * tps / self.gpu_peak_flops 352 | tflops = self.num_flops_per_token * tps / 1e12 353 | 354 | time_end_to_end = time_delta / self.job_config.metrics.log_freq 355 | time_data_loading = sum(self.data_loading_times) / len(self.data_loading_times) 356 | time_data_loading_pct = 100 * sum(self.data_loading_times) / time_delta 357 | 358 | device_mem_stats = self.device_memory_monitor.get_peak_stats() 359 | 360 | metrics = { 361 | "loss_metrics/global_avg_loss": global_avg_loss, 362 | "loss_metrics/global_max_loss": global_max_loss, 363 | "grad_norm": grad_norm, 364 | "lr": lr, 365 | "epoch": epoch, 366 | "throughput(tps)": tps, 367 | "tflops": tflops, 368 | "mfu(%)": mfu, 369 | "time_metrics/end_to_end(s)": time_end_to_end, 370 | "time_metrics/data_loading(s)": time_data_loading, 371 | "time_metrics/data_loading(%)": time_data_loading_pct, 372 | "memory/max_active(GiB)": device_mem_stats.max_active_gib, 373 | "memory/max_active(%)": device_mem_stats.max_active_pct, 374 | "memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, 375 | "memory/max_reserved(%)": device_mem_stats.max_reserved_pct, 376 | "memory/num_alloc_retries": device_mem_stats.num_alloc_retries, 377 | "memory/num_ooms": device_mem_stats.num_ooms, 378 | } 379 | 380 | if extra_metrics: 381 | metrics.update(extra_metrics) 382 | 383 | self.logger.log(metrics, step) 384 | 385 | color = self.color 386 | logger.info( 387 | f"{color.red}step: {step:2} " 388 | f"{color.green}loss: {global_avg_loss:7.4f} " 389 | f"{color.orange}grad_norm: {grad_norm:7.4f} " 390 | f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" 391 | f"({device_mem_stats.max_reserved_pct:.2f}%) " 392 | f"{color.blue}tps: {round(tps):,} " 393 | f"{color.cyan}tflops: {tflops:,.2f} " 394 | f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" 395 | ) 396 | 397 | self.ntokens_since_last_log = 0 398 | self.data_loading_times.clear() 399 | self.time_last_log = time.perf_counter() 400 | self.device_memory_monitor.reset_peak_stats() 401 | 402 | def log_validation(self, loss: float, step: int): 403 | time_delta = time.perf_counter() - self.time_last_log 404 | 405 | device_mem_stats = self.device_memory_monitor.get_peak_stats() 406 | 407 | # tokens per second per device, abbreviated as tps 408 | tps = self.ntokens_since_last_log / ( 409 | time_delta * self.parallel_dims.non_data_parallel_size 410 | ) 411 | 412 | metrics = { 413 | "validation_metrics/loss": loss, 414 | "validation_metrics/throughput(tps)": tps, 415 | "validation_metrics/memory/max_active(GiB)": device_mem_stats.max_active_gib, 416 | "validation_metrics/memory/max_active(%)": device_mem_stats.max_active_pct, 417 | "validation_metrics/memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, 418 | "validation_metrics/memory/max_reserved(%)": device_mem_stats.max_reserved_pct, 419 | } 420 | self.logger.log(metrics, step) 421 | 422 | color = self.color 423 | logger.info( 424 | f"{color.yellow}validate step: {step:2} " 425 | f"{color.green}loss: {loss:7.4f} " 426 | f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" 427 | f"({device_mem_stats.max_reserved_pct:.2f}%) " 428 | f"{color.blue}tps: {round(tps):,}{color.reset}" 429 | ) 430 | 431 | self.ntokens_since_last_log = 0 432 | self.time_last_log = time.perf_counter() 433 | self.device_memory_monitor.reset_peak_stats() 434 | 435 | def close(self): 436 | self.logger.close() 437 | 438 | 439 | def build_metrics_processor( 440 | job_config: JobConfig, 441 | parallel_dims: ParallelDims, 442 | tag: str | None = None, 443 | ) -> MetricsProcessor: 444 | """Create a metrics processor. 445 | 446 | Args: 447 | job_config (JobConfig): Job configuration. 448 | parallel_dims (ParallelDims): Parallel dimensions. 449 | tag (str | None): Tag to use for WandB. Defaults to None. 450 | 451 | Returns: 452 | MetricsProcessor: A metrics processor. 453 | """ 454 | return MetricsProcessor(job_config, parallel_dims, tag) 455 | -------------------------------------------------------------------------------- /torchtitan/components/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | 10 | import os 11 | from abc import ABC, abstractmethod 12 | from typing import Any, Optional, Union 13 | 14 | from tokenizers import AddedToken, Tokenizer 15 | from torchtitan.config_manager import JobConfig 16 | from torchtitan.tools.logging import logger 17 | from typing_extensions import override 18 | 19 | 20 | class BaseTokenizer(ABC): 21 | # base tokenizer interface, for typing purpose mainly 22 | def __init__(self): 23 | self.eos_id = 0 24 | 25 | @abstractmethod 26 | def encode(self, *args, **kwargs) -> list[int]: ... 27 | 28 | @abstractmethod 29 | def decode(self, *args, **kwargs) -> str: ... 30 | 31 | @abstractmethod 32 | def get_vocab_size(self) -> int: ... 33 | 34 | 35 | class HuggingFaceTokenizer(BaseTokenizer): 36 | """ 37 | A tokenizer wrapper that handles BOS/EOS token inference and encoding. 38 | 39 | This class loads tokenizer files and automatically infers BOS/EOS tokens from 40 | a configuration file (tokenizer_config.json). It provides an encode method that adds 41 | BOS/EOS tokens based on whether the underlying tokenizer adds them automatically. 42 | 43 | Args: 44 | tokenizer_path (str): Path to directory containing tokenizer files 45 | """ 46 | 47 | def __init__( 48 | self, 49 | tokenizer_path: str, 50 | ): 51 | super().__init__() 52 | self.tokenizer_path = tokenizer_path 53 | 54 | # Initialize BOS/EOS token attributes (frequently used) 55 | self.bos_id = None 56 | self.eos_id = None 57 | self.bos_token = None 58 | self.eos_token = None 59 | 60 | # Load the underlying tokenizer 61 | self.tokenizer = self._load_tokenizer_from_path(tokenizer_path) 62 | 63 | # Load configuration files 64 | self.config = self._load_config( 65 | os.path.join(tokenizer_path, "tokenizer_config.json") 66 | ) 67 | 68 | # Infer special tokens and adding BOS/EOS behavior 69 | self._infer_special_tokens() 70 | self._infer_should_add_bos_eos() 71 | 72 | def _load_config(self, config_path: str) -> Optional[dict]: 73 | """Load configuration from JSON file if it exists.""" 74 | if os.path.exists(config_path): 75 | with open(config_path, "r") as f: 76 | return json.load(f) 77 | return None 78 | 79 | def _load_tokenizer_from_path(self, tokenizer_path: str) -> Tokenizer: 80 | """Load tokenizer from various file formats.""" 81 | if not os.path.exists(tokenizer_path): 82 | raise FileNotFoundError(f"Tokenizer path '{tokenizer_path}' does not exist") 83 | 84 | # Define paths for different tokenizer file types 85 | tokenizer_json_path = os.path.join(tokenizer_path, "tokenizer.json") 86 | vocab_txt_path = os.path.join(tokenizer_path, "vocab.txt") 87 | vocab_json_path = os.path.join(tokenizer_path, "vocab.json") 88 | merges_txt_path = os.path.join(tokenizer_path, "merges.txt") 89 | 90 | # Strategy 1: Load from tokenizer.json (preferred for modern tokenizers) 91 | if os.path.exists(tokenizer_json_path): 92 | logger.info("Loading tokenizer from tokenizer.json") 93 | return Tokenizer.from_file(tokenizer_json_path) 94 | # Strategy 2: Load from vocab files (with or without merges.txt) 95 | elif os.path.exists(vocab_json_path) or os.path.exists(vocab_txt_path): 96 | # Load vocabulary 97 | if os.path.exists(vocab_json_path): 98 | logger.info("Loading vocabulary from vocab.json") 99 | with open(vocab_json_path, "r") as f: 100 | vocab = json.load(f) 101 | vocab_source = "vocab.json" 102 | else: 103 | logger.info("Loading vocabulary from vocab.txt") 104 | vocab = {} 105 | with open(vocab_txt_path, "r") as f: 106 | for i, line in enumerate(f): 107 | token = line.strip() 108 | if token: 109 | vocab[token] = i 110 | vocab_source = "vocab.txt" 111 | 112 | # Strategy 2a: Use BPE if merges.txt exists 113 | if os.path.exists(merges_txt_path): 114 | logger.info(f"Loading BPE tokenizer from {vocab_source} + merges.txt") 115 | from tokenizers import decoders, pre_tokenizers, processors 116 | from tokenizers.models import BPE 117 | 118 | # Load merges from file and convert to tuples 119 | merges = [] 120 | with open(merges_txt_path, "r") as f: 121 | for line in f: 122 | line = line.strip() 123 | if line and not line.startswith( 124 | "#" 125 | ): # Skip comments and empty lines 126 | parts = line.split() 127 | if len(parts) >= 2: 128 | merges.append((parts[0], parts[1])) 129 | 130 | # Create BPE model 131 | bpe_model = BPE(vocab=vocab, merges=merges) 132 | tokenizer = Tokenizer(bpe_model) 133 | 134 | # Configure GPT-2 style components for proper space handling 135 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel( 136 | add_prefix_space=False 137 | ) 138 | tokenizer.decoder = decoders.ByteLevel() 139 | tokenizer.post_processor = processors.ByteLevel(trim_offsets=True) 140 | 141 | return tokenizer 142 | 143 | # Strategy 2b: Use WordLevel if no merges.txt 144 | else: 145 | logger.info(f"Loading WordLevel tokenizer from {vocab_source}") 146 | from tokenizers.models import WordLevel 147 | 148 | word_level_model = WordLevel(vocab=vocab, unk_token="[UNK]") 149 | return Tokenizer(word_level_model) 150 | 151 | else: 152 | # List available files for debugging 153 | available_files = [ 154 | f 155 | for f in os.listdir(tokenizer_path) 156 | if os.path.isfile(os.path.join(tokenizer_path, f)) 157 | ] 158 | raise FileNotFoundError( 159 | f"No supported tokenizer files found in '{tokenizer_path}'. " 160 | f"Available files: {available_files}. " 161 | "Looking for: tokenizer.json, tokenizer.model, vocab.txt+merges.txt, or vocab.json+merges.txt" 162 | ) 163 | 164 | def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[str]: 165 | """ 166 | Parse special tokens from config that can be either strings or dicts. 167 | HF tokens are stored as either {'bos_token': ''} or {'bos_token': {'content': '', ...}}. 168 | """ 169 | token = config.get(key) 170 | if isinstance(token, dict): 171 | if "content" not in token: 172 | raise ValueError(f"Could not parse {key} from config") 173 | token = token["content"] 174 | elif token is not None and not isinstance(token, str): 175 | raise ValueError( 176 | f"Could not parse {key} from config - expected string or dict" 177 | ) 178 | return token 179 | 180 | def _process_special_token( 181 | self, token_str: str, token_config: dict, token_id: Optional[int] = None 182 | ) -> AddedToken: 183 | """ 184 | Process a special token and update BOS/EOS attributes if applicable. 185 | 186 | Args: 187 | token_str: The token string content 188 | token_config: Token configuration dictionary 189 | token_id: Optional explicit token ID (for added_tokens_decoder) 190 | 191 | Returns: 192 | AddedToken object to be added to the tokenizer 193 | """ 194 | # Get reference BOS/EOS tokens from config for comparison 195 | config_bos_token = ( 196 | self._get_token_from_config(self.config, "bos_token") 197 | if self.config 198 | else None 199 | ) 200 | config_eos_token = ( 201 | self._get_token_from_config(self.config, "eos_token") 202 | if self.config 203 | else None 204 | ) 205 | 206 | # Store BOS/EOS tokens as class attributes if they match 207 | if token_str == config_bos_token: 208 | self.bos_token = token_str 209 | self.bos_id = ( 210 | token_id 211 | if token_id is not None 212 | else self.tokenizer.token_to_id(token_str) 213 | ) 214 | elif token_str == config_eos_token: 215 | self.eos_token = token_str 216 | self.eos_id = ( 217 | token_id 218 | if token_id is not None 219 | else self.tokenizer.token_to_id(token_str) 220 | ) 221 | 222 | # Create AddedToken object based on config format 223 | if isinstance(token_config, dict): 224 | if token_config.get("__type") == "AddedToken" or "content" in token_config: 225 | # Handle both AddedToken format and added_tokens_decoder format 226 | return AddedToken( 227 | content=token_str, 228 | single_word=token_config.get("single_word", False), 229 | lstrip=token_config.get("lstrip", False), 230 | rstrip=token_config.get("rstrip", False), 231 | normalized=token_config.get("normalized", True), 232 | special=token_config.get("special", True), 233 | ) 234 | 235 | # Fallback to simple special token 236 | return AddedToken(content=token_str, special=True) 237 | 238 | def _infer_special_tokens(self): 239 | """ 240 | Read special tokens from config and add them to the underlying tokenizer. 241 | Store BOS/EOS tokens as class attributes since they are frequently used. 242 | 243 | This method handles multiple token configuration formats: 244 | 1. Standard top-level keys (bos_token, eos_token, etc.) 245 | 2. added_tokens_decoder dictionary (used by models like Llama 3.1) 246 | """ 247 | standard_keys = [ 248 | "bos_token", 249 | "eos_token", 250 | "pad_token", 251 | "unk_token", 252 | "sep_token", 253 | "cls_token", 254 | "mask_token", 255 | ] 256 | 257 | # List to collect AddedToken objects for updating the underlying tokenizer 258 | added_tokens_to_add = [] 259 | 260 | if not self.config: 261 | return 262 | 263 | # Process standard top-level token keys 264 | for key in standard_keys: 265 | token_config = self.config.get(key) 266 | if token_config is not None: 267 | token_str = self._get_token_from_config(self.config, key) 268 | if token_str is not None: 269 | added_token = self._process_special_token(token_str, token_config) 270 | added_tokens_to_add.append(added_token) 271 | 272 | # Process added_tokens_decoder (comprehensive special token definitions) 273 | added_tokens_decoder = self.config.get("added_tokens_decoder", {}) 274 | for token_id_str, token_config in added_tokens_decoder.items(): 275 | if isinstance(token_config, dict) and "content" in token_config: 276 | token_str = token_config["content"] 277 | token_id = int(token_id_str) 278 | added_token = self._process_special_token( 279 | token_str, token_config, token_id 280 | ) 281 | added_tokens_to_add.append(added_token) 282 | 283 | # Update the underlying tokenizer with special tokens 284 | if added_tokens_to_add: 285 | self.tokenizer.add_special_tokens(added_tokens_to_add) 286 | 287 | # Update BOS/EOS token IDs after adding to tokenizer (in case they changed) 288 | if self.bos_token: 289 | self.bos_id = self.tokenizer.token_to_id(self.bos_token) 290 | if self.eos_token: 291 | self.eos_id = self.tokenizer.token_to_id(self.eos_token) 292 | 293 | def _infer_should_add_bos_eos(self): 294 | """ 295 | Determine if we should add BOS/EOS tokens based on config settings. 296 | If config explicitly specifies add_bos_token/add_eos_token, follow that. 297 | Otherwise, determine if the underlying tokenizer automatically adds them. 298 | """ 299 | self.default_add_bos = False 300 | self.default_add_eos = False 301 | self.hf_adds_bos = False 302 | self.hf_adds_eos = False 303 | 304 | # First, determine if underlying tokenizer auto-adds BOS/EOS tokens empirically 305 | encoded_empty_str = self.tokenizer.encode("").ids 306 | if self.bos_id is not None and self.bos_id in encoded_empty_str: 307 | self.hf_adds_bos = True 308 | if self.eos_id is not None and self.eos_id in encoded_empty_str: 309 | self.hf_adds_eos = True 310 | 311 | # Check tokenizer_config.json for explicit settings - these override empirical detection 312 | if self.config: 313 | config_add_bos = self.config.get("add_bos_token") 314 | config_add_eos = self.config.get("add_eos_token") 315 | if config_add_bos is not None: 316 | self.default_add_bos = bool(config_add_bos) 317 | if config_add_eos is not None: 318 | self.default_add_eos = bool(config_add_eos) 319 | 320 | def encode(self, *args, **kwargs) -> list[int]: 321 | """ 322 | Encode text into token IDs with BOS/EOS handling. 323 | 324 | Args: 325 | text (str): The text to encode 326 | add_bos (bool): Whether to add BOS token (if not already added by tokenizer) 327 | add_eos (bool): Whether to add EOS token (if not already added by tokenizer) 328 | 329 | Returns: 330 | list[int]: List of token IDs 331 | """ 332 | # Extract arguments 333 | if len(args) >= 1: 334 | text = args[0] 335 | else: 336 | text = kwargs.get("text", "") 337 | 338 | add_bos = kwargs.get("add_bos", self.default_add_bos) 339 | add_eos = kwargs.get("add_eos", self.default_add_eos) 340 | 341 | # Get base token IDs from the underlying tokenizer 342 | token_ids = self.tokenizer.encode(text).ids 343 | 344 | # Add BOS token if requested and not already added by tokenizer 345 | if not self.hf_adds_bos and add_bos: 346 | if self.bos_id is not None: 347 | token_ids.insert(0, self.bos_id) 348 | 349 | # Add EOS token if requested and not already added by tokenizer 350 | if not self.hf_adds_eos and add_eos: 351 | if self.eos_id is not None: 352 | token_ids.append(self.eos_id) 353 | 354 | return token_ids 355 | 356 | @override 357 | def decode(self, *args, **kwargs) -> str: 358 | """ 359 | Decode token IDs back to text. 360 | 361 | Args: 362 | token_ids (list[int]): List of token IDs to decode 363 | **kwargs: Additional arguments passed to the underlying tokenizer's decode method 364 | (e.g., skip_special_tokens) 365 | 366 | Returns: 367 | str: Decoded text 368 | """ 369 | # Extract token_ids from arguments 370 | if len(args) >= 1: 371 | token_ids = args[0] 372 | # Pass through remaining kwargs 373 | return self.tokenizer.decode(token_ids, **kwargs) 374 | else: 375 | token_ids = kwargs.pop("token_ids", []) 376 | # Pass through remaining kwargs after removing token_ids 377 | return self.tokenizer.decode(token_ids, **kwargs) 378 | 379 | @property 380 | def vocab_size(self) -> int: 381 | """Get the vocabulary size.""" 382 | return self.tokenizer.get_vocab_size() 383 | 384 | def get_vocab_size(self) -> int: 385 | """Get the vocabulary size.""" 386 | return self.tokenizer.get_vocab_size() 387 | 388 | def get_vocab(self) -> dict[str, int]: 389 | """Get the vocabulary as a dictionary.""" 390 | return self.tokenizer.get_vocab() 391 | 392 | def token_to_id(self, token: str) -> Optional[int]: 393 | """Convert token to ID.""" 394 | return self.tokenizer.token_to_id(token) 395 | 396 | def id_to_token(self, token_id: int) -> Optional[str]: 397 | """Convert ID to token.""" 398 | return self.tokenizer.id_to_token(token_id) 399 | 400 | 401 | def build_hf_tokenizer( 402 | job_config: JobConfig, 403 | ) -> Union[HuggingFaceTokenizer, BaseTokenizer]: 404 | """ 405 | Builds a HuggingFaceTokenizer from the specified path. 406 | 407 | This function creates a HuggingFaceTokenizer instance that handles BOS/EOS token 408 | inference and intelligent encoding. The tokenizer automatically detects and loads 409 | from various file formats and infers special token behavior. 410 | 411 | Args: 412 | JobConfig: A JobConfig object containing the path to the tokenizer directory. 413 | 414 | Returns: 415 | tokenizer (HuggingFaceTokenizer): Loaded tokenizer instance with intelligent BOS/EOS handling 416 | """ 417 | tokenizer = HuggingFaceTokenizer(job_config.model.tokenizer_path) 418 | return tokenizer 419 | --------------------------------------------------------------------------------