├── src └── llm_training │ ├── cli │ ├── __init__.py │ └── main.py │ ├── utils │ ├── safetensors │ │ ├── __init__.py │ │ └── __init__.pyi │ ├── __init__.py │ ├── str_enum.py │ ├── context_managers.py │ └── decorators.py │ ├── lightning │ ├── loggers │ │ ├── __init__.py │ │ └── wandb.py │ ├── strategy │ │ ├── __init__.py │ │ ├── deepspeed │ │ │ ├── __init__.py │ │ │ └── deepspeed_strategy.py │ │ └── fsdp2 │ │ │ └── __init__.py │ ├── cli │ │ ├── __init__.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── cli.py │ ├── __init__.py │ └── callbacks │ │ ├── __init__.py │ │ ├── tqdm_progress.py │ │ ├── model_checkpoint.py │ │ ├── extra_config.py │ │ ├── save_config_callback.py │ │ ├── training_time_estimator.py │ │ └── output_redirection.py │ ├── lms │ ├── protos │ │ ├── __init__.py │ │ └── clm_proto.py │ ├── clm │ │ ├── __init__.py │ │ ├── clm_config.py │ │ └── clm.py │ ├── dpo │ │ ├── __init__.py │ │ └── dpo_config.py │ ├── orpo │ │ ├── __init__.py │ │ └── orpo_config.py │ ├── __init__.py │ ├── utils.py │ ├── model_provider.py │ └── base_lm_config.py │ ├── optim │ ├── __init__.py │ └── master_weight_wrapper.py │ ├── models │ ├── utils │ │ ├── __init__.py │ │ ├── modeling_outputs.py │ │ └── utils.py │ ├── phi3 │ │ ├── __init__.py │ │ └── phi3_config.py │ ├── llama │ │ ├── __init__.py │ │ └── llama_config.py │ ├── base_model │ │ ├── __init__.py │ │ ├── base_model_config.py │ │ └── base_model.py │ ├── hf_causal_lm │ │ ├── __init__.py │ │ ├── hf_causal_lm_config.py │ │ └── hf_causal_lm.py │ ├── hf_compat_model │ │ ├── __init__.py │ │ ├── hf_compat_config.py │ │ └── hf_compat_model.py │ └── __init__.py │ ├── data │ ├── hf_based │ │ ├── __init__.py │ │ └── hf_based_datamodule_config.py │ ├── dummy │ │ ├── __init__.py │ │ ├── dummy_datamodule.py │ │ ├── dummy_datamodule_config.py │ │ └── dummy_dataset.py │ ├── pre_training │ │ ├── __init__.py │ │ ├── pre_training_datamodule_config.py │ │ └── pre_training_datacollator.py │ ├── preference_tuning │ │ ├── __init__.py │ │ ├── preference_tuning_datamodule_config.py │ │ ├── preference_tuning_datacollator.py │ │ └── preference_tuning_datamodule.py │ ├── instruction_tuning │ │ ├── __init__.py │ │ ├── instruction_tuning_datamodule_config.py │ │ ├── instruction_tuning_datacollator.py │ │ └── instruction_tuning_datamodule.py │ ├── __init__.py │ ├── base_datacollator.py │ ├── base_datamodule_config.py │ ├── chat_templates │ │ ├── chatml.j2 │ │ ├── tulu-2.j2 │ │ ├── gemma.j2 │ │ ├── llama-3.j2 │ │ ├── phi-3.j2 │ │ ├── llama-2.j2 │ │ ├── __init__.py │ │ ├── qwen2.5.j2 │ │ ├── llama-3.2.j2 │ │ └── llama-3.1.j2 │ ├── resumable_dataloader.py │ └── base_datamodule.py │ ├── metrics │ ├── __init__.py │ ├── consumed_samples.py │ ├── consumed_tokens.py │ ├── metric.py │ └── perplexity.py │ ├── ops │ ├── __init__.py │ ├── liger_kernel │ │ ├── __init__.py │ │ ├── rms_norm_op.py │ │ ├── rope_op.py │ │ ├── swiglu_op.py │ │ └── cross_entropy_op.py │ ├── cross_entropy_op.py │ ├── rms_norm_op.py │ ├── rope_op.py │ └── swiglu_op.py │ ├── lr_schedulers │ ├── __init__.py │ ├── constant.py │ ├── cosine.py │ ├── linear.py │ └── warmup.py │ └── __init__.py ├── environment.yaml ├── install.sh ├── pyproject.toml ├── docs ├── model_implementations.md ├── config.md ├── pre_training.md └── instruction_tuning.md ├── scripts ├── pre_process_data.py ├── train.sh ├── extend_fast_tokenizer.py └── convert_to_hf.py ├── config └── examples │ ├── llama-3.1 │ ├── llama-3.1-8b_tp_example.yaml │ ├── llama-3.1-8b_pt_example.yaml │ └── llama-3.1-8b_it_example.yaml │ └── phi-3 │ ├── phi-3-mini_tp_example.yaml │ ├── phi-3-mini_orpo_example.yaml │ ├── phi-3-mini_dpo_example.yaml │ ├── phi-3-mini_pt_example.yaml │ └── phi-3-mini_it_example.yaml ├── .gitignore └── README.md /src/llm_training/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/llm_training/utils/safetensors/__init__.py: -------------------------------------------------------------------------------- 1 | from safetensors import * 2 | -------------------------------------------------------------------------------- /src/llm_training/lightning/loggers/__init__.py: -------------------------------------------------------------------------------- 1 | from .wandb import WandbLogger 2 | -------------------------------------------------------------------------------- /src/llm_training/lms/protos/__init__.py: -------------------------------------------------------------------------------- 1 | from .clm_proto import CausalLMProto 2 | -------------------------------------------------------------------------------- /src/llm_training/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .master_weight_wrapper import MasterWeightsOptimizer 2 | -------------------------------------------------------------------------------- /src/llm_training/lightning/strategy/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed import * 2 | from .fsdp2 import * 3 | -------------------------------------------------------------------------------- /src/llm_training/lms/clm/__init__.py: -------------------------------------------------------------------------------- 1 | from .clm import CLM 2 | from .clm_config import CLMConfig 3 | -------------------------------------------------------------------------------- /src/llm_training/lms/dpo/__init__.py: -------------------------------------------------------------------------------- 1 | from .dpo import DPO 2 | from .dpo_config import DPOConfig 3 | -------------------------------------------------------------------------------- /src/llm_training/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import init_empty_weights, init_on_device 2 | -------------------------------------------------------------------------------- /src/llm_training/lms/orpo/__init__.py: -------------------------------------------------------------------------------- 1 | from .orpo import ORPO 2 | from .orpo_config import ORPOConfig 3 | -------------------------------------------------------------------------------- /src/llm_training/lightning/cli/__init__.py: -------------------------------------------------------------------------------- 1 | from .cli import LightningCLI 2 | from .utils import HFTokenizer 3 | -------------------------------------------------------------------------------- /src/llm_training/lightning/strategy/deepspeed/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed_strategy import DeepSpeedStrategy 2 | -------------------------------------------------------------------------------- /src/llm_training/models/phi3/__init__.py: -------------------------------------------------------------------------------- 1 | from .phi3_config import Phi3Config 2 | from .phi3_model import Phi3 3 | -------------------------------------------------------------------------------- /src/llm_training/models/llama/__init__.py: -------------------------------------------------------------------------------- 1 | from .llama_config import LlamaConfig 2 | from .llama_model import Llama 3 | -------------------------------------------------------------------------------- /src/llm_training/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .context_managers import ContextManagers 2 | from .str_enum import StrEnum 3 | -------------------------------------------------------------------------------- /src/llm_training/cli/main.py: -------------------------------------------------------------------------------- 1 | from llm_training.lightning import LightningCLI 2 | 3 | 4 | def main(): 5 | LightningCLI() 6 | -------------------------------------------------------------------------------- /src/llm_training/models/base_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .base_model_config import BaseModelConfig 3 | -------------------------------------------------------------------------------- /src/llm_training/models/hf_causal_lm/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_causal_lm import HFCausalLM 2 | from .hf_causal_lm_config import HFCausalLMConfig 3 | -------------------------------------------------------------------------------- /src/llm_training/lightning/__init__.py: -------------------------------------------------------------------------------- 1 | from .callbacks import * 2 | from .cli import * 3 | from .loggers import * 4 | from .strategy import * 5 | -------------------------------------------------------------------------------- /src/llm_training/lightning/strategy/fsdp2/__init__.py: -------------------------------------------------------------------------------- 1 | from .fsdp2_precision import FSDP2Precision 2 | from .fsdp2_strategy import FSDP2Strategy 3 | -------------------------------------------------------------------------------- /src/llm_training/models/hf_compat_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_compat_config import HFCompatModelConfig 2 | from .hf_compat_model import HFCompatModel 3 | -------------------------------------------------------------------------------- /src/llm_training/data/hf_based/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_based_datamodule import HFBasedDataModule 2 | from .hf_based_datamodule_config import HFBasedDataModuleConfig 3 | -------------------------------------------------------------------------------- /src/llm_training/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .consumed_samples import ConsumedSamples 2 | from .consumed_tokens import ConsumedTokens 3 | from .perplexity import Perplexity 4 | -------------------------------------------------------------------------------- /src/llm_training/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import * 2 | from .hf_causal_lm import * 3 | from .hf_compat_model import * 4 | from .llama import * 5 | from .phi3 import * 6 | -------------------------------------------------------------------------------- /src/llm_training/data/dummy/__init__.py: -------------------------------------------------------------------------------- 1 | from .dummy_datamodule import DummyDataModule 2 | from .dummy_datamodule_config import DummyDataModuleConfig 3 | from .dummy_dataset import DummyDataset 4 | -------------------------------------------------------------------------------- /src/llm_training/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy_op import shift_labels 2 | from .rms_norm_op import rms_norm 3 | from .rope_op import apply_rope, rotate_half 4 | from .swiglu_op import swiglu 5 | -------------------------------------------------------------------------------- /src/llm_training/lr_schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .constant import ConstantWarmupLR 2 | from .cosine import CosineAnnealingWarmupLR 3 | from .linear import LinearWarmupLR 4 | from .warmup import WarmupLR 5 | -------------------------------------------------------------------------------- /src/llm_training/lms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_lm import BaseLightningModule 2 | from .base_lm_config import BaseLightningModuleConfig, BaseOptimizerConfig 3 | from .clm import * 4 | from .dpo import * 5 | from .orpo import * 6 | -------------------------------------------------------------------------------- /src/llm_training/ops/liger_kernel/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy_op import cross_entropy, fused_linear_cross_entropy 2 | from .rms_norm_op import rms_norm 3 | from .rope_op import apply_rope 4 | from .swiglu_op import silu_mul, swiglu 5 | -------------------------------------------------------------------------------- /src/llm_training/data/pre_training/__init__.py: -------------------------------------------------------------------------------- 1 | from .pre_training_datacollator import PreTrainingDataCollator 2 | from .pre_training_datamodule import PreTrainingDataModule 3 | from .pre_training_datamodule_config import PreTrainingDataModuleConfig 4 | -------------------------------------------------------------------------------- /src/llm_training/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | _logger = logging.getLogger(__name__) 4 | _handler = logging.StreamHandler() 5 | _handler.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] %(message)s')) 6 | _logger.addHandler(_handler) 7 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: llm-training 2 | channels: 3 | - nvidia/label/cuda-12.4.0 4 | - nvidia 5 | dependencies: 6 | - python=3.10 7 | - pip 8 | - pytorch::pytorch=2.5.1 9 | - pytorch::pytorch-cuda=12.4 10 | - cuda 11 | - conda-forge::gxx=9 12 | -------------------------------------------------------------------------------- /src/llm_training/data/preference_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | from .preference_tuning_datacollator import PreferenceTuningDataCollator 2 | from .preference_tuning_datamodule import PreferenceTuningDataModule 3 | from .preference_tuning_datamodule_config import \ 4 | PreferenceTuningDataModuleConfig 5 | -------------------------------------------------------------------------------- /src/llm_training/data/instruction_tuning/__init__.py: -------------------------------------------------------------------------------- 1 | from .instruction_tuning_datacollator import InstructionTuningDataCollator 2 | from .instruction_tuning_datamodule import InstructionTuningDataModule 3 | from .instruction_tuning_datamodule_config import \ 4 | InstructionTuningDataModuleConfig 5 | -------------------------------------------------------------------------------- /src/llm_training/ops/cross_entropy_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def shift_labels(labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor: 5 | labels = labels.roll(shifts=-1, dims=1) 6 | index = torch.tensor(-1, device=labels.device) 7 | labels = labels.index_fill_(1, index, ignore_index) 8 | return labels 9 | -------------------------------------------------------------------------------- /src/llm_training/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_datacollator import BaseDataCollator 2 | from .base_datamodule import BaseDataModule 3 | from .base_datamodule_config import BaseDataModuleConfig 4 | from .dummy import * 5 | from .hf_based import * 6 | from .instruction_tuning import * 7 | from .pre_training import * 8 | from .preference_tuning import * 9 | -------------------------------------------------------------------------------- /src/llm_training/data/hf_based/hf_based_datamodule_config.py: -------------------------------------------------------------------------------- 1 | from llm_training.data.base_datamodule_config import BaseDataModuleConfig 2 | 3 | 4 | class HFBasedDataModuleConfig(BaseDataModuleConfig): 5 | dataset_kwargs: dict | None = None 6 | num_proc: int | None = None 7 | cleanup_cache_files: bool = False 8 | enable_cache: bool = True 9 | -------------------------------------------------------------------------------- /src/llm_training/lightning/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .extra_config import ExtraConfig 2 | from .model_checkpoint import ModelCheckpoint 3 | from .output_redirection import OutputRedirection 4 | from .save_config_callback import SaveConfigCallback 5 | from .tqdm_progress import TQDMProgressBar 6 | from .training_time_estimator import TrainingTimeEstimator 7 | -------------------------------------------------------------------------------- /src/llm_training/lms/clm/clm_config.py: -------------------------------------------------------------------------------- 1 | from llm_training.lms.base_lm_config import BaseLightningModuleConfig 2 | from llm_training.lms.utils import ModelType 3 | 4 | 5 | class CLMConfig(BaseLightningModuleConfig): 6 | model: ModelType 7 | ignore_index: int = -100 8 | neftune_alpha: float | None = None 9 | log_perplexity: bool = True 10 | -------------------------------------------------------------------------------- /src/llm_training/lms/orpo/orpo_config.py: -------------------------------------------------------------------------------- 1 | from llm_training.lms.base_lm_config import BaseLightningModuleConfig 2 | from llm_training.lms.utils import ModelType 3 | 4 | 5 | class ORPOConfig(BaseLightningModuleConfig): 6 | model: ModelType 7 | beta: float = 0.1 8 | ignore_index: int = -100 9 | empty_cache_threshold: int | None = None 10 | -------------------------------------------------------------------------------- /src/llm_training/ops/rms_norm_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rms_norm( 5 | x: torch.Tensor, 6 | weight: torch.Tensor, 7 | eps: float 8 | ) -> torch.Tensor: 9 | dtype = x.dtype 10 | x = x.to(torch.float32) 11 | variance = x.pow(2).mean(-1, keepdim=True) 12 | x = x * torch.rsqrt(variance + eps) 13 | x = weight * x.to(dtype) 14 | return x 15 | -------------------------------------------------------------------------------- /src/llm_training/lms/dpo/dpo_config.py: -------------------------------------------------------------------------------- 1 | from llm_training.lms.base_lm_config import BaseLightningModuleConfig 2 | from llm_training.lms.utils import ModelType 3 | 4 | 5 | class DPOConfig(BaseLightningModuleConfig): 6 | model: ModelType 7 | ref_model: ModelType | None = None 8 | beta: float = 0.1 9 | label_smoothing: float = 0.0 10 | ignore_index: int = -100 11 | -------------------------------------------------------------------------------- /src/llm_training/models/utils/modeling_outputs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pydantic import BaseModel, ConfigDict 3 | 4 | 5 | class ModelOutput(BaseModel): 6 | model_config = ConfigDict( 7 | arbitrary_types_allowed=True, 8 | protected_namespaces=() 9 | ) 10 | 11 | 12 | class CausalLMOutput(ModelOutput): 13 | logits: torch.Tensor 14 | last_hidden_states: torch.Tensor | None = None 15 | -------------------------------------------------------------------------------- /src/llm_training/lightning/cli/trainer.py: -------------------------------------------------------------------------------- 1 | from lightning import Trainer as _Trainer 2 | 3 | 4 | class Trainer(_Trainer): 5 | @property 6 | def estimated_stepping_batches(self) -> int | float: 7 | has_train_dataloader = self.train_dataloader is not None 8 | r = super().estimated_stepping_batches 9 | if not has_train_dataloader: 10 | self.fit_loop._combined_loader = None 11 | return r 12 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Make sure `ninja` and `packaging` is installed before installing `flash_attn` 4 | pip install ninja packaging 5 | # Force a rebuild of `flash_attn` in case .so files built with an incompatible version of CUDA is cached. 6 | 7 | FA_VERSION=$(cat pyproject.toml | grep -oE '"flash-attn[^"]+"') 8 | FA_VERSION=${FA_VERSION:1:-1} 9 | pip install $FA_VERSION --no-build-isolation --no-cache-dir 10 | 11 | pip install -e .[deepspeed] 12 | -------------------------------------------------------------------------------- /src/llm_training/data/base_datacollator.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | from .base_datamodule_config import BaseDataModuleConfig 5 | 6 | 7 | class BaseDataCollator(ABC): 8 | def __init__(self, config: BaseDataModuleConfig) -> None: 9 | super().__init__() 10 | 11 | self.config = config 12 | 13 | @abstractmethod 14 | def __call__(self, batch: list[dict[str, Any]]) -> dict[str, Any]: ... 15 | -------------------------------------------------------------------------------- /src/llm_training/lms/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | from llm_training.lms.model_provider import ModelProvider 4 | from llm_training.models.base_model.base_model import BaseModel 5 | 6 | ModelType = ModelProvider | BaseModel | Callable[[], BaseModel] 7 | 8 | 9 | def get_model(model_or_provider: ModelType) -> BaseModel: 10 | if isinstance(model_or_provider, BaseModel): 11 | return model_or_provider 12 | return model_or_provider() 13 | -------------------------------------------------------------------------------- /src/llm_training/data/base_datamodule_config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict 2 | 3 | 4 | class BaseDataModuleConfig(BaseModel): 5 | pre_processed_data_path: str | None = None 6 | validation_split: int | float | None = None 7 | batch_size: int = 1 8 | num_workers: int = 0 9 | pin_memory: bool = False 10 | prepare_data_per_node: bool = False 11 | prefetch_factor: int | None = None 12 | 13 | model_config = ConfigDict(arbitrary_types_allowed=True) 14 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/chatml.j2: -------------------------------------------------------------------------------- 1 | {%- for message in messages %} 2 | {{- '<|im_start|>' + message.role + '\n' }} 3 | {%- set content = message.content + '<|im_end|>\n' %} 4 | {%- if message.role == 'assistant' %} 5 | {% generation %} 6 | {{- content -}} 7 | {% endgeneration %} 8 | {%- else %} 9 | {{- content }} 10 | {%- endif %} 11 | {%- endfor %} 12 | {%- if add_generation_prompt %} 13 | {{- '<|im_start|>assistant\n' }} 14 | {%- endif %} 15 | -------------------------------------------------------------------------------- /src/llm_training/ops/liger_kernel/rms_norm_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from liger_kernel.ops.rms_norm import LigerRMSNormFunction 3 | 4 | from llm_training.ops.rms_norm_op import rms_norm as rms_norm_torch 5 | 6 | 7 | def rms_norm( 8 | x: torch.Tensor, 9 | weight: torch.Tensor, 10 | eps: float 11 | ) -> torch.Tensor: 12 | if x.device.type != 'cuda': 13 | return rms_norm_torch(x, weight, eps) 14 | 15 | return LigerRMSNormFunction.apply( 16 | x, 17 | weight, 18 | eps 19 | ) 20 | -------------------------------------------------------------------------------- /src/llm_training/lightning/callbacks/tqdm_progress.py: -------------------------------------------------------------------------------- 1 | from lightning import LightningModule, Trainer 2 | from lightning.pytorch.callbacks.progress.tqdm_progress import \ 3 | TQDMProgressBar as _TQDMProgressBar 4 | 5 | 6 | class TQDMProgressBar(_TQDMProgressBar): 7 | def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 8 | super().on_train_epoch_start(trainer, pl_module) 9 | 10 | if trainer.fit_loop.restarting: 11 | self.train_progress_bar.initial = self.trainer.fit_loop.batch_idx + 1 12 | -------------------------------------------------------------------------------- /src/llm_training/utils/str_enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | 4 | class StrEnum(str, Enum): 5 | def __new__(cls, value, *args, **kwargs): 6 | if not isinstance(value, (str, auto)): 7 | raise TypeError( 8 | f"Values of StrEnums must be strings: {value!r} is a {type(value)}" 9 | ) 10 | return super().__new__(cls, value, *args, **kwargs) 11 | 12 | def __str__(self): 13 | return str(self.value) 14 | 15 | def _generate_next_value_(name, *_): 16 | return name.lower() 17 | -------------------------------------------------------------------------------- /src/llm_training/ops/liger_kernel/rope_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from liger_kernel.ops.rope import LigerRopeFunction 3 | 4 | from llm_training.ops.rope_op import apply_rope as apply_rope_torch 5 | 6 | 7 | def apply_rope( 8 | q: torch.Tensor, 9 | k: torch.Tensor, 10 | cos: torch.Tensor, 11 | sin: torch.Tensor 12 | ) -> tuple[torch.Tensor, torch.Tensor]: 13 | if q.device.type != 'cuda': 14 | return apply_rope_torch(q, k, cos, sin) 15 | 16 | return LigerRopeFunction.apply( 17 | q, 18 | k, 19 | cos, 20 | sin 21 | ) 22 | -------------------------------------------------------------------------------- /src/llm_training/ops/rope_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rotate_half(x: torch.Tensor) -> torch.Tensor: 5 | x1 = x[..., : x.shape[-1] // 2] 6 | x2 = x[..., x.shape[-1] // 2 :] 7 | return torch.cat((-x2, x1), dim=-1) 8 | 9 | 10 | def apply_rope( 11 | q: torch.Tensor, 12 | k: torch.Tensor, 13 | cos: torch.Tensor, 14 | sin: torch.Tensor 15 | ) -> tuple[torch.Tensor, torch.Tensor]: 16 | cos = cos.unsqueeze(1) 17 | sin = sin.unsqueeze(1) 18 | q = (q * cos) + (rotate_half(q) * sin) 19 | k = (k * cos) + (rotate_half(k) * sin) 20 | return q, k 21 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/tulu-2.j2: -------------------------------------------------------------------------------- 1 | {%- for message in messages %} 2 | {%- if message.role == 'system' %} 3 | {{- '<|system|>\n' + message.content }} 4 | {%- elif message.role == 'user' %} 5 | {{- '<|user|>\n' + message.content }} 6 | {%- elif message.role == 'assistant' %} 7 | {% generation %} 8 | {{- '<|assistant|>\n' + message.content + eos_token -}} 9 | {% endgeneration %} 10 | {%- endif %} 11 | {%- if loop.last and add_generation_prompt %} 12 | {{- '<|assistant|>' }} 13 | {%- endif %} 14 | {%- endfor %} 15 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/gemma.j2: -------------------------------------------------------------------------------- 1 | {{- bos_token }} 2 | {%- for message in messages %} 3 | {%- set content = message.content | trim + '\n' %} 4 | {%- set role = 'model' if message.role == 'assistant' else message.role %} 5 | {{- '' + role + '\n' }} 6 | {%- if message.role == 'assistant' %} 7 | {% generation %} 8 | {{- content }} 9 | {% endgeneration %} 10 | {%- else %} 11 | {{- content }} 12 | {%- endif %} 13 | {%- endfor %} 14 | {%- if add_generation_prompt %}{{'model\n'}} 15 | {%- endif %} 16 | -------------------------------------------------------------------------------- /src/llm_training/metrics/consumed_samples.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .metric import Metric 3 | 4 | 5 | class ConsumedSamples(Metric): 6 | higher_is_better: bool = True 7 | full_state_update: bool = False 8 | 9 | n: torch.Tensor 10 | 11 | def __init__(self, **kwargs) -> None: 12 | super().__init__(**kwargs) 13 | 14 | self.add_state('n', torch.tensor(0), dist_reduce_fx='sum', persistent=True) 15 | 16 | def update(self, target: torch.Tensor) -> None: 17 | self.n += target.size(0) 18 | 19 | def compute(self) -> torch.Tensor: 20 | return self.n 21 | -------------------------------------------------------------------------------- /src/llm_training/models/hf_causal_lm/hf_causal_lm_config.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from transformers import PretrainedConfig 4 | 5 | from llm_training.models.hf_compat_model import HFCompatModelConfig 6 | 7 | 8 | class HFCausalLMConfig(HFCompatModelConfig): 9 | enable_gradient_checkpointing: bool = False 10 | enable_liger_kernel: bool = False 11 | 12 | hf_config: PretrainedConfig | None = None 13 | 14 | def __getattr__(self, name: str) -> Any: 15 | if hasattr(self.hf_config, name): 16 | return getattr(self.hf_config, name) 17 | return super().__getattr__(name) 18 | -------------------------------------------------------------------------------- /src/llm_training/utils/context_managers.py: -------------------------------------------------------------------------------- 1 | from contextlib import AbstractContextManager, ExitStack 2 | from typing import ContextManager 3 | 4 | 5 | class ContextManagers(AbstractContextManager): 6 | def __init__(self, context_managers: list[ContextManager]) -> None: 7 | self.context_managers = context_managers 8 | self.stack = ExitStack() 9 | 10 | def __enter__(self): 11 | for cm in self.context_managers: 12 | self.stack.enter_context(cm) 13 | return self 14 | 15 | def __exit__(self, __exc_type, __exc_value, __traceback): 16 | self.stack.close() 17 | -------------------------------------------------------------------------------- /src/llm_training/utils/safetensors/__init__.pyi: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors import * 3 | from typing_extensions import Self 4 | 5 | class _TensorSlice: 6 | def __getitem__(self, index: slice | tuple[slice]) -> torch.Tensor: ... 7 | 8 | def get_shape(self) -> list[int]: ... 9 | 10 | 11 | class safe_open: 12 | def __init__(self, filename: str, framework: str, device: str = 'cpu'): ... 13 | 14 | def keys(self) -> list[str]: ... 15 | 16 | def get_tensor(key: str) -> torch.Tensor: ... 17 | 18 | def get_slice(self, key: str) -> _TensorSlice: ... 19 | 20 | def __enter__(self) -> Self: ... 21 | -------------------------------------------------------------------------------- /src/llm_training/metrics/consumed_tokens.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .metric import Metric 3 | 4 | 5 | class ConsumedTokens(Metric): 6 | higher_is_better: bool = True 7 | full_state_update: bool = False 8 | 9 | n: torch.Tensor 10 | 11 | def __init__(self, ignore_index: int = -100, **kwargs) -> None: 12 | super().__init__(**kwargs) 13 | 14 | self.ignore_index = ignore_index 15 | self.add_state('n', torch.tensor(0), dist_reduce_fx='sum', persistent=True) 16 | 17 | def update(self, target: torch.Tensor) -> None: 18 | self.n += target.ne(self.ignore_index).sum() 19 | 20 | def compute(self) -> torch.Tensor: 21 | return self.n 22 | -------------------------------------------------------------------------------- /src/llm_training/lightning/cli/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | from jsonargparse import class_from_function 4 | from transformers import AutoTokenizer, PreTrainedTokenizerBase 5 | 6 | 7 | def _load_tokenizer( 8 | path: str, 9 | pad_token: str | None = None, 10 | padding_side: Literal["left", "right"] | None = None, 11 | **kwargs 12 | ) -> PreTrainedTokenizerBase: 13 | if pad_token is not None: 14 | kwargs['pad_token'] = pad_token 15 | 16 | if padding_side is not None: 17 | kwargs['padding_side'] = padding_side 18 | 19 | return AutoTokenizer.from_pretrained(path, **kwargs) 20 | 21 | 22 | HFTokenizer = class_from_function(_load_tokenizer, name='HFTokenizer') 23 | -------------------------------------------------------------------------------- /src/llm_training/lightning/callbacks/model_checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from lightning import LightningModule, Trainer 5 | from lightning.pytorch.callbacks.model_checkpoint import \ 6 | ModelCheckpoint as LightningModelCheckpoint 7 | 8 | from llm_training.lightning.loggers.wandb import WandbLogger 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ModelCheckpoint(LightningModelCheckpoint): 14 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 15 | if self.dirpath is None and isinstance(trainer.logger, WandbLogger): 16 | self.dirpath = os.path.join(trainer.log_dir, 'checkpoints') 17 | 18 | super().setup(trainer, pl_module, stage) 19 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/llama-3.j2: -------------------------------------------------------------------------------- 1 | {%- set loop_messages = messages %} 2 | {%- for message in loop_messages %} 3 | {%- set header = '<|start_header_id|>' + message.role + '<|end_header_id|>\n\n' %} 4 | {%- set content = message.content | trim + '<|eot_id|>' %} 5 | {%- if loop.index0 == 0 %} 6 | {%- set header = bos_token + header %} 7 | {%- endif %} 8 | {{- header -}} 9 | {%- if message.role == 'assistant' %} 10 | {% generation %} 11 | {{- content -}} 12 | {% endgeneration %} 13 | {%- else %} 14 | {{- content }} 15 | {%- endif %} 16 | {%- endfor %} 17 | {%- if add_generation_prompt %} 18 | {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} 19 | {%- endif %} 20 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/phi-3.j2: -------------------------------------------------------------------------------- 1 | {%- for message in messages %} 2 | {%- set content = message.content | trim %} 3 | {%- if message.role == 'system' and content %} 4 | {{- '<|system|>\n' + content + '<|end|>\n' }} 5 | {%- elif message.role == 'user' %} 6 | {{- '<|user|>\n' + content + '<|end|>\n' }} 7 | {%- elif message.role == 'assistant' %} 8 | {{- '<|assistant|>\n' -}} 9 | {% generation %} 10 | {{- content + '<|end|>\n' -}} 11 | {% endgeneration %} 12 | {%- endif %} 13 | {%- endfor %} 14 | {%- if add_generation_prompt %} 15 | {{- '<|assistant|>\n' }} 16 | {%- else %} 17 | {% generation %} 18 | {{- eos_token -}} 19 | {% endgeneration %} 20 | {%- endif %} 21 | -------------------------------------------------------------------------------- /src/llm_training/data/dummy/dummy_datamodule.py: -------------------------------------------------------------------------------- 1 | from llm_training.data.base_datamodule import BaseDataModule, DatasetDict 2 | 3 | from .dummy_datamodule_config import DummyDataModuleConfig 4 | from .dummy_dataset import DummyDataset 5 | 6 | 7 | class DummyDataModule(BaseDataModule): 8 | config: DummyDataModuleConfig 9 | 10 | def __init__(self, config: DummyDataModuleConfig) -> None: 11 | super().__init__(config) 12 | 13 | def setup(self, stage: str | None = None) -> None: 14 | if self.trainer is not None: 15 | self.config.base_seed = self.trainer.strategy.broadcast(self.config.base_seed) 16 | 17 | super().setup(stage) 18 | 19 | def load_data(self) -> DatasetDict: 20 | return {'train': DummyDataset(self.config)} 21 | -------------------------------------------------------------------------------- /src/llm_training/metrics/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric as _Metric 3 | 4 | 5 | class Metric(_Metric): 6 | def _load_from_state_dict( 7 | self, 8 | state_dict: dict, 9 | prefix: str, 10 | local_metadata: dict, 11 | strict: bool, 12 | missing_keys: list[str], 13 | unexpected_keys: list[str], 14 | error_msgs: list[str] 15 | ) -> None: 16 | for key in self._defaults: 17 | name = prefix + key 18 | tensor = getattr(self, key, None) 19 | if name in state_dict and isinstance(tensor, torch.Tensor): 20 | tensor.copy_(state_dict.pop(name)) 21 | return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) 22 | -------------------------------------------------------------------------------- /src/llm_training/lr_schedulers/constant.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.optim.lr_scheduler import ConstantLR 3 | 4 | from .warmup import WarmupLR 5 | 6 | 7 | class ConstantWarmupLR(WarmupLR): 8 | def __init__( 9 | self, 10 | optimizer: Optimizer, 11 | factor: float = 1.0, 12 | total_iters: int = 0, 13 | num_warmup_steps: int = 0, 14 | last_epoch: int = -1 15 | ) -> None: 16 | super().__init__( 17 | optimizer=optimizer, 18 | lr_scheduler=ConstantLR( 19 | optimizer=optimizer, 20 | factor=factor, 21 | total_iters=total_iters, 22 | last_epoch=last_epoch 23 | ), 24 | num_warmup_epochs=num_warmup_steps, 25 | last_epoch=last_epoch 26 | ) 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "llm-training" 3 | version = "0.2.0" 4 | dependencies = [ 5 | "accelerate>=1.1.1", 6 | "datasets>=3.1.0", 7 | "fire>=0.7.0", 8 | "flash-attn==2.7.0.post2", 9 | "jsonargparse[signatures]==4.34.1", 10 | "lightning==2.4.0", 11 | "liger-kernel==0.4.2", 12 | "omegaconf>=2.3.0", 13 | "protobuf>=5.29.0", 14 | "pydantic>=2.10.3", 15 | "safetensors>=0.4.5", 16 | "sentencepiece>=0.2.0", 17 | "tabulate[widechars]>=0.9.0", 18 | "tokenizers==0.20.3", 19 | "torch==2.5.1", 20 | "transformers==4.46.3", 21 | "triton==3.1.0", 22 | "wandb>=0.18.7", 23 | ] 24 | requires-python = ">=3.10" 25 | 26 | [project.optional-dependencies] 27 | peft = ["peft>=0.13.2"] 28 | 29 | deepspeed = ["deepspeed==0.16.0"] 30 | 31 | [project.scripts] 32 | llm-training = "llm_training.cli.main:main" 33 | -------------------------------------------------------------------------------- /src/llm_training/lr_schedulers/cosine.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.optim.lr_scheduler import CosineAnnealingLR 3 | 4 | from .warmup import WarmupLR 5 | 6 | 7 | class CosineAnnealingWarmupLR(WarmupLR): 8 | def __init__( 9 | self, 10 | optimizer: Optimizer, 11 | num_warmup_steps: int, 12 | num_total_steps: int, 13 | min_lr: float, 14 | last_epoch: int = -1 15 | ) -> None: 16 | super().__init__( 17 | optimizer=optimizer, 18 | lr_scheduler=CosineAnnealingLR( 19 | optimizer=optimizer, 20 | T_max=num_total_steps - num_warmup_steps, 21 | eta_min=min_lr, 22 | last_epoch=last_epoch 23 | ), 24 | num_warmup_epochs=num_warmup_steps, 25 | last_epoch=last_epoch 26 | ) 27 | -------------------------------------------------------------------------------- /src/llm_training/lms/protos/clm_proto.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from llm_training.models.utils.modeling_outputs import CausalLMOutput 7 | 8 | 9 | class CausalLMProto(Protocol): 10 | def get_input_embeddings(self) -> nn.Embedding: ... 11 | 12 | def get_output_embeddings(self) -> nn.Linear: ... 13 | 14 | def set_input_embeddings(self, embedding: nn.Embedding) -> None: ... 15 | 16 | def set_output_embeddings(self, linear: nn.Linear) -> None: ... 17 | 18 | def __call__( 19 | self, 20 | *, 21 | input_ids: torch.Tensor | None = None, 22 | attention_mask: torch.Tensor | None = None, 23 | position_ids: torch.Tensor | None = None, 24 | inputs_embeds: torch.Tensor | None = None, 25 | return_last_hidden_states: bool = False 26 | ) -> CausalLMOutput: ... 27 | -------------------------------------------------------------------------------- /src/llm_training/models/base_model/base_model_config.py: -------------------------------------------------------------------------------- 1 | from types import UnionType 2 | from typing import Any 3 | 4 | import torch 5 | from pydantic import BaseModel, ConfigDict, ValidationInfo, field_validator 6 | 7 | 8 | class BaseModelConfig(BaseModel): 9 | pre_trained_weights: str | None = None 10 | 11 | model_config = ConfigDict(arbitrary_types_allowed=True) 12 | 13 | @field_validator('*') 14 | @classmethod 15 | def validate_torch_dtype(cls, value: Any, info: ValidationInfo) -> Any: 16 | field = cls.model_fields[info.field_name] 17 | is_torch_dtype = isinstance(field.annotation, torch.dtype) 18 | is_torch_dtype |= isinstance(field.annotation, UnionType) and torch.dtype in field.annotation.__args__ 19 | if is_torch_dtype and isinstance(value, str) and value != 'auto': 20 | value = getattr(torch, value) 21 | return value 22 | -------------------------------------------------------------------------------- /src/llm_training/ops/swiglu_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def swiglu( 6 | x: torch.Tensor, 7 | *, 8 | w3: torch.Tensor, 9 | w1w2: torch.Tensor | None = None, 10 | w1: torch.Tensor | None = None, 11 | w2: torch.Tensor | None = None, 12 | b1: torch.Tensor | None = None, 13 | b2: torch.Tensor | None = None, 14 | b3: torch.Tensor | None = None, 15 | b1b2: torch.Tensor | None = None 16 | ) -> torch.Tensor: 17 | assert ( 18 | w1w2 is not None and w1 is None and w2 is None 19 | or w1w2 is None and w1 is not None and w2 is not None 20 | ) 21 | 22 | if w1w2 is not None: 23 | x1x2 = F.linear(x, w1w2, b1b2) 24 | x1, x2 = torch.chunk(x1x2, chunks=2, dim=-1) 25 | else: 26 | x1 = F.linear(x, w1, b1) 27 | x2 = F.linear(x, w2, b2) 28 | 29 | return F.linear(F.silu(x1) * x2, w3, b3) 30 | -------------------------------------------------------------------------------- /src/llm_training/data/dummy/dummy_datamodule_config.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from pydantic import Field, ValidationInfo, field_validator 4 | 5 | from llm_training.data.base_datamodule_config import BaseDataModuleConfig 6 | 7 | 8 | class DummyDataModuleConfig(BaseDataModuleConfig): 9 | vocab_size: int 10 | max_length: int 11 | num_samples: int | None = None 12 | num_tokens: int | None = None 13 | base_seed: int | None = Field(None, validate_default=True) 14 | 15 | @field_validator('num_tokens') 16 | @classmethod 17 | def validate_num_tokens(cls, value: int | None, info: ValidationInfo): 18 | assert info.data['num_samples'] is None 19 | return value 20 | 21 | @field_validator('base_seed') 22 | @classmethod 23 | def validate_base_seed(cls, value: int | None): 24 | if value is None: 25 | value = random.randrange(0, 999999) 26 | return value 27 | -------------------------------------------------------------------------------- /src/llm_training/utils/decorators.py: -------------------------------------------------------------------------------- 1 | from functools import update_wrapper 2 | import inspect 3 | from typing import Callable, ParamSpec, TypeVar 4 | 5 | P = ParamSpec('P') 6 | T = TypeVar('T') 7 | 8 | 9 | def copy_method_signature(ref_method: Callable[P, T], passthrough: bool = True) -> Callable[[Callable], Callable[P, T]]: 10 | def decorator(method: Callable): 11 | wrapped = method 12 | if passthrough: 13 | def wrapped(self, *args, _mro_idx: int = 0, **kwargs): 14 | f = getattr(super(type(self).mro()[_mro_idx], self), method.__name__) 15 | if '_mro_idx' in inspect.signature(f, follow_wrapped=False).parameters: 16 | kwargs['_mro_idx'] = _mro_idx + 1 17 | return f(*args, **kwargs) 18 | wrapped = update_wrapper(wrapped, method) 19 | wrapped = update_wrapper(wrapped, ref_method, [], []) 20 | return wrapped 21 | return decorator 22 | -------------------------------------------------------------------------------- /src/llm_training/models/hf_compat_model/hf_compat_config.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import torch 4 | from pydantic import Field 5 | 6 | from llm_training.models.base_model.base_model_config import BaseModelConfig 7 | 8 | 9 | class HFCompatModelConfig(BaseModelConfig): 10 | hf_path: str | None = None 11 | hf_tokenizer_path: str | None = None 12 | 13 | torch_dtype: str | torch.dtype = 'auto' 14 | trust_remote_code: bool = False 15 | low_cpu_mem_usage: bool = True 16 | revision: str = 'main' 17 | attn_implementation: Literal['eager', 'sdpa', 'flash_attention_2'] | None = None 18 | hf_extra_kwargs: dict = Field(default_factory=dict) 19 | 20 | load_hf_weights: bool = True 21 | 22 | @property 23 | def _attn_implementation(self) -> str: 24 | if self.attn_implementation is None: 25 | return 'flash_attention_2' if torch.cuda.get_device_capability()[0] >= 8 else 'sdpa' 26 | return self.attn_implementation 27 | -------------------------------------------------------------------------------- /src/llm_training/lms/model_provider.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from jsonargparse.typing import final 4 | 5 | from llm_training.models.base_model import BaseModel, BaseModelConfig 6 | 7 | 8 | @final 9 | class ModelProvider: 10 | def __init__( 11 | self, 12 | model_class: type[BaseModel], 13 | model_config: dict[str, Any] | BaseModelConfig 14 | ) -> None: 15 | self.model_class = model_class 16 | if isinstance(model_config, dict): 17 | self.model_config = model_class.config_class.model_validate(model_config) 18 | else: 19 | self.model_config = model_config 20 | 21 | def __call__(self) -> BaseModel: 22 | return self.model_class(self.model_config) 23 | 24 | def __repr__(self) -> str: 25 | return ( 26 | f'{self.__class__.__name__}(\n' 27 | f' model_class={self.model_class},\n' 28 | f' model_config={repr(self.model_config)}\n' 29 | ')' 30 | ) 31 | -------------------------------------------------------------------------------- /src/llm_training/models/llama/llama_config.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Literal 3 | 4 | from llm_training.models.hf_compat_model import HFCompatModelConfig 5 | 6 | 7 | class LlamaConfig(HFCompatModelConfig): 8 | vocab_size: int = 32000 9 | hidden_size: int = 4096 10 | intermediate_size: int = 11008 11 | num_hidden_layers: int = 32 12 | num_attention_heads: int = 32 13 | num_key_value_heads: int = 32 14 | # hidden_act: str = 'silu' 15 | max_position_embeddings: int = 4096 16 | initializer_range: float = 0.02 17 | rms_norm_eps: float = 1e-6 18 | pad_token_id: int | None = None 19 | tie_word_embeddings: bool = False 20 | rope_theta: float = 10000.0 21 | attention_bias: bool = False 22 | attention_dropout: float = 0.0 23 | 24 | mlp_bias: bool = False 25 | rope_scaling: dict[str, Any] | None = None 26 | 27 | pad_token_id: int | None = None 28 | bos_token_id: int = 1 29 | eos_token_id: int = 2 30 | 31 | enable_gradient_checkpointing: bool = False 32 | recompute_granularity: Literal['full', 'selective'] = 'full' 33 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/llama-2.j2: -------------------------------------------------------------------------------- 1 | {%- if messages[0]['role'] == 'system' %} 2 | {%- set loop_messages = messages[1:] %} 3 | {%- set system_message = messages[0]['content'] %} 4 | {%- else %} 5 | {%- set loop_messages = messages %} 6 | {%- set system_message = false %} 7 | {%- endif %} 8 | {%- for message in loop_messages %} 9 | {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} 10 | {{- raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} 11 | {%- endif %} 12 | {%- if loop.index0 == 0 and system_message != false %} 13 | {%- set content = '<>\n' + system_message + '\n<>\n\n' + message['content'] %} 14 | {%- else %} 15 | {%- set content = message['content'] %} 16 | {%- endif %} 17 | {%- if message['role'] == 'user' %} 18 | {{- bos_token + '[INST] ' + content.strip() + ' [/INST]' }} 19 | {%- elif message['role'] == 'assistant' %} 20 | {%- generation %} 21 | {{- ' ' + content.strip() + ' ' + eos_token }} 22 | {%- endgeneration %} 23 | {%- endif %} 24 | {%- endfor %} 25 | -------------------------------------------------------------------------------- /src/llm_training/data/preference_tuning/preference_tuning_datamodule_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from enum import auto 3 | 4 | from pydantic import field_validator 5 | from transformers import PreTrainedTokenizerBase 6 | 7 | from llm_training.data.chat_templates import get_chat_template 8 | from llm_training.data.hf_based.hf_based_datamodule_config import \ 9 | HFBasedDataModuleConfig 10 | from llm_training.utils.str_enum import StrEnum 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class OverlongHandlingMethod(StrEnum): 16 | DROP = auto() 17 | 18 | 19 | class PreferenceTuningDataModuleConfig(HFBasedDataModuleConfig): 20 | tokenizer: PreTrainedTokenizerBase 21 | chat_template: str | None = None 22 | max_length: int | None = None 23 | overlong_handling_method: OverlongHandlingMethod | str = OverlongHandlingMethod.DROP 24 | pad_to_multiple_of: int | None = None 25 | pad_to_max_length: bool = False 26 | 27 | @field_validator('chat_template') 28 | @classmethod 29 | def validate_chat_template(cls, value: str | None) -> str | None: 30 | if value is not None: 31 | value = get_chat_template(value) 32 | return value 33 | -------------------------------------------------------------------------------- /src/llm_training/data/dummy/dummy_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from .dummy_datamodule_config import DummyDataModuleConfig 7 | 8 | 9 | class DummyDataset(Dataset): 10 | def __init__(self, config: DummyDataModuleConfig) -> None: 11 | super().__init__() 12 | 13 | self.config = config 14 | self.base_seed = config.base_seed 15 | 16 | if self.config.num_samples is not None: 17 | self.num_samples = self.config.num_samples 18 | elif self.config.num_tokens is not None: 19 | self.num_samples = math.ceil(self.config.num_tokens / self.config.max_length) 20 | 21 | def __len__(self) -> int: 22 | return self.num_samples 23 | 24 | def __getitem__(self, index: int) -> dict[str, torch.Tensor]: 25 | generator = torch.Generator() 26 | generator.manual_seed(self.base_seed + index) 27 | input_ids = torch.randint(0, self.config.vocab_size, (self.config.max_length,), generator=generator) 28 | return dict( 29 | input_ids=input_ids, 30 | attention_mask=torch.ones_like(input_ids), 31 | position_ids=torch.arange(input_ids.size(0)), 32 | labels=input_ids 33 | ) 34 | -------------------------------------------------------------------------------- /src/llm_training/ops/liger_kernel/swiglu_op.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from liger_kernel.ops.swiglu import LigerSiLUMulFunction 4 | 5 | 6 | def swiglu( 7 | x: torch.Tensor, 8 | *, 9 | w3: torch.Tensor, 10 | w1w2: torch.Tensor | None = None, 11 | w1: torch.Tensor | None = None, 12 | w2: torch.Tensor | None = None, 13 | b1: torch.Tensor | None = None, 14 | b2: torch.Tensor | None = None, 15 | b3: torch.Tensor | None = None, 16 | b1b2: torch.Tensor | None = None 17 | ) -> torch.Tensor: 18 | assert ( 19 | w1w2 is not None and w1 is None and w2 is None 20 | or w1w2 is None and w1 is not None and w2 is not None 21 | ) 22 | 23 | if w1w2 is not None: 24 | x1x2 = F.linear(x, w1w2, b1b2) 25 | x1, x2 = torch.chunk(x1x2, chunks=2, dim=-1) 26 | else: 27 | x1 = F.linear(x, w1, b1) 28 | x2 = F.linear(x, w2, b2) 29 | 30 | if x.device.type == 'cuda': 31 | return F.linear(LigerSiLUMulFunction.apply(x1, x2), w3) 32 | 33 | return F.linear(F.silu(x1) * x2, w3, b3) 34 | 35 | 36 | def silu_mul(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 37 | if x1.device.type == 'cuda': 38 | return LigerSiLUMulFunction.apply(x1, x2) 39 | return F.silu(x1) * x2 40 | -------------------------------------------------------------------------------- /docs/model_implementations.md: -------------------------------------------------------------------------------- 1 | # Model Implementations 2 | 3 | ## Optimized Models 4 | 5 | `LLM-Training` has implemented several models that are more efficient compared to those in [Hugging Face](https://github.com/huggingface/transformers) and support additional features. 6 | Therefore, it's recommended to prioritize using these models. 7 | 8 | The optimized models currently implemented: 9 | 10 | - [Phi-3](/src/llm_training/models/phi3/phi3_model.py) 11 | - [x] RMS Norm Fusion 12 | - [x] SwiGLU Fusion 13 | - [x] RoPE Fusion 14 | - [x] RoPE Caching 15 | - [x] Selective Activation Checkpointing 16 | - [LLaMA](/src/llm_training/models/llama/llama_model.py) 17 | - [x] RMS Norm Fusion 18 | - [x] SwiGLU Fusion 19 | - [x] RoPE Fusion 20 | - [x] RoPE Caching 21 | - [x] Selective Activation Checkpointing 22 | 23 | ## Hugging Face Models 24 | 25 | If you need to use a model implemented by Hugging Face, you can use [`HFCausalLM`](/src/llm_training/models/hf_causal_lm/hf_causal_lm.py). 26 | 27 | ```yaml 28 | ... 29 | model: 30 | class_path: llm_training.lms.CLM # or other objective 31 | init_args.config: 32 | model: 33 | model_class: llm_training.models.HFCausalLM 34 | model_config: 35 | hf_path: 36 | torch_dtype: ??? 37 | attn_implementation: ??? 38 | enable_gradient_checkpointing: ??? 39 | ... 40 | ``` 41 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | class _ChatTemplates: 7 | def _get_path_by_name(self, name: str) -> Path: 8 | p = Path(__file__).parent / name 9 | return p.with_suffix(f'{p.suffix}.j2') 10 | 11 | def __getitem__(self, name: str) -> str: 12 | if name not in self: 13 | raise KeyError(f'Template `{name}` is not found') 14 | 15 | with open(self._get_path_by_name(name)) as f: 16 | return f.read() 17 | 18 | def __contains__(self, name: str) -> bool: 19 | return self._get_path_by_name(name).exists() 20 | 21 | 22 | CHAT_TEMPLATES = _ChatTemplates() 23 | 24 | def get_chat_template(chat_template: str) -> str: 25 | if Path(chat_template).exists(): 26 | logger.info(f'Found template file at `{chat_template}`.') 27 | with open(chat_template) as f: 28 | chat_template = f.read() 29 | elif chat_template in CHAT_TEMPLATES: 30 | logger.info(f'Using pre-defined chat template `{chat_template}`.') 31 | chat_template = CHAT_TEMPLATES[chat_template] 32 | else: 33 | logger.warn( 34 | '`chat_template` is being used directly as a chat template.\n' 35 | 'If this is not the behavior you expected, please change the value to a name of pre-defined chat template or a file path.' 36 | ) 37 | return chat_template 38 | -------------------------------------------------------------------------------- /src/llm_training/lr_schedulers/linear.py: -------------------------------------------------------------------------------- 1 | from torch.optim import Optimizer 2 | from torch.optim.lr_scheduler import LRScheduler 3 | 4 | 5 | class LinearWarmupLR(LRScheduler): 6 | def __init__( 7 | self, 8 | optimizer: Optimizer, 9 | num_warmup_steps: int, 10 | num_total_steps: int, 11 | min_lr: float | list[float], 12 | last_epoch: int = -1 13 | ) -> None: 14 | self.num_warmup_steps = num_warmup_steps 15 | self.num_total_steps = num_total_steps 16 | self.min_lr = min_lr 17 | 18 | super().__init__(optimizer, last_epoch=last_epoch) 19 | 20 | assert isinstance(self.min_lr, float) or len(min_lr) == len(self.base_lrs) 21 | 22 | @property 23 | def min_lrs(self) -> list[float]: 24 | if isinstance(self.min_lr, float): 25 | return [self.min_lr] * len(self.base_lrs) 26 | return self.min_lr 27 | 28 | def get_lr(self): 29 | if self.last_epoch < self.num_warmup_steps: 30 | return [(self.last_epoch + 1) / (self.num_warmup_steps + 1) * lr for lr in self.base_lrs] 31 | 32 | lrs = [] 33 | for lr, min_lr in zip(self.base_lrs, self.min_lrs): 34 | factor = (self.num_total_steps - self.last_epoch) / (self.num_total_steps - self.num_warmup_steps) 35 | min_lr_factor = min_lr / lr 36 | factor = (1.0 - min_lr_factor) * (factor - 0.0) + min_lr_factor 37 | lrs.append(lr * factor) 38 | 39 | return lrs 40 | -------------------------------------------------------------------------------- /src/llm_training/metrics/perplexity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.functional.text.perplexity import (_perplexity_compute, 3 | _perplexity_update) 4 | 5 | from .metric import Metric 6 | 7 | 8 | class Perplexity(Metric): 9 | is_differentiable: bool = True 10 | higher_is_better: bool = False 11 | full_state_update: bool = False 12 | total_log_probs: torch.Tensor 13 | count: torch.Tensor 14 | 15 | def __init__( 16 | self, 17 | ignore_index: int | None = None, 18 | **kwargs 19 | ) -> None: 20 | super().__init__(**kwargs) 21 | 22 | if ignore_index is not None and not isinstance(ignore_index, int): 23 | raise ValueError(f"Argument `ignore_index` expected to either be `None` or an `int` but got {ignore_index}") 24 | 25 | self.ignore_index = ignore_index 26 | self.add_state('total_log_probs', default=torch.tensor(0.0), dist_reduce_fx='sum') 27 | self.add_state('count', default=torch.tensor(0.0), dist_reduce_fx='sum') 28 | 29 | def update(self, preds_or_loss: torch.Tensor, target: torch.Tensor | None = None) -> None: 30 | if preds_or_loss.dim() == 0: 31 | self.total_log_probs += preds_or_loss 32 | self.count += 1 33 | else: 34 | total_log_probs, count = _perplexity_update(preds_or_loss, target, self.ignore_index) 35 | self.total_log_probs += total_log_probs 36 | self.count += count 37 | 38 | def compute(self) -> torch.Tensor: 39 | return _perplexity_compute(self.total_log_probs, self.count) 40 | -------------------------------------------------------------------------------- /scripts/pre_process_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from typing import TextIO 4 | 5 | from llm_training.data import * 6 | from llm_training.lightning.cli import * 7 | from llm_training.models import * 8 | 9 | 10 | class OutputStreamRedirector: 11 | def __init__(self, *streams: TextIO) -> None: 12 | self._streams = streams 13 | 14 | def write(self, s: str) -> int: 15 | n = 0 16 | for stream in self._streams: 17 | n += stream.write(s) 18 | return n 19 | 20 | def flush(self) -> None: 21 | for s in self._streams: 22 | s.flush() 23 | 24 | 25 | def main(): 26 | cli = LightningCLI(run=False) 27 | 28 | datamodule = cli.datamodule 29 | 30 | assert isinstance(datamodule, HFBasedDataModule) 31 | 32 | config = datamodule.config 33 | 34 | pre_processed_data_path = config.pre_processed_data_path 35 | 36 | assert pre_processed_data_path is not None, "`pre_processed_data_path` should not be `None`." 37 | 38 | if not os.path.exists(pre_processed_data_path) or len(os.listdir(pre_processed_data_path)) == 0: 39 | config.pre_processed_data_path = None 40 | datamodule.setup() 41 | datamodule.save_pre_processed_data(pre_processed_data_path) 42 | else: 43 | print(f'`pre_processed_data_path="{pre_processed_data_path}"` is not empty, skipping.') 44 | datamodule.setup() 45 | 46 | with open(os.path.join(pre_processed_data_path, 'info.txt'), 'w') as f: 47 | datamodule.print_dataset_info(file=OutputStreamRedirector(sys.stdout, f)) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | -------------------------------------------------------------------------------- /src/llm_training/ops/liger_kernel/cross_entropy_op.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction 6 | from liger_kernel.ops.fused_linear_cross_entropy import \ 7 | LigerFusedLinearCrossEntropyFunction 8 | 9 | 10 | def cross_entropy( 11 | logits: torch.Tensor, 12 | labels: torch.Tensor, 13 | ignore_index: int = -100, 14 | reduction: Literal['mean'] = 'mean' 15 | ) -> torch.Tensor: 16 | assert reduction == 'mean' 17 | 18 | if logits.dim() == 3 and labels.dim() == 2: 19 | logits = logits.flatten(end_dim=1) 20 | labels = labels.flatten(end_dim=1) 21 | 22 | if logits.device.type != 'cuda': 23 | return F.cross_entropy( 24 | logits, 25 | labels, 26 | ignore_index=ignore_index 27 | ) 28 | 29 | return LigerCrossEntropyFunction.apply( 30 | logits, 31 | labels, 32 | ignore_index 33 | )[0] 34 | 35 | 36 | def fused_linear_cross_entropy( 37 | hidden_states: torch.Tensor, 38 | weight: torch.Tensor, 39 | labels: torch.Tensor, 40 | ignore_index: int = -100, 41 | reduction: Literal['mean'] = 'mean' 42 | ) -> torch.Tensor: 43 | assert reduction == 'mean' 44 | 45 | if hidden_states.dim() == 3 and labels.dim() == 2: 46 | hidden_states = hidden_states.flatten(end_dim=1) 47 | labels = labels.flatten(end_dim=1) 48 | 49 | return LigerFusedLinearCrossEntropyFunction.apply( 50 | hidden_states, 51 | weight, 52 | labels, 53 | ignore_index 54 | ) 55 | -------------------------------------------------------------------------------- /config/examples/llama-3.1/llama-3.1-8b_tp_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.FSDP2Strategy 8 | init_args: 9 | tensor_parallel_size: 8 10 | reshard_after_forward: false 11 | precision: bf16-true 12 | logger: 13 | class_path: llm_training.lightning.WandbLogger 14 | init_args: 15 | name: llama-3.1-8b_tp_example 16 | job_type: example 17 | project: llm-training 18 | save_dir: logs 19 | save_code: true 20 | max_epochs: 1 21 | val_check_interval: null 22 | accumulate_grad_batches: 1 23 | gradient_clip_val: 1.0 24 | callbacks: 25 | - class_path: LearningRateMonitor 26 | - class_path: llm_training.lightning.ModelCheckpoint 27 | init_args: 28 | save_on_train_epoch_end: true 29 | save_top_k: 1 30 | 31 | model: 32 | class_path: llm_training.lms.CLM 33 | init_args.config: 34 | model: 35 | model_class: llm_training.models.Llama 36 | model_config: 37 | hf_path: meta-llama/Llama-3.1-8B 38 | enable_gradient_checkpointing: true 39 | 40 | optim: 41 | optimizer_class: torch.optim.AdamW 42 | optimizer_kwargs: 43 | lr: 3e-5 44 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 45 | lr_scheduler_kwargs: 46 | num_warmup_steps: 10000 47 | min_lr: 3e-6 48 | 49 | data: 50 | class_path: llm_training.data.DummyDataModule 51 | init_args.config: 52 | batch_size: 1 53 | vocab_size: 128256 54 | max_length: 131072 55 | num_tokens: 50_000_000_000 # 50B 56 | -------------------------------------------------------------------------------- /config/examples/phi-3/phi-3-mini_tp_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.FSDP2Strategy 8 | init_args: 9 | tensor_parallel_size: 8 10 | reshard_after_forward: false 11 | precision: bf16-true 12 | logger: 13 | class_path: llm_training.lightning.WandbLogger 14 | init_args: 15 | name: phi-3-mini-128k-instruct_tp_example 16 | job_type: example 17 | project: llm-training 18 | save_dir: logs 19 | save_code: true 20 | max_epochs: 1 21 | val_check_interval: null 22 | accumulate_grad_batches: 1 23 | gradient_clip_val: 1.0 24 | callbacks: 25 | - class_path: LearningRateMonitor 26 | - class_path: llm_training.lightning.ModelCheckpoint 27 | init_args: 28 | save_on_train_epoch_end: true 29 | save_top_k: 1 30 | 31 | model: 32 | class_path: llm_training.lms.CLM 33 | init_args.config: 34 | model: 35 | model_class: llm_training.models.Phi3 36 | model_config: 37 | hf_path: microsoft/Phi-3-mini-128k-instruct 38 | enable_gradient_checkpointing: true 39 | 40 | optim: 41 | optimizer_class: torch.optim.AdamW 42 | optimizer_kwargs: 43 | lr: 3e-5 44 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 45 | lr_scheduler_kwargs: 46 | num_warmup_steps: 10000 47 | min_lr: 3e-6 48 | 49 | data: 50 | class_path: llm_training.data.DummyDataModule 51 | init_args.config: 52 | batch_size: 1 53 | vocab_size: 32064 54 | max_length: 131072 55 | num_tokens: 50_000_000_000 # 50B 56 | -------------------------------------------------------------------------------- /src/llm_training/lms/base_lm_config.py: -------------------------------------------------------------------------------- 1 | from types import UnionType 2 | from typing import Any 3 | 4 | import torch 5 | from pydantic import BaseModel as PyDanticBaseModel 6 | from pydantic import ConfigDict, Field, ValidationInfo, field_validator 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import LRScheduler 9 | 10 | from llm_training.lr_schedulers import ConstantWarmupLR 11 | 12 | 13 | class BaseOptimizerConfig(PyDanticBaseModel): 14 | optimizer_class: type[Optimizer] 15 | optimizer_kwargs: dict[str, Any] 16 | lr_scheduler_class: type[LRScheduler] = ConstantWarmupLR 17 | lr_scheduler_kwargs: dict[str, Any] = Field(default_factory=dict) 18 | 19 | model_config = ConfigDict(arbitrary_types_allowed=True) 20 | 21 | 22 | class BaseLightningModuleConfig(PyDanticBaseModel): 23 | init_weights: bool = False 24 | load_weights: bool = True 25 | pre_trained_weights: str | None = None 26 | optim: BaseOptimizerConfig | None = None 27 | frozen_modules: list[str] | None = None 28 | log_grad_norm: bool = True 29 | 30 | model_config = ConfigDict( 31 | arbitrary_types_allowed=True, 32 | protected_namespaces=() 33 | ) 34 | 35 | @field_validator('*') 36 | @classmethod 37 | def validate_torch_dtype(cls, value: Any, info: ValidationInfo) -> Any: 38 | field = cls.model_fields[info.field_name] 39 | is_torch_dtype = isinstance(field.annotation, torch.dtype) 40 | is_torch_dtype |= isinstance(field.annotation, UnionType) and torch.dtype in field.annotation.__args__ 41 | if is_torch_dtype and isinstance(value, str) and value != 'auto': 42 | value = getattr(torch, value) 43 | return value 44 | -------------------------------------------------------------------------------- /src/llm_training/data/resumable_dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, Iterator 2 | 3 | from lightning import Trainer 4 | from torch.utils.data import BatchSampler, DataLoader 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class ResumableBatchSampler(BatchSampler): 9 | def __init__( 10 | self, 11 | sampler: Sampler[int] | Iterable[int], 12 | batch_size: int, 13 | drop_last: bool, 14 | trainer: Trainer | None = None 15 | ) -> None: 16 | super().__init__(sampler, batch_size, drop_last) 17 | 18 | self.trainer = trainer 19 | 20 | def __iter__(self) -> Iterator[list[int]]: 21 | for i, indices in enumerate(super().__iter__()): 22 | if self.trainer is not None and i < self.trainer.fit_loop.batch_idx: 23 | continue 24 | 25 | yield indices 26 | 27 | 28 | class ResumableDataLoader(DataLoader): 29 | def __init__( 30 | self, 31 | trainer: Trainer | None = None, 32 | *args, 33 | **kwargs 34 | ): 35 | self.trainer = trainer 36 | 37 | super().__init__(*args, **kwargs) 38 | 39 | def __setattr__(self, attr: str, val: Any): 40 | if attr == 'batch_sampler': 41 | assert isinstance(val, BatchSampler) 42 | 43 | val = ResumableBatchSampler( 44 | sampler=val.sampler, 45 | batch_size=val.batch_size, 46 | drop_last=val.drop_last, 47 | trainer=self.trainer 48 | ) 49 | 50 | return super().__setattr__(attr, val) 51 | 52 | def state_dict(self) -> dict[str, Any]: 53 | return {} 54 | 55 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 56 | ... 57 | -------------------------------------------------------------------------------- /src/llm_training/lr_schedulers/warmup.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch.optim.lr_scheduler import LRScheduler 4 | from torch.optim.optimizer import Optimizer 5 | 6 | 7 | class WarmupLR(LRScheduler): 8 | def __init__( 9 | self, 10 | optimizer: Optimizer, 11 | lr_scheduler: LRScheduler, 12 | num_warmup_epochs: int, 13 | last_epoch: int = -1 14 | ) -> None: 15 | self.lr_scheduler = lr_scheduler 16 | self.num_warmup_epochs = num_warmup_epochs 17 | 18 | super().__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self) -> list[float]: 21 | if self.last_epoch >= self.num_warmup_epochs: 22 | return self.lr_scheduler.get_lr() 23 | return [(self.last_epoch + 1) / self.num_warmup_epochs * lr for lr in self.base_lrs] 24 | 25 | def step(self, epoch: int | None = None) -> None: 26 | if self.last_epoch == self.num_warmup_epochs: 27 | self.lr_scheduler.base_lrs = self.base_lrs 28 | 29 | if self.last_epoch >= self.num_warmup_epochs: 30 | epoch = None if epoch is None else epoch - self.num_warmup_epochs 31 | self.lr_scheduler.step(epoch) 32 | self._last_lr = self.lr_scheduler.get_last_lr() 33 | return super().step(epoch) 34 | 35 | def state_dict(self) -> dict[str, Any]: 36 | state_dict = {k: v for k, v in self.__dict__.items() if k not in ['optimizer', 'lr_scheduler']} 37 | state_dict['lr_scheduler_state_dict'] = self.lr_scheduler.state_dict() 38 | return state_dict 39 | 40 | def load_state_dict(self, state_dict: dict[str, Any]) -> None: 41 | lr_scheduler_state_dict = state_dict.pop('lr_scheduler_state_dict') 42 | self.lr_scheduler.load_state_dict(lr_scheduler_state_dict) 43 | super().load_state_dict(state_dict) 44 | -------------------------------------------------------------------------------- /src/llm_training/data/pre_training/pre_training_datamodule_config.py: -------------------------------------------------------------------------------- 1 | from enum import auto 2 | 3 | from pydantic import Field, ValidationInfo, field_validator 4 | from transformers import PreTrainedTokenizerBase 5 | 6 | from llm_training.data.hf_based import HFBasedDataModuleConfig 7 | from llm_training.utils.str_enum import StrEnum 8 | 9 | 10 | class PackingMethod(StrEnum): 11 | NO_PACKING = auto() 12 | NAIVE_PACKING = auto() 13 | BEST_FIT_BIN_PACKING = auto() 14 | 15 | 16 | class PreTrainingDataModuleConfig(HFBasedDataModuleConfig): 17 | tokenizer: PreTrainedTokenizerBase 18 | max_length: int | None = None 19 | stride: int | None = None 20 | packing_method: PackingMethod | str = PackingMethod.NAIVE_PACKING 21 | sample_rate: dict[str, float] = Field(default_factory=dict) 22 | pre_processing_batch_size: int = 1000 23 | pad_to_multiple_of: int | None = None 24 | pad_to_max_length: bool = False 25 | 26 | @field_validator('packing_method') 27 | @classmethod 28 | def validate_packing_method(cls, value: PackingMethod | str, info: ValidationInfo) -> PackingMethod: 29 | value = PackingMethod(value.lower()) 30 | assert value in (PackingMethod.NAIVE_PACKING, PackingMethod.BEST_FIT_BIN_PACKING) or info.data['max_length'] is not None, \ 31 | "You must set `max_length` to packing data" 32 | return value 33 | 34 | @field_validator('stride') 35 | @classmethod 36 | def validate_stride(cls, value: int | None, info: ValidationInfo) -> int | None: 37 | max_length = info.data['max_length'] 38 | 39 | if value is None: 40 | value = max_length 41 | else: 42 | assert max_length is not None, "You must also set `max_length` to use `stride`" 43 | assert value <= max_length, "`stride` must be <= `max_length`" 44 | 45 | return value 46 | -------------------------------------------------------------------------------- /config/examples/phi-3/phi-3-mini_orpo_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.DeepSpeedStrategy 8 | init_args: 9 | stage: 2 10 | precision: bf16-true 11 | logger: 12 | class_path: llm_training.lightning.WandbLogger 13 | init_args: 14 | name: phi-3-mini-128k-instruct_orpo_example 15 | job_type: example 16 | project: llm-training 17 | save_dir: logs 18 | save_code: true 19 | max_epochs: 3 20 | val_check_interval: null 21 | accumulate_grad_batches: 1 22 | gradient_clip_val: 1.0 23 | callbacks: 24 | - class_path: LearningRateMonitor 25 | - class_path: llm_training.lightning.ModelCheckpoint 26 | init_args: 27 | save_on_train_epoch_end: true 28 | save_top_k: 1 29 | 30 | model: 31 | class_path: llm_training.lms.ORPO 32 | init_args.config: 33 | model: 34 | model_class: llm_training.models.Phi3 35 | model_config: 36 | hf_path: microsoft/Phi-3-mini-128k-instruct 37 | enable_gradient_checkpointing: true 38 | 39 | beta: 0.1 40 | 41 | optim: 42 | optimizer_class: torch.optim.AdamW 43 | optimizer_kwargs: 44 | lr: 1e-6 45 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 46 | lr_scheduler_kwargs: 47 | num_warmup_steps: 100 48 | min_lr: 1e-7 49 | 50 | data: 51 | class_path: llm_training.data.PreferenceTuningDataModule 52 | init_args.config: 53 | dataset_kwargs: 54 | path: trl-internal-testing/Anthropic-hh-rlhf-processed 55 | tokenizer: 56 | class_path: HFTokenizer 57 | init_args: 58 | path: microsoft/Phi-3-mini-128k-instruct 59 | chat_template: phi-3 60 | batch_size: 1 61 | max_length: 4096 62 | pad_to_multiple_of: 64 63 | validation_split: null 64 | num_proc: 4 65 | num_workers: 4 66 | enable_cache: true 67 | -------------------------------------------------------------------------------- /config/examples/phi-3/phi-3-mini_dpo_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.DeepSpeedStrategy 8 | init_args: 9 | stage: 2 10 | exclude_frozen_parameters: true 11 | precision: bf16-true 12 | logger: 13 | class_path: llm_training.lightning.WandbLogger 14 | init_args: 15 | name: phi-3-mini-128k-instruct_dpo_example 16 | job_type: example 17 | project: llm-training 18 | save_dir: logs 19 | save_code: true 20 | max_epochs: 3 21 | val_check_interval: null 22 | accumulate_grad_batches: 1 23 | gradient_clip_val: 1.0 24 | callbacks: 25 | - class_path: LearningRateMonitor 26 | - class_path: llm_training.lightning.ModelCheckpoint 27 | init_args: 28 | save_on_train_epoch_end: true 29 | save_top_k: 1 30 | 31 | model: 32 | class_path: llm_training.lms.DPO 33 | init_args.config: 34 | model: 35 | model_class: llm_training.models.Phi3 36 | model_config: 37 | hf_path: microsoft/Phi-3-mini-128k-instruct 38 | enable_gradient_checkpointing: true 39 | 40 | beta: 0.1 41 | label_smoothing: 0.0 42 | 43 | optim: 44 | optimizer_class: deepspeed.ops.adam.FusedAdam 45 | optimizer_kwargs: 46 | lr: 1e-6 47 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 48 | lr_scheduler_kwargs: 49 | num_warmup_steps: 100 50 | min_lr: 1e-7 51 | 52 | data: 53 | class_path: llm_training.data.PreferenceTuningDataModule 54 | init_args.config: 55 | dataset_kwargs: 56 | path: trl-internal-testing/Anthropic-hh-rlhf-processed 57 | tokenizer: 58 | class_path: HFTokenizer 59 | init_args: 60 | path: microsoft/Phi-3-mini-128k-instruct 61 | chat_template: phi-3 62 | batch_size: 1 63 | max_length: 4096 64 | pad_to_multiple_of: 64 65 | validation_split: null 66 | num_proc: 4 67 | num_workers: 4 68 | enable_cache: true 69 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | # export TRITON_PTXAS_PATH=$CONDA_PREFIX/bin/ptxas 2 | # export TRITON_CUOBJDUMP_PATH=$CONDA_PREFIX/bin/cuobjdump 3 | # export TRITON_NVDISASM_PATH=$CONDA_PREFIX/bin/nvdisasm 4 | 5 | JOB_NAME= 6 | PARTITION= 7 | ACCOUNT= 8 | NODES= 9 | GPUS_PER_NODE= 10 | CPUS_PER_TASK= 11 | EXTRA_ARGS=( 12 | ) 13 | CONFIG=null 14 | CKPT_PATH=null 15 | 16 | 17 | COMMAND=( 18 | srun llm-training fit 19 | --config $CONFIG 20 | --trainer.num_nodes $NODES 21 | --ckpt_path $CKPT_PATH 22 | ) 23 | 24 | COMMAND=${COMMAND[@]} 25 | 26 | echo $COMMAND 27 | 28 | SBATCH_ARGS=( 29 | --partition $PARTITION 30 | --gpus-per-node $GPUS_PER_NODE 31 | --cpus-per-task $CPUS_PER_TASK 32 | --ntasks-per-node $GPUS_PER_NODE 33 | --account $ACCOUNT 34 | --nodes $NODES 35 | ) 36 | 37 | if [[ $JOB_NAME ]]; 38 | then 39 | SBATCH_ARGS+=(--job-name $JOB_NAME) 40 | fi 41 | 42 | SBATCH_ARGS+=(${EXTRA_ARGS[@]}) 43 | SBATCH_ARGS=${SBATCH_ARGS[@]} 44 | 45 | SBATCH_OUTPUT=$(sbatch $SBATCH_ARGS --wrap "$COMMAND") 46 | 47 | echo $SBATCH_OUTPUT 48 | 49 | if [[ $SBATCH_OUTPUT != "Submitted batch job"* ]]; 50 | then 51 | exit 52 | fi 53 | 54 | JOB_ID=$(echo $SBATCH_OUTPUT | sed "s/Submitted batch job //") 55 | 56 | echo "Waiting for the job to start" 57 | while [[ $JOB_STATE != "RUNNING" ]] 58 | do 59 | JOB_STATE=$(squeue -j $JOB_ID -h -o %T) 60 | sleep 1 61 | done 62 | 63 | echo "The job is running, trying to attach to the output stream ..." 64 | sleep 3 65 | 66 | 67 | while [[ $JOB_STATE == "RUNNING" ]] 68 | do 69 | SATTACH_OUTPUT=$(sattach $JOB_ID.0 2>&1 | tee /dev/tty) 70 | if [[ $SATTACH_OUTPUT == *"Job/step already completing or completed"* ]] \ 71 | || [[ $SATTACH_OUTPUT == *"Socket timed out on send/recv operation"* ]] \ 72 | || [[ $SATTACH_OUTPUT == *"does not look like a jobid"* ]]; 73 | then 74 | break 75 | fi 76 | sleep 1 77 | done 78 | 79 | # while [[ $JOB_STATE == "RUNNING" ]] 80 | # do 81 | # JOB_STATE=$(squeue -j $JOB_ID -h -o %T) 82 | # tail -f slurm-$JOB_ID.out 83 | # sleep 1 84 | # done 85 | -------------------------------------------------------------------------------- /src/llm_training/data/pre_training/pre_training_datacollator.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any 3 | 4 | import torch 5 | 6 | from llm_training.data.base_datacollator import BaseDataCollator 7 | from .pre_training_datamodule_config import PreTrainingDataModuleConfig 8 | 9 | 10 | class PreTrainingDataCollator(BaseDataCollator): 11 | config: PreTrainingDataModuleConfig 12 | 13 | @property 14 | def tokenizer(self): 15 | return self.config.tokenizer 16 | 17 | def __init__(self, config: PreTrainingDataModuleConfig) -> None: 18 | super().__init__(config) 19 | 20 | assert 'pad_token' in config.tokenizer.special_tokens_map, '`pad_token` is not specified. Please set it manually.' 21 | 22 | def _pad_to_longest(self, x: list[list[int]]) -> list[list[int]]: 23 | n = self.config.max_length if self.config.pad_to_max_length else max(len(y) for y in batch) 24 | 25 | if self.config.pad_to_multiple_of is not None: 26 | n = (math.ceil(n / self.config.pad_to_multiple_of)) * self.config.pad_to_multiple_of 27 | 28 | for y in x: 29 | num_paddings = n - len(y) 30 | paddings = [-1] * num_paddings 31 | y[:] = paddings + y if self.tokenizer.padding_side == 'left' else y + paddings 32 | return x 33 | 34 | def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]: 35 | input_ids = [x['input_ids'] for x in batch] 36 | 37 | input_ids = self._pad_to_longest(input_ids) 38 | input_ids = torch.tensor(input_ids) 39 | padding_mask = input_ids == -1 40 | input_ids[padding_mask] = self.tokenizer.pad_token_id 41 | bos_mask = input_ids == self.tokenizer.bos_token_id 42 | 43 | return { 44 | 'input_ids': input_ids, 45 | 'attention_mask': torch.ones_like(input_ids).masked_fill(padding_mask, 0), 46 | 'position_ids': torch.arange(input_ids.size(1)).unsqueeze(0), 47 | 'labels': input_ids.masked_fill(bos_mask | padding_mask, -100) 48 | } 49 | -------------------------------------------------------------------------------- /config/examples/phi-3/phi-3-mini_pt_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.DeepSpeedStrategy 8 | init_args: 9 | stage: 2 10 | precision: bf16-true 11 | logger: 12 | class_path: llm_training.lightning.WandbLogger 13 | init_args: 14 | name: phi-3-mini-128k-instruct_pt_example 15 | job_type: example 16 | project: llm-training 17 | save_dir: logs 18 | save_code: true 19 | max_epochs: 1 20 | val_check_interval: null 21 | accumulate_grad_batches: 1 22 | gradient_clip_val: 1.0 23 | callbacks: 24 | - class_path: LearningRateMonitor 25 | - class_path: llm_training.lightning.ModelCheckpoint 26 | init_args: 27 | save_on_train_epoch_end: true 28 | save_top_k: -1 29 | - class_path: llm_training.lightning.ModelCheckpoint 30 | init_args: 31 | save_top_k: 1 32 | every_n_train_steps: 100 33 | 34 | model: 35 | class_path: llm_training.lms.CLM 36 | init_args.config: 37 | model: 38 | model_class: llm_training.models.Phi3 39 | model_config: 40 | hf_path: microsoft/Phi-3-mini-128k-instruct 41 | enable_gradient_checkpointing: true 42 | 43 | optim: 44 | optimizer_class: deepspeed.ops.adam.FusedAdam 45 | optimizer_kwargs: 46 | lr: 1e-4 47 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 48 | lr_scheduler_kwargs: 49 | num_warmup_steps: 100 50 | min_lr: 1e-5 51 | 52 | data: 53 | class_path: llm_training.data.PreTrainingDataModule 54 | init_args.config: 55 | dataset_kwargs: 56 | path: Salesforce/wikitext 57 | name: wikitext-2-v1 58 | num_proc: 32 59 | # pre_processed_data_path: data/pre_processed/phi-3/wikitext-2-v1 60 | tokenizer: 61 | class_path: HFTokenizer 62 | init_args.path: microsoft/Phi-3-mini-128k-instruct 63 | batch_size: 1 64 | max_length: 4096 65 | num_proc: 32 66 | num_workers: 4 67 | -------------------------------------------------------------------------------- /src/llm_training/lightning/callbacks/extra_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Literal 4 | 5 | import torch 6 | from lightning import LightningModule, Trainer 7 | from lightning.pytorch.callbacks.callback import Callback 8 | from triton.runtime.cache import default_cache_dir 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class ExtraConfig(Callback): 14 | def __init__( 15 | self, 16 | float32_matmul_precision: Literal["medium", "high", "highest"] | None = None, 17 | logging_level: int | str = logging.INFO, 18 | env: dict[str, str] | None = None 19 | ) -> None: 20 | super().__init__() 21 | 22 | self.float32_matmul_precision = float32_matmul_precision 23 | self.logging_level = logging_level 24 | self.env = env 25 | 26 | self._configure_float32_matmul_precision() 27 | self._configure_logging_level() 28 | self._configure_environment() 29 | 30 | def _configure_float32_matmul_precision(self) -> None: 31 | if self.float32_matmul_precision is not None: 32 | torch.set_float32_matmul_precision(self.float32_matmul_precision) 33 | 34 | def _configure_logging_level(self) -> None: 35 | if isinstance(self.logging_level, str): 36 | logging_level = getattr(logging, self.logging_level.upper()) 37 | else: 38 | logging_level = self.logging_level 39 | 40 | logging.getLogger('llm_training').setLevel(logging_level) 41 | logging.getLogger('lightning').setLevel(logging_level) 42 | 43 | def _configure_environment(self) -> None: 44 | if self.env is not None: 45 | os.environ.update(self.env) 46 | 47 | def _configure_triton_cache_dir(self, rank: int) -> None: 48 | if not os.getenv('TRITON_CACHE_DIR', '').strip(): 49 | os.environ['TRITON_CACHE_DIR'] = os.path.join(default_cache_dir(), 'llm_training', f'rank_{rank}') 50 | 51 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str): 52 | self._configure_triton_cache_dir(trainer.global_rank) 53 | -------------------------------------------------------------------------------- /config/examples/llama-3.1/llama-3.1-8b_pt_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.DeepSpeedStrategy 8 | init_args: 9 | stage: 2 10 | precision: bf16-true 11 | logger: 12 | class_path: llm_training.lightning.WandbLogger 13 | init_args: 14 | name: llama-3.1-8b_pt_example 15 | job_type: example 16 | project: llm-training 17 | save_dir: logs 18 | save_code: true 19 | max_epochs: 1 20 | val_check_interval: null 21 | accumulate_grad_batches: 1 22 | gradient_clip_val: 1.0 23 | callbacks: 24 | - class_path: LearningRateMonitor 25 | - class_path: llm_training.lightning.ModelCheckpoint 26 | init_args: 27 | save_on_train_epoch_end: true 28 | save_top_k: -1 29 | - class_path: llm_training.lightning.ModelCheckpoint 30 | init_args: 31 | save_top_k: 1 32 | every_n_train_steps: 5000 33 | 34 | model: 35 | class_path: llm_training.lms.CLM 36 | init_args.config: 37 | model: 38 | model_class: llm_training.models.Llama 39 | model_config: 40 | hf_path: meta-llama/Llama-3.1-8B 41 | enable_gradient_checkpointing: true 42 | 43 | optim: 44 | optimizer_class: deepspeed.ops.adam.FusedAdam 45 | optimizer_kwargs: 46 | lr: 1e-5 47 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 48 | lr_scheduler_kwargs: 49 | num_warmup_steps: 1000 50 | min_lr: 1e-6 51 | 52 | data: 53 | class_path: llm_training.data.PreTrainingDataModule 54 | init_args.config: 55 | dataset_kwargs: 56 | path: Salesforce/wikitext 57 | name: wikitext-2-v1 58 | num_proc: 32 59 | # pre_processed_data_path: data/pre_processed/llama-3.1/wikitext-2-v1 60 | tokenizer: 61 | class_path: HFTokenizer 62 | init_args: 63 | path: meta-llama/Llama-3.1-8B 64 | pad_token: <|end_of_text|> 65 | padding_side: left 66 | batch_size: 1 67 | max_length: 4096 68 | num_proc: 32 69 | num_workers: 4 70 | -------------------------------------------------------------------------------- /config/examples/phi-3/phi-3-mini_it_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.DeepSpeedStrategy 8 | init_args: 9 | stage: 2 10 | precision: bf16-true 11 | logger: 12 | class_path: llm_training.lightning.WandbLogger 13 | init_args: 14 | name: phi-3-mini-128k-instruct_it_example 15 | job_type: example 16 | project: llm-training 17 | save_dir: logs 18 | save_code: true 19 | max_epochs: 1 20 | val_check_interval: null 21 | accumulate_grad_batches: 1 22 | gradient_clip_val: 1.0 23 | callbacks: 24 | - class_path: LearningRateMonitor 25 | - class_path: llm_training.lightning.ModelCheckpoint 26 | init_args: 27 | save_on_train_epoch_end: true 28 | save_top_k: 1 29 | 30 | model: 31 | class_path: llm_training.lms.CLM 32 | init_args.config: 33 | model: 34 | model_class: llm_training.models.Phi3 35 | model_config: 36 | hf_path: microsoft/Phi-3-mini-128k-instruct 37 | enable_gradient_checkpointing: true 38 | 39 | # neftune_alpha: 5.0 40 | 41 | optim: 42 | optimizer_class: deepspeed.ops.adam.FusedAdam 43 | optimizer_kwargs: 44 | lr: 1e-5 45 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 46 | lr_scheduler_kwargs: 47 | num_warmup_steps: 100 48 | min_lr: 1e-6 49 | 50 | data: 51 | class_path: llm_training.data.InstructionTuningDataModule 52 | init_args.config: 53 | dataset_kwargs: 54 | path: ShinoharaHare/Infinity-Instruct-Reformatted 55 | name: "0625" 56 | pre_processed_data_path: null 57 | tokenizer: 58 | class_path: HFTokenizer 59 | init_args: 60 | path: microsoft/Phi-3-mini-128k-instruct 61 | batch_size: 1 62 | add_default_system_prompt_rate: 0.0 63 | default_system_prompt: "" 64 | chat_template: phi-3 65 | packing_method: GROUP_BY_LENGTH 66 | max_length: 4096 67 | pad_to_multiple_of: 64 68 | validation_split: null 69 | num_proc: 4 70 | num_workers: 4 71 | enable_cache: true 72 | -------------------------------------------------------------------------------- /src/llm_training/lightning/callbacks/save_config_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | import yaml 6 | from lightning import LightningModule, Trainer 7 | from lightning.pytorch.cli import SaveConfigCallback as _SaveConfigCallback 8 | 9 | from llm_training.lightning.loggers.wandb import WandbLogger 10 | 11 | 12 | class SaveConfigCallback(_SaveConfigCallback): 13 | 14 | def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 15 | logger = trainer.logger 16 | if isinstance(logger, WandbLogger): 17 | config_path = Path(logger.save_dir, self.config_filename) 18 | 19 | self.parser.save( 20 | self.config, 21 | config_path, 22 | skip_none=False, 23 | overwrite=self.overwrite, 24 | multifile=self.multifile 25 | ) 26 | 27 | with open(config_path) as f: 28 | self.yaml_config = yaml.safe_load(f) 29 | 30 | self.yaml_config['world_size'] = trainer.world_size 31 | 32 | if 'SLURM_JOB_ID' in os.environ: 33 | self.yaml_config['slurm_job_id'] = os.environ['SLURM_JOB_ID'] 34 | self.yaml_config['slurm_job_name'] = os.environ.get('SLURM_JOB_NAME', None) 35 | self.yaml_config['slurm_num_nodes'] = os.environ.get('SLURM_NNODES', None) 36 | self.yaml_config['slurm_ntasks'] = os.environ.get('SLURM_NTASKS', None) 37 | 38 | logger.log_hyperparams(self.yaml_config) 39 | logger.experiment.save(config_path, policy='now') 40 | logger.experiment.log_code(include_fn=_wandb_code_include_fn) 41 | 42 | def on_save_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint: dict[str, Any]) -> None: 43 | yaml_config = self.yaml_config if trainer.is_global_zero else None 44 | checkpoint['config'] = trainer.strategy.broadcast(yaml_config, src=0) 45 | 46 | 47 | def _wandb_code_include_fn(path: str, root: str): 48 | p = Path(path).relative_to(root) 49 | return p.parts[0] in ['src', 'scripts'] and p.suffix in ['.py', '.sh', '.j2'] 50 | -------------------------------------------------------------------------------- /config/examples/llama-3.1/llama-3.1-8b_it_example.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 42 2 | float32_matmul_precision: medium 3 | logging_level: DEBUG 4 | 5 | trainer: 6 | strategy: 7 | class_path: llm_training.lightning.DeepSpeedStrategy 8 | init_args: 9 | stage: 2 10 | precision: bf16-true 11 | logger: 12 | class_path: llm_training.lightning.WandbLogger 13 | init_args: 14 | name: llama-3.1-8b_it_example 15 | job_type: example 16 | project: llm-training 17 | save_dir: logs 18 | save_code: true 19 | max_epochs: 1 20 | val_check_interval: null 21 | accumulate_grad_batches: 1 22 | gradient_clip_val: 1.0 23 | callbacks: 24 | - class_path: LearningRateMonitor 25 | - class_path: llm_training.lightning.ModelCheckpoint 26 | init_args: 27 | save_on_train_epoch_end: true 28 | save_top_k: 1 29 | 30 | model: 31 | class_path: llm_training.lms.CLM 32 | init_args.config: 33 | model: 34 | model_class: llm_training.models.Llama 35 | model_config: 36 | hf_path: meta-llama/Llama-3.1-8B-Instruct 37 | enable_gradient_checkpointing: true 38 | 39 | # neftune_alpha: 5.0 40 | 41 | optim: 42 | optimizer_class: deepspeed.ops.adam.FusedAdam 43 | optimizer_kwargs: 44 | lr: 1e-5 45 | lr_scheduler_class: llm_training.lr_schedulers.CosineAnnealingWarmupLR 46 | lr_scheduler_kwargs: 47 | num_warmup_steps: 100 48 | min_lr: 1e-6 49 | 50 | data: 51 | class_path: llm_training.data.InstructionTuningDataModule 52 | init_args.config: 53 | dataset_kwargs: 54 | path: ShinoharaHare/Infinity-Instruct-Reformatted 55 | name: "0625" 56 | pre_processed_data_path: null 57 | tokenizer: 58 | class_path: HFTokenizer 59 | init_args: 60 | path: meta-llama/Llama-3.1-8B-Instruct 61 | pad_token: <|end_of_text|> 62 | padding_side: left 63 | batch_size: 1 64 | add_default_system_prompt_rate: 0.0 65 | default_system_prompt: "" 66 | chat_template: llama-3.1 67 | packing_method: GROUP_BY_LENGTH 68 | max_length: 4096 69 | pad_to_multiple_of: 64 70 | validation_split: null 71 | num_proc: 4 72 | num_workers: 4 73 | enable_cache: true 74 | -------------------------------------------------------------------------------- /src/llm_training/lightning/loggers/wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Literal, Optional, Union 3 | 4 | from lightning.fabric.utilities.types import _PATH 5 | from lightning.pytorch.loggers.wandb import WandbLogger as _WandbLogger 6 | from wandb.sdk.lib import RunDisabled 7 | from wandb.wandb_run import Run 8 | 9 | 10 | class WandbLogger(_WandbLogger): 11 | def __init__( 12 | self, 13 | name: Optional[str] = None, 14 | save_dir: _PATH = ".", 15 | version: Optional[str] = None, 16 | offline: bool = False, 17 | dir: Optional[_PATH] = None, 18 | id: Optional[str] = None, 19 | anonymous: Optional[bool] = None, 20 | project: Optional[str] = None, 21 | log_model: Union[Literal["all"], bool] = False, 22 | experiment: Union["Run", "RunDisabled", None] = None, 23 | prefix: str = "", 24 | checkpoint_name: Optional[str] = None, 25 | entity: str | None = None, 26 | tags: list | None = None, 27 | save_code: bool | None = None, 28 | **kwargs: Any, 29 | ) -> None: 30 | super().__init__( 31 | name=name, 32 | save_dir=save_dir, 33 | version=version, 34 | offline=offline, 35 | dir=dir, 36 | id=id, 37 | anonymous=anonymous, 38 | project=project, 39 | log_model=log_model, 40 | experiment=experiment, 41 | prefix=prefix, 42 | checkpoint_name=checkpoint_name, 43 | entity=entity, 44 | tags=tags, 45 | save_code=save_code, 46 | **kwargs 47 | ) 48 | 49 | @property 50 | def experiment(self) -> Run: 51 | os.makedirs(self._wandb_init['dir'], exist_ok=True) 52 | return super().experiment 53 | 54 | @property 55 | def name(self) -> str: 56 | return self._name 57 | 58 | @property 59 | def save_dir(self) -> str: 60 | return os.path.join( 61 | self._save_dir, 62 | self._project, 63 | self._name 64 | ) 65 | 66 | @property 67 | def log_dir(self) -> str: 68 | return os.path.join( 69 | self._save_dir, 70 | self._project, 71 | self._name 72 | ) 73 | -------------------------------------------------------------------------------- /src/llm_training/data/instruction_tuning/instruction_tuning_datamodule_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from enum import auto 3 | 4 | from pydantic import ValidationInfo, field_validator 5 | from transformers import PreTrainedTokenizerBase 6 | 7 | from llm_training.data.chat_templates import get_chat_template 8 | from llm_training.data.hf_based.hf_based_datamodule_config import \ 9 | HFBasedDataModuleConfig 10 | from llm_training.utils.str_enum import StrEnum 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class OverlongHandlingMethod(StrEnum): 16 | DROP = auto() 17 | TRUNCATE = auto() 18 | 19 | 20 | class PackingMethod(StrEnum): 21 | NO_PACKING = auto() 22 | GROUP_BY_LENGTH = auto() 23 | 24 | 25 | class InstructionTuningDataModuleConfig(HFBasedDataModuleConfig): 26 | tokenizer: PreTrainedTokenizerBase 27 | chat_template: str | None = None 28 | max_length: int | None = None 29 | overlong_handling_method: OverlongHandlingMethod | str = OverlongHandlingMethod.DROP 30 | packing_method: PackingMethod | str = PackingMethod.NO_PACKING 31 | pad_to_multiple_of: int | None = None 32 | pad_to_max_length: bool = False 33 | add_default_system_prompt_rate: float | None = None 34 | default_system_prompt: str | None = None 35 | 36 | @field_validator('chat_template') 37 | @classmethod 38 | def validate_chat_template(cls, value: str | None) -> str | None: 39 | if value is not None: 40 | value = get_chat_template(value) 41 | return value 42 | 43 | @field_validator('default_system_prompt') 44 | @classmethod 45 | def validate_default_system_prompt(cls, value: str | None, info: ValidationInfo) -> str | None: 46 | assert value is None or info.data['add_default_system_prompt_rate'] is not None, \ 47 | "Default system prompt must be set to use `add_default_system_prompt_rate`." 48 | return value 49 | 50 | @field_validator('overlong_handling_method') 51 | @classmethod 52 | def validate_overlong_handling_method(cls, value: OverlongHandlingMethod | str) -> OverlongHandlingMethod: 53 | return OverlongHandlingMethod(value.lower()) 54 | 55 | @field_validator('packing_method') 56 | @classmethod 57 | def validate_packing_method(cls, value: PackingMethod | str) -> PackingMethod: 58 | return PackingMethod(value.lower()) 59 | -------------------------------------------------------------------------------- /docs/config.md: -------------------------------------------------------------------------------- 1 | # Config 2 | 3 | A config file is a YAML file used to set up everything, including seeding, distributed strategy, hyper-parameters, model, data, and more. 4 | 5 | When writing your own config file, it's a good idea to refer to some examples in the [config](../config/examples/) directory for guidance. 6 | 7 | The config file can be divided into three parts, including trainer, model and data. 8 | 9 | Since the `llm-training` command is implemented using [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html#lightning.pytorch.cli.LightningCLI), you can refer to [the lightning tutorials](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_advanced.html) for more information. 10 | 11 | ## Trainer 12 | 13 | This part is used to set the parameters to be passed to the Lightning Trainer, which controls various general settings, such as distributed strategy, precision, logger, epochs, gradient clipping, checkpointing, and more. 14 | 15 | Please refer to the [Trainer API](https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api) for more information. 16 | 17 | ```yaml 18 | trainer: 19 | strategy: ... 20 | precision: ... 21 | logger: ... 22 | max_epochs: 1 23 | accumulate_grad_batches: 2 24 | gradient_clip_val: 1.0 25 | callbacks: ... 26 | ``` 27 | 28 | ## Model 29 | 30 | This part controls model-related parameters, such as model architecture, pre-trained parameters, optimizer, and more. 31 | 32 | First, you need to set the `class_path` to determine which model class to use. 33 | 34 | Next, you can use `init_args` to set initialization parameters. 35 | 36 | Specific parameters that can be set depend on the chosen model class. 37 | 38 | ```yaml 39 | model: 40 | class_path: path.to.your.model.class 41 | init_args: 42 | key1: value1 43 | key2: value2 44 | ``` 45 | 46 | ## Data 47 | 48 | This part controls data-related parameters, such as data source, data processing pipeline, batch size, etc. 49 | 50 | Similar to the model config, you first use `class_path` to determine the data module class, then use `init_args` to set parameters for it. 51 | 52 | ```yaml 53 | data: 54 | class_path: path.to.your.datamodule.class 55 | init_args: 56 | key1: value1 57 | key2: value2 58 | ``` 59 | -------------------------------------------------------------------------------- /src/llm_training/models/base_model/base_model.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import ClassVar, Generator 3 | 4 | import safetensors.torch 5 | import torch 6 | from torch import nn 7 | from torch.distributed._composable.fsdp import (MixedPrecisionPolicy, 8 | OffloadPolicy) 9 | from torch.distributed.device_mesh import DeviceMesh 10 | 11 | from .base_model_config import BaseModelConfig 12 | 13 | 14 | class BaseModel(nn.Module): 15 | _init_weights: ClassVar[bool] = True 16 | 17 | config_class: type[BaseModelConfig] = BaseModelConfig 18 | no_split_modules: list[str] = [] 19 | 20 | def __init__(self, config: BaseModelConfig) -> None: 21 | super().__init__() 22 | 23 | self.config = config 24 | 25 | if self._init_weights: 26 | self.init_weights() 27 | 28 | @property 29 | def has_pre_trained_weights(self) -> bool: 30 | return self.config.pre_trained_weights is not None 31 | 32 | def get_pre_trained_weights(self) -> dict[str, torch.Tensor]: 33 | return safetensors.torch.load_file(self.config.pre_trained_weights) 34 | 35 | def _init_weights_impl(self, module: nn.Module) -> None: ... 36 | 37 | def init_weights(self) -> None: 38 | self.apply(self._init_weights_impl) 39 | 40 | @classmethod 41 | @contextmanager 42 | def init_weights_context(cls, init_weights: bool) -> Generator[None, None, None]: 43 | v = cls._init_weights 44 | cls._init_weights = bool(init_weights) 45 | yield 46 | cls._init_weights = v 47 | 48 | def configure_tensor_parallel(self, tp_mesh: DeviceMesh) -> None: 49 | if tp_mesh.size() == 1: 50 | return 51 | 52 | raise NotImplementedError(f"`{self.__class__.__name__}` does not support tensor parallel.") 53 | 54 | def configure_fully_sharded_data_parallel( 55 | self, 56 | dp_mesh: DeviceMesh, 57 | reshard_after_forward: bool | int, 58 | mp_policy: MixedPrecisionPolicy, 59 | offload_policy: OffloadPolicy, 60 | **kwargs 61 | ) -> None: 62 | if dp_mesh.size() == 1: 63 | return 64 | 65 | raise NotImplementedError(f"`{self.__class__.__name__}` does not support fully sharded data parallel.") 66 | 67 | def parallelize( 68 | self, 69 | dp_mesh: DeviceMesh, 70 | tp_mesh: DeviceMesh, 71 | **fsdp_kwargs 72 | ) -> None: 73 | self.configure_tensor_parallel(tp_mesh) 74 | self.configure_fully_sharded_data_parallel(dp_mesh, **fsdp_kwargs) 75 | -------------------------------------------------------------------------------- /src/llm_training/models/utils/utils.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Generator 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | @contextmanager 9 | def init_on_device(device: torch.device, include_buffers: bool = False) -> Generator[None, None, None]: 10 | if include_buffers: 11 | with device: 12 | yield 13 | return 14 | 15 | old_register_parameter = nn.Module.register_parameter 16 | if include_buffers: 17 | old_register_buffer = nn.Module.register_buffer 18 | 19 | def register_empty_parameter(module, name, param): 20 | old_register_parameter(module, name, param) 21 | if param is not None and param.device != device: 22 | param_cls = type(module._parameters[name]) 23 | kwargs = module._parameters[name].__dict__ 24 | kwargs["requires_grad"] = param.requires_grad 25 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 26 | 27 | def register_empty_buffer(module, name, buffer, persistent=True): 28 | old_register_buffer(module, name, buffer, persistent=persistent) 29 | if buffer is not None and buffer.device != device: 30 | module._buffers[name] = module._buffers[name].to(device) 31 | 32 | # Patch tensor creation 33 | if include_buffers: 34 | tensor_constructors_to_patch = { 35 | torch_function_name: getattr(torch, torch_function_name) 36 | for torch_function_name in ["empty", "zeros", "ones", "full"] 37 | } 38 | else: 39 | tensor_constructors_to_patch = {} 40 | 41 | def patch_tensor_constructor(fn): 42 | def wrapper(*args, **kwargs): 43 | kwargs["device"] = device 44 | return fn(*args, **kwargs) 45 | 46 | return wrapper 47 | 48 | try: 49 | nn.Module.register_parameter = register_empty_parameter 50 | if include_buffers: 51 | nn.Module.register_buffer = register_empty_buffer 52 | for torch_function_name in tensor_constructors_to_patch.keys(): 53 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 54 | yield 55 | finally: 56 | nn.Module.register_parameter = old_register_parameter 57 | if include_buffers: 58 | nn.Module.register_buffer = old_register_buffer 59 | for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): 60 | setattr(torch, torch_function_name, old_torch_function) 61 | 62 | 63 | def init_empty_weights(include_buffers: bool = False): 64 | return init_on_device(torch.device('meta'), include_buffers=include_buffers) 65 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/qwen2.5.j2: -------------------------------------------------------------------------------- 1 | {%- if tools %} 2 | {{- '<|im_start|>system\n' }} 3 | {%- if messages[0]['role'] == 'system' %} 4 | {{- messages[0]['content'] }} 5 | {%- else %} 6 | {{- 'You are a helpful assistant.' }} 7 | {%- endif %} 8 | {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} 9 | {%- for tool in tools %} 10 | {{- "\n" }} 11 | {{- tool | tojson }} 12 | {%- endfor %} 13 | {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} 14 | {%- else %} 15 | {%- if messages[0]['role'] == 'system' %} 16 | {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }} 17 | {%- else %} 18 | {{- '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }} 19 | {%- endif %} 20 | {%- endif %} 21 | {%- for message in messages %} 22 | {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} 23 | {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} 24 | {%- elif message.role == "assistant" %} 25 | {{- '<|im_start|>assistant' }} 26 | {%- if message.content %} 27 | {{- '\n' }} 28 | {%- endif %} 29 | {% generation %} 30 | {{- message.content }} 31 | {%- if message.tool_calls %} 32 | {%- for tool_call in message.tool_calls %} 33 | {%- if tool_call.function is defined %} 34 | {%- set tool_call = tool_call.function %} 35 | {%- endif %} 36 | {{- '\n\n{"name": "' }} 37 | {{- tool_call.name }} 38 | {{- '", "arguments": ' }} 39 | {{- tool_call.arguments | tojson }} 40 | {{- '}\n' -}} 41 | {%- endfor %} 42 | {%- endif %} 43 | {{- '<|im_end|>' + '\n' -}} 44 | {% endgeneration %} 45 | {%- elif message.role == "tool" %} 46 | {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %} 47 | {{- '<|im_start|>user' }} 48 | {%- endif %} 49 | {{- '\n\n' }} 50 | {{- message.content }} 51 | {{- '\n' }} 52 | {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} 53 | {{- '<|im_end|>\n' }} 54 | {%- endif %} 55 | {%- endif %} 56 | {%- endfor %} 57 | {%- if add_generation_prompt %} 58 | {{- '<|im_start|>assistant\n' }} 59 | {%- endif %} 60 | -------------------------------------------------------------------------------- /src/llm_training/data/instruction_tuning/instruction_tuning_datacollator.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, TypeVar 3 | 4 | import torch 5 | 6 | from llm_training.data.base_datacollator import BaseDataCollator 7 | 8 | from .instruction_tuning_datamodule_config import ( 9 | PackingMethod, InstructionTuningDataModuleConfig) 10 | 11 | T = TypeVar('T') 12 | 13 | class InstructionTuningDataCollator(BaseDataCollator): 14 | config: InstructionTuningDataModuleConfig 15 | 16 | def __init__(self, config: InstructionTuningDataModuleConfig): 17 | super().__init__(config) 18 | 19 | assert 'pad_token' in config.tokenizer.special_tokens_map, '`pad_token` is not specified. Please set it manually.' 20 | 21 | def _pad_to_longest(self, batch: list[list[T]], padding_value: T) -> list[list[T]]: 22 | n = self.config.max_length if self.config.pad_to_max_length else max(len(y) for y in batch) 23 | 24 | if self.config.pad_to_multiple_of is not None: 25 | n = (math.ceil(n / self.config.pad_to_multiple_of)) * self.config.pad_to_multiple_of 26 | 27 | new_batch = [] 28 | for x in batch: 29 | num_paddings = n - len(x) 30 | paddings = [padding_value] * num_paddings 31 | x = paddings + x if self.config.tokenizer.padding_side == 'left' else x + paddings 32 | new_batch.append(x) 33 | 34 | return new_batch 35 | 36 | def __call__(self, batch: list[dict[str, Any]]): 37 | batch_input_ids = [] 38 | batch_attention_mask = [] 39 | batch_position_ids = [] 40 | batch_labels = [] 41 | 42 | for x in batch: 43 | input_ids = x['input_ids'] 44 | labels = x['labels'] 45 | n = len(input_ids) 46 | 47 | if self.config.packing_method == PackingMethod.NO_PACKING: 48 | position_ids = list(range(n)) 49 | attention_mask = [1] * n 50 | elif self.config.packing_method == PackingMethod.GROUP_BY_LENGTH: 51 | position_ids = list(range(n)) 52 | attention_mask = x['attention_mask'] 53 | 54 | batch_input_ids.append(input_ids) 55 | batch_attention_mask.append(attention_mask) 56 | batch_position_ids.append(position_ids) 57 | batch_labels.append(labels) 58 | 59 | batch_input_ids = self._pad_to_longest(batch_input_ids, self.config.tokenizer.pad_token_id) 60 | batch_attention_mask = self._pad_to_longest(batch_attention_mask, 0) 61 | batch_position_ids = self._pad_to_longest(batch_position_ids, 0) 62 | batch_labels = self._pad_to_longest(batch_labels, -100) 63 | 64 | input_ids = torch.tensor(batch_input_ids) 65 | attention_mask = torch.tensor(batch_attention_mask) 66 | position_ids = torch.tensor(batch_position_ids) 67 | labels = torch.tensor(batch_labels) 68 | 69 | return { 70 | 'input_ids': input_ids, 71 | 'attention_mask': attention_mask, 72 | 'position_ids': position_ids, 73 | 'labels': labels 74 | } 75 | -------------------------------------------------------------------------------- /src/llm_training/lightning/callbacks/training_time_estimator.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Any, Mapping 4 | 5 | import torch 6 | from lightning import LightningModule, Trainer 7 | from lightning.pytorch.callbacks import Callback, Checkpoint 8 | from tabulate import tabulate 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class TrainingTimeEstimator(Callback): 13 | def __init__( 14 | self, 15 | num_test_steps: int, 16 | num_warmup_steps: int = 2, 17 | enable_checkpointing: bool = False 18 | ) -> None: 19 | super().__init__() 20 | 21 | assert num_warmup_steps >= 0 22 | assert num_warmup_steps < num_test_steps 23 | 24 | self.num_test_steps = num_test_steps 25 | self.num_warmup_steps = num_warmup_steps 26 | self.enable_checkpointing = enable_checkpointing 27 | 28 | self.start_time = 0 29 | self.end_time = 0 30 | 31 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 32 | trainer.limit_val_batches = 0 33 | trainer.accumulate_grad_batches = 1 34 | 35 | if not self.enable_checkpointing: 36 | trainer.callbacks = [c for c in trainer.callbacks if not isinstance(c, Checkpoint)] 37 | 38 | def print_estimated_training_time(self) -> None: 39 | seconds = self.end_time - self.start_time 40 | steps_per_second = (self.num_test_steps - self.num_warmup_steps) / seconds 41 | estimated_seconds = self.num_total_steps / steps_per_second 42 | 43 | for unit_seconds, unit_name in [ 44 | (60 * 60 * 24, 'days'), 45 | (60 * 60, 'hours'), 46 | (60, 'minutes'), 47 | (1, 'seconds') 48 | ]: 49 | if estimated_seconds >= unit_seconds: 50 | estimated_training_time = f'{estimated_seconds / unit_seconds:.2f} {unit_name}' 51 | break 52 | 53 | s = tabulate( 54 | [ 55 | ['Running Time', f'{seconds:.2f} seconds'], 56 | ['Steps per second', f'{steps_per_second:.2f} steps'], 57 | ['Estimated training time', estimated_training_time] 58 | ], 59 | tablefmt='fancy_grid' 60 | ) 61 | print(s) 62 | 63 | def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: 64 | if trainer.global_step == self.num_warmup_steps: 65 | self.start_time = time.time() 66 | 67 | def on_train_batch_end( 68 | self, 69 | trainer: Trainer, 70 | pl_module: LightningModule, 71 | outputs: torch.Tensor | Mapping[str, Any] | None, 72 | batch: Any, 73 | batch_idx: int 74 | ) -> None: 75 | if trainer.global_step == self.num_test_steps: 76 | self.end_time = time.time() 77 | trainer.should_stop = True 78 | 79 | def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: 80 | self.num_total_steps = trainer.estimated_stepping_batches 81 | 82 | if trainer.is_global_zero: 83 | self.print_estimated_training_time() 84 | -------------------------------------------------------------------------------- /src/llm_training/optim/master_weight_wrapper.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import ClassVar, Mapping 3 | 4 | import torch 5 | from torch import nn 6 | from torch.optim import Optimizer 7 | from typing_extensions import Self 8 | 9 | 10 | class MasterWeightsOptimizer(Optimizer): 11 | _is_subclass: ClassVar[bool] = False 12 | _subclasses: ClassVar[dict[type[Optimizer], type[Self]]] = {} 13 | 14 | _parameter_mapping: Mapping[nn.Parameter, nn.Parameter] 15 | _parameters: list[nn.Parameter] 16 | 17 | def __new__(cls, optimizer: Optimizer): 18 | if cls._is_subclass: 19 | return object.__new__(cls) 20 | 21 | optimizer_class = type(optimizer) 22 | if optimizer_class not in cls._subclasses: 23 | cls._subclasses[optimizer_class] = type( 24 | optimizer_class.__name__, 25 | (cls,), 26 | {'_is_subclass': True} 27 | ) 28 | 29 | return cls._subclasses[optimizer_class](optimizer) 30 | 31 | def __init__(self, optimizer: Optimizer): 32 | self._optimizer = optimizer 33 | self._parameters = [p for g in self._optimizer.param_groups for p in g['params']] 34 | self._parameter_mapping = {} 35 | for p in self._parameters: 36 | mp = p if p.dtype == torch.float else p.detach().float().requires_grad_() 37 | self._parameter_mapping[p] = mp 38 | self._parameter_mapping[mp] = p 39 | 40 | @contextmanager 41 | def _replace_params(self, replace_state_key: bool = True): 42 | try: 43 | for group in self._optimizer.param_groups: 44 | params = group['params'] 45 | for i, p in enumerate(params): 46 | params[i] = self._parameter_mapping[p] 47 | 48 | if replace_state_key: 49 | for w in list(self._optimizer.state.keys()): 50 | self._optimizer.state[self._parameter_mapping[w]] = self._optimizer.state.pop(w) 51 | 52 | yield 53 | finally: 54 | for group in self._optimizer.param_groups: 55 | params = group['params'] 56 | for i, p in enumerate(params): 57 | params[i] = self._parameter_mapping[p] 58 | 59 | if replace_state_key: 60 | for mw in list(self._optimizer.state.keys()): 61 | self._optimizer.state[self._parameter_mapping[mw]] = self._optimizer.state.pop(mw) 62 | 63 | def step(self, closure=None): 64 | loss = None 65 | if closure is not None: 66 | with torch.enable_grad(): 67 | loss = closure() 68 | 69 | for p in self._parameters: 70 | if p.grad is None: 71 | continue 72 | self._parameter_mapping[p].grad = p.grad.float() 73 | 74 | with self._replace_params(): 75 | self._optimizer.step() 76 | 77 | for p in self._parameters: 78 | p.data.copy_(self._parameter_mapping[p], non_blocking=True) 79 | 80 | return loss 81 | 82 | def zero_grad(self, set_to_none = True): 83 | self._optimizer.zero_grad(set_to_none) 84 | 85 | with self._replace_params(replace_state_key=False): 86 | self._optimizer.zero_grad(set_to_none) 87 | 88 | def state_dict(self): 89 | return self._optimizer.state_dict() 90 | 91 | def load_state_dict(self, state_dict): 92 | with self._replace_params(): 93 | self._optimizer.load_state_dict(state_dict) 94 | 95 | def __getattr__(self, name): 96 | return getattr(self._optimizer, name) 97 | -------------------------------------------------------------------------------- /src/llm_training/models/phi3/phi3_config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Literal 2 | 3 | import torch 4 | from pydantic import ValidationInfo, field_validator 5 | 6 | from llm_training.models.hf_compat_model import HFCompatModelConfig 7 | 8 | 9 | class Phi3Config(HFCompatModelConfig): 10 | vocab_size: int = 32064 11 | hidden_size: int = 3072 12 | intermediate_size: int = 8192 13 | num_hidden_layers: int = 32 14 | num_attention_heads: int = 32 15 | num_key_value_heads: int | None = None 16 | resid_pdrop: float = 0.0 17 | embd_pdrop: float = 0.0 18 | attention_dropout: float = 0.0 19 | max_position_embeddings: int = 4096 20 | original_max_position_embeddings: int = 4096 21 | initializer_range: float = 0.02 22 | rms_norm_eps: float = 1e-5 23 | rope_theta: float = 10000.0 24 | rope_scaling: dict | None = None 25 | bos_token_id: int = 1 26 | eos_token_id: int = 32000 27 | pad_token_id: int = 32000 28 | sliding_window: int | None = None 29 | 30 | enable_gradient_checkpointing: bool = False 31 | recompute_granularity: Literal['full', 'selective'] = 'full' 32 | attention_compute_dtype: torch.dtype | str | None = None 33 | 34 | @field_validator('rope_scaling') 35 | @classmethod 36 | def validate_rope_scaling(cls, rope_scaling: dict[str, Any] | None, info: ValidationInfo) -> dict[str, Any] | None: 37 | """ 38 | Validate the `rope_scaling` configuration. 39 | """ 40 | if rope_scaling is None: 41 | return rope_scaling 42 | 43 | hidden_size = info.data['hidden_size'] 44 | num_attention_heads = info.data['num_attention_heads'] 45 | 46 | if not isinstance(rope_scaling, dict) or len(rope_scaling) != 3: 47 | raise ValueError( 48 | "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " 49 | f"got {rope_scaling}" 50 | ) 51 | rope_scaling_type = rope_scaling.get('type', None) 52 | rope_scaling_short_factor = rope_scaling.get('short_factor', None) 53 | rope_scaling_long_factor = rope_scaling.get('long_factor', None) 54 | if rope_scaling_type is None or rope_scaling_type not in ['longrope']: 55 | raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") 56 | if not ( 57 | isinstance(rope_scaling_short_factor, list) 58 | and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) 59 | ): 60 | raise ValueError( 61 | f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" 62 | ) 63 | if not len(rope_scaling_short_factor) == hidden_size // num_attention_heads // 2: 64 | raise ValueError( 65 | f"`rope_scaling`'s short_factor field must have length {hidden_size // num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" 66 | ) 67 | if not ( 68 | isinstance(rope_scaling_long_factor, list) 69 | and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) 70 | ): 71 | raise ValueError( 72 | f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" 73 | ) 74 | if not len(rope_scaling_long_factor) == hidden_size // num_attention_heads // 2: 75 | raise ValueError( 76 | f"`rope_scaling`'s long_factor field must have length {hidden_size // num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" 77 | ) 78 | 79 | return rope_scaling 80 | -------------------------------------------------------------------------------- /src/llm_training/data/preference_tuning/preference_tuning_datacollator.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, TypeVar 3 | 4 | import torch 5 | from llm_training.data.base_datacollator import BaseDataCollator 6 | 7 | from .preference_tuning_datamodule_config import \ 8 | PreferenceTuningDataModuleConfig 9 | 10 | T = TypeVar('T') 11 | 12 | class PreferenceTuningDataCollator(BaseDataCollator): 13 | config: PreferenceTuningDataModuleConfig 14 | 15 | def __init__(self, config: PreferenceTuningDataModuleConfig): 16 | super().__init__(config) 17 | 18 | assert 'pad_token' in config.tokenizer.special_tokens_map, \ 19 | '`pad_token` is not specified. Please set it manually.' 20 | 21 | def _pad_to_longest(self, batch: list[list[T]], padding_value: T) -> list[list[T]]: 22 | n = self.config.max_length if self.config.pad_to_max_length else max(len(y) for y in batch) 23 | 24 | if self.config.pad_to_multiple_of is not None: 25 | n = (math.ceil(n / self.config.pad_to_multiple_of)) * self.config.pad_to_multiple_of 26 | 27 | new_batch = [] 28 | for x in batch: 29 | num_paddings = n - len(x) 30 | paddings = [padding_value] * num_paddings 31 | x = paddings + x if self.config.tokenizer.padding_side == 'left' else x + paddings 32 | new_batch.append(x) 33 | 34 | return new_batch 35 | 36 | def __call__(self, batch: list[dict[str, Any]]) -> dict[str, torch.Tensor]: 37 | outputs = { 38 | 'chosen_input_ids': [], 39 | 'chosen_attention_mask': [], 40 | 'chosen_labels': [], 41 | 'rejected_input_ids': [], 42 | 'rejected_attention_mask': [], 43 | 'rejected_labels': [] 44 | } 45 | 46 | for x in batch: 47 | outputs['chosen_input_ids'].append(x['chosen_input_ids']) 48 | outputs['chosen_attention_mask'].append([1] * len(x['chosen_input_ids'])) 49 | outputs['chosen_labels'].append(x['chosen_labels']) 50 | outputs['rejected_input_ids'].append(x['rejected_input_ids']) 51 | outputs['rejected_attention_mask'].append([1] * len(x['rejected_input_ids'])) 52 | outputs['rejected_labels'].append(x['rejected_labels']) 53 | 54 | outputs['chosen_input_ids'] = self._pad_to_longest(outputs['chosen_input_ids'], self.config.tokenizer.pad_token_id) 55 | outputs['chosen_attention_mask'] = self._pad_to_longest(outputs['chosen_attention_mask'], 0) 56 | outputs['chosen_labels'] = self._pad_to_longest(outputs['chosen_labels'], -100) 57 | outputs['rejected_input_ids'] = self._pad_to_longest(outputs['rejected_input_ids'], self.config.tokenizer.pad_token_id) 58 | outputs['rejected_attention_mask'] = self._pad_to_longest(outputs['rejected_attention_mask'], 0) 59 | outputs['rejected_labels'] = self._pad_to_longest(outputs['rejected_labels'], -100) 60 | 61 | outputs['chosen_input_ids'] = torch.tensor(outputs['chosen_input_ids']) 62 | outputs['chosen_position_ids'] = torch.arange(outputs['chosen_input_ids'].size(1)).unsqueeze(0) 63 | outputs['chosen_attention_mask'] = torch.tensor(outputs['chosen_attention_mask']) 64 | outputs['chosen_labels'] = torch.tensor(outputs['chosen_labels']) 65 | outputs['rejected_input_ids'] = torch.tensor(outputs['rejected_input_ids']) 66 | outputs['rejected_position_ids'] = torch.arange(outputs['rejected_input_ids'].size(1)).unsqueeze(0) 67 | outputs['rejected_attention_mask'] = torch.tensor(outputs['rejected_attention_mask']) 68 | outputs['rejected_labels'] = torch.tensor(outputs['rejected_labels']) 69 | 70 | return outputs 71 | -------------------------------------------------------------------------------- /docs/pre_training.md: -------------------------------------------------------------------------------- 1 | # Pre-training 2 | 3 | The pre-training data processing logic is implemented through [`PreTrainingDataModule`](/src/llm_training/data/pre_training/pre_training_datamodule.py). 4 | 5 | It uses [datasets](https://github.com/huggingface/datasets) under the hood to load and process data. 6 | 7 | A valid input dataset must include a `text` field, which must be of type string. 8 | 9 | ## Key Parameters 10 | 11 | | Parameter | Description | 12 | | :--------------- | :-------------------------------------------------------------------------------------------------------------- | 13 | | `dataset_kwargs` | kwargs to be passed to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/loading) for loading data | 14 | | `tokenizer` | A transformers tokenizer for tokenizing the data | 15 | | `max_length` | Max length of the tokenized data | 16 | | `num_proc` | Number of CPU cores for processing data | 17 | 18 | For a complete set of parameters, please refer to [`PreTrainingDataModuleConfig`](/src/llm_training/data/pre_training/pre_training_datamodule_config.py). 19 | 20 | ## Pre-processing data before training 21 | 22 | Before training begins, the framework automatically processes the data, ensuring everything is ready before training starts. This is particularly convenient when dealing with small training dataset. However, pre-training datasets are typically large, and CPUs used during training are often limited, making this step very time-consuming. 23 | 24 | To address this issue, you can set `pre_processed_data_path` and use many CPUs to execute `scripts/pre_process_data.py` for pre-processing and saving the data in advance. 25 | 26 | Remember to set `num_proc` to the desired number of CPUs to utilize. 27 | 28 | ```yaml 29 | data: 30 | class_path: llm_training.data.PreTrainingDataModule 31 | init_args.config: 32 | ... 33 | pre_processed_data_path: 34 | num_proc: 35 | ``` 36 | 37 | ```bash 38 | python scripts/pre_process_data.py -c 39 | ``` 40 | 41 | ## Data Sampling 42 | 43 | `PreTrainingDataModule` also supports data sampling. 44 | You can include a `source` field in the dataset and set `sample_rate` in the config to sample data based on `source`. The `sample_rate` is a dictionary where the keys represent `source` and the values represent the sampling rates. 45 | 46 | The following config will downsample `source_1` by half and upsample `source_2` by 3 times. 47 | ```yaml 48 | data: 49 | class_path: llm_training.data.PreTrainingDataModule 50 | init_args.config: 51 | ... 52 | sample_rate: 53 | source_1: 0.5 54 | source_2: 3.0 55 | ``` 56 | 57 | ## Example 58 | 59 | ```yaml 60 | ... 61 | data: 62 | class_path: llm_training.data.PreTrainingDataModule 63 | init_args.config: 64 | dataset_kwargs: 65 | path: HuggingFaceFW/fineweb 66 | name: sample-10BT 67 | num_proc: 32 # Number of threads for downloading 68 | pre_processed_data_path: data/pre_processed/phi-3/fineweb-sample-10bt 69 | tokenizer: # Phi-3 Tokenizer 70 | class_path: HFTokenizer 71 | init_args.path: microsoft/Phi-3-mini-128k-instruct 72 | batch_size: 1 73 | max_length: 4096 74 | validation_split: 30000 75 | num_proc: 32 # Number of cores for processing 76 | num_workers: 4 # Number of workers for data loader 77 | ``` 78 | -------------------------------------------------------------------------------- /src/llm_training/lightning/callbacks/output_redirection.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from io import StringIO 4 | from pathlib import Path 5 | from typing import TextIO 6 | 7 | from lightning import LightningModule, Trainer 8 | from lightning.pytorch.callbacks import Callback 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class OutputRedirection(Callback): 14 | LOG_FILE_SUFFIX: str = '.log' 15 | 16 | def __init__( 17 | self, 18 | log_file_name: str = '{index}-{version}', 19 | redirect_stdout: bool = True, 20 | redirect_stderr: bool = True, 21 | enabled: bool = True 22 | ) -> None: 23 | super().__init__() 24 | 25 | self.log_file_name = log_file_name 26 | self.redirect_stdout = redirect_stdout 27 | self.redirect_stderr = redirect_stderr 28 | self.enabled = enabled 29 | 30 | if not enabled: 31 | return 32 | 33 | self.stdout = sys.stdout 34 | self.stderr = sys.stderr 35 | self.log_file = None 36 | self.buffer = StringIO() 37 | 38 | if self.redirect_stdout: 39 | self.stdout_redirector = _StreamRedirector(self.stdout, self.buffer) 40 | sys.stdout = self.stdout_redirector 41 | 42 | if self.redirect_stderr: 43 | self.stderr_redirector = _StreamRedirector(self.stderr, self.buffer) 44 | sys.stderr = self.stderr_redirector 45 | 46 | self.redirect_loggers() 47 | 48 | def redirect_loggers(self) -> None: 49 | logger_names = [None, 'llm_training', 'lightning', 'lightning.fabric', 'lightning.pytorch'] 50 | for logger_name in logger_names: 51 | logger = logging.getLogger(logger_name) 52 | for handler in logger.handlers: 53 | if isinstance(handler, logging.StreamHandler): 54 | original_stream = handler.stream.src if isinstance(handler.stream, _StreamRedirector) else handler.stream 55 | if original_stream is self.stdout and self.redirect_stdout: 56 | handler.setStream(self.stdout_redirector) 57 | elif original_stream is self.stderr and self.redirect_stderr: 58 | handler.setStream(self.stderr_redirector) 59 | 60 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: 61 | if not self.enabled: 62 | return 63 | 64 | log_file_name = None 65 | 66 | log_dir = Path(trainer.log_dir) 67 | if trainer.is_global_zero: 68 | log_dir.mkdir(parents=True, exist_ok=True) 69 | log_file_name = self.log_file_name.format( 70 | index=len(list(log_dir.glob(f'*{self.LOG_FILE_SUFFIX}'))), 71 | version=trainer.logger.version or trainer.logger.name 72 | ) 73 | log_file_name += self.LOG_FILE_SUFFIX 74 | 75 | log_file_name = trainer.strategy.broadcast(log_file_name) 76 | self.log_file = open(log_dir / log_file_name, 'a', encoding='utf-8') 77 | 78 | self.log_file.write(self.buffer.getvalue()) 79 | self.buffer.truncate(0) 80 | self.buffer.seek(0) 81 | 82 | if self.redirect_stdout: 83 | self.stdout_redirector.dst = self.log_file 84 | 85 | if self.redirect_stderr: 86 | self.stderr_redirector.dst = self.log_file 87 | 88 | 89 | class _StreamRedirector: 90 | def __init__(self, src: TextIO, dst: TextIO) -> None: 91 | self.src = src 92 | self.dst = dst 93 | 94 | def write(self, s: str) -> int: 95 | r = self.src.write(s) 96 | self.dst.write(s) 97 | return r 98 | 99 | def flush(self) -> None: 100 | self.src.flush() 101 | self.dst.flush() 102 | -------------------------------------------------------------------------------- /src/llm_training/lightning/cli/cli.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Any, Callable 4 | 5 | from lightning.pytorch import LightningDataModule, LightningModule 6 | from lightning.pytorch import Trainer as _Trainer 7 | from lightning.pytorch.cli import ArgsType, LightningArgumentParser 8 | from lightning.pytorch.cli import LightningCLI as _LightningCLI 9 | from lightning.pytorch.cli import SaveConfigCallback as _SaveConfigCallback 10 | 11 | from llm_training.lightning import (ExtraConfig, OutputRedirection, 12 | SaveConfigCallback, TQDMProgressBar) 13 | 14 | from .trainer import Trainer 15 | 16 | 17 | class LightningCLI(_LightningCLI): 18 | def __init__( 19 | self, 20 | model_class: type[LightningModule] | Callable[..., LightningModule] | None = None, 21 | datamodule_class: type[LightningDataModule] | Callable[..., LightningDataModule] | None = None, 22 | save_config_callback: type[_SaveConfigCallback] | None = None, 23 | save_config_kwargs: dict[str, Any] | None = None, 24 | trainer_class: type[_Trainer] | Callable[..., _Trainer] = Trainer, 25 | trainer_defaults: dict[str, Any] | None = None, 26 | seed_everything_default: bool | int = True, 27 | parser_kwargs: dict[str, Any] | dict[str, dict[str, Any]] | None = None, 28 | subclass_mode_model: bool = False, 29 | subclass_mode_data: bool = False, 30 | args: ArgsType = None, 31 | run: bool = True, 32 | auto_configure_optimizers: bool = True 33 | ) -> None: 34 | save_config_callback = SaveConfigCallback if save_config_callback is None else save_config_callback 35 | default_save_config_kwargs = { 36 | 'overwrite': True, 37 | 'save_to_log_dir': False 38 | } 39 | save_config_kwargs = save_config_kwargs or {} 40 | save_config_kwargs = default_save_config_kwargs | save_config_kwargs 41 | 42 | default_parser_kwargs = { 43 | 'parser_mode': 'omegaconf' 44 | } 45 | parser_kwargs = parser_kwargs or {} 46 | parser_kwargs = default_parser_kwargs | parser_kwargs 47 | 48 | super().__init__( 49 | model_class=model_class, 50 | datamodule_class=datamodule_class, 51 | save_config_callback=save_config_callback, 52 | save_config_kwargs=save_config_kwargs, 53 | trainer_class=trainer_class, 54 | trainer_defaults=trainer_defaults, 55 | seed_everything_default=seed_everything_default, 56 | parser_kwargs=parser_kwargs, 57 | subclass_mode_model=subclass_mode_model, 58 | subclass_mode_data=subclass_mode_data, 59 | args=args, 60 | run=run, 61 | auto_configure_optimizers=auto_configure_optimizers 62 | ) 63 | 64 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 65 | parser.add_argument('--float32-matmul-precision', type=str | None, choices=['medium', 'high', 'highest'], default=None) 66 | parser.add_argument('--logging-level', type=str | int, default=logging.INFO) 67 | parser.add_argument('--env', type=dict | None, default=None) 68 | parser.add_lightning_class_args(OutputRedirection, 'output_redirection') 69 | parser.add_lightning_class_args(TQDMProgressBar, 'tqdm_progress') 70 | 71 | def _instantiate_extra_config(self) -> ExtraConfig: 72 | return ExtraConfig( 73 | float32_matmul_precision=self._get(self.config, 'float32_matmul_precision'), 74 | logging_level=self._get(self.config, 'logging_level'), 75 | env=self._get(self.config, 'env') 76 | ) 77 | 78 | def _instantiate_trainer(self, config, callbacks): 79 | callbacks.insert(0, self._instantiate_extra_config()) 80 | 81 | if int(os.getenv('SLURM_NTASKS', '0')) == 1: 82 | del os.environ['SLURM_JOB_ID'] 83 | del os.environ['SLURM_NTASKS'] 84 | 85 | return super()._instantiate_trainer(config, callbacks) 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /logs/* 2 | /config/* 3 | !/config/examples/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | cover/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | .pybuilder/ 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | # For a library or package, you might want to ignore these files since the code is 91 | # intended to run in multiple environments; otherwise, check them in: 92 | # .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # poetry 102 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 103 | # This is especially recommended for binary packages to ensure reproducibility, and is more 104 | # commonly ignored for libraries. 105 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 106 | #poetry.lock 107 | 108 | # pdm 109 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 110 | #pdm.lock 111 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 112 | # in version control. 113 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 114 | .pdm.toml 115 | .pdm-python 116 | .pdm-build/ 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/llama-3.2.j2: -------------------------------------------------------------------------------- 1 | {{- bos_token }} 2 | {%- if custom_tools is defined %} 3 | {%- set tools = custom_tools %} 4 | {%- endif %} 5 | {%- if not tools_in_user_message is defined %} 6 | {%- set tools_in_user_message = true %} 7 | {%- endif %} 8 | {%- if not date_string is defined %} 9 | {%- if strftime_now is defined %} 10 | {%- set date_string = strftime_now("%d %b %Y") %} 11 | {%- else %} 12 | {%- set date_string = "26 Jul 2024" %} 13 | {%- endif %} 14 | {%- endif %} 15 | {%- if not tools is defined %} 16 | {%- set tools = none %} 17 | {%- endif %} 18 | 19 | {#- This block extracts the system message, so we can slot it into the right place. #} 20 | {%- if messages[0]['role'] == 'system' %} 21 | {%- set system_message = messages[0]['content']|trim %} 22 | {%- set messages = messages[1:] %} 23 | {%- else %} 24 | {%- set system_message = "" %} 25 | {%- endif %} 26 | 27 | {#- System message #} 28 | {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} 29 | {%- if tools is not none %} 30 | {{- "Environment: ipython\n" }} 31 | {%- endif %} 32 | {{- "Cutting Knowledge Date: December 2023\n" }} 33 | {{- "Today Date: " + date_string + "\n\n" }} 34 | {%- if tools is not none and not tools_in_user_message %} 35 | {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} 36 | {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} 37 | {{- "Do not use variables.\n\n" }} 38 | {%- for t in tools %} 39 | {{- t | tojson(indent=4) }} 40 | {{- "\n\n" }} 41 | {%- endfor %} 42 | {%- endif %} 43 | {{- system_message }} 44 | {{- "<|eot_id|>" }} 45 | 46 | {#- Custom tools are passed in a user message with some extra guidance #} 47 | {%- if tools_in_user_message and not tools is none %} 48 | {#- Extract the first user message so we can plug it in here #} 49 | {%- if messages | length != 0 %} 50 | {%- set first_user_message = messages[0]['content']|trim %} 51 | {%- set messages = messages[1:] %} 52 | {%- else %} 53 | {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} 54 | {%- endif %} 55 | {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} 56 | {{- "Given the following functions, please respond with a JSON for a function call " }} 57 | {{- "with its proper arguments that best answers the given prompt.\n\n" }} 58 | {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} 59 | {{- "Do not use variables.\n\n" }} 60 | {%- for t in tools %} 61 | {{- t | tojson(indent=4) }} 62 | {{- "\n\n" }} 63 | {%- endfor %} 64 | {{- first_user_message + "<|eot_id|>"}} 65 | {%- endif %} 66 | 67 | {%- for message in messages %} 68 | {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} 69 | {{- '<|start_header_id|>' + message.role + '<|end_header_id|>\n\n' }} 70 | {%- set content = message.content | trim + '<|eot_id|>' %} 71 | {%- if message.role == 'assistant' %} 72 | {% generation %} 73 | {{- content -}} 74 | {% endgeneration %} 75 | {% else %} 76 | {{- content }} 77 | {%- endif %} 78 | {%- elif 'tool_calls' in message %} 79 | {%- if not message.tool_calls|length == 1 %} 80 | {{- raise_exception("This model only supports single tool-calls at once!") }} 81 | {%- endif %} 82 | {%- set tool_call = message.tool_calls[0].function %} 83 | {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} 84 | {% generation %} 85 | {{- '{"name": "' + tool_call.name + '", ' }} 86 | {{- '"parameters": ' }} 87 | {{- tool_call.arguments | tojson }} 88 | {{- "}" }} 89 | {{- "<|eot_id|>" -}} 90 | {% endgeneration %} 91 | {%- elif message.role == "tool" or message.role == "ipython" %} 92 | {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} 93 | {%- if message.content is mapping or message.content is iterable %} 94 | {{- message.content | tojson }} 95 | {%- else %} 96 | {{- message.content }} 97 | {%- endif %} 98 | {{- "<|eot_id|>" }} 99 | {%- endif %} 100 | {%- endfor %} 101 | {%- if add_generation_prompt %} 102 | {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} 103 | {%- endif %} 104 | -------------------------------------------------------------------------------- /src/llm_training/models/hf_causal_lm/hf_causal_lm.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.distributed 5 | from liger_kernel.transformers import _apply_liger_kernel_to_instance 6 | from torch import nn 7 | from torch.distributed._composable.fsdp import fully_shard 8 | from transformers import (AutoConfig, AutoModelForCausalLM, 9 | modeling_flash_attention_utils) 10 | from transformers.modeling_utils import no_init_weights 11 | 12 | from llm_training.models.hf_compat_model import HFCompatModel 13 | from llm_training.models.utils.modeling_outputs import CausalLMOutput 14 | from llm_training.ops.attention_op import _get_unpad_data 15 | from llm_training.utils.decorators import copy_method_signature 16 | 17 | from .hf_causal_lm_config import HFCausalLMConfig 18 | 19 | # Patch for packed attention masks (FA only) 20 | modeling_flash_attention_utils._get_unpad_data = _get_unpad_data 21 | 22 | class HFCausalLM(HFCompatModel): 23 | config: HFCausalLMConfig 24 | 25 | config_class = HFCausalLMConfig 26 | hf_model_class = AutoModelForCausalLM 27 | hf_config_class = AutoConfig 28 | 29 | def __init__(self, config: HFCausalLMConfig) -> None: 30 | super().__init__(config) 31 | 32 | self.config.hf_config = self.hf_config 33 | 34 | with no_init_weights(not self._init_weights): 35 | self.hf_model = self.construct_hf_model() 36 | 37 | if self.config.enable_gradient_checkpointing: 38 | self.hf_model.gradient_checkpointing_enable({'use_reentrant': False}) 39 | 40 | self.hf_model.tie_weights() 41 | 42 | if self.config.enable_liger_kernel: 43 | _apply_liger_kernel_to_instance(self.hf_model, rope=False) 44 | 45 | def convert_state_dict_from_hf(self, hf_state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 46 | return {'hf_model.' + k: v for k, v in hf_state_dict.items()} 47 | 48 | def convert_state_dict_to_hf(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 49 | return {k.removeprefix('hf_model.'): v for k, v in state_dict.items()} 50 | 51 | def forward( 52 | self, 53 | input_ids: torch.Tensor | None = None, 54 | attention_mask: torch.Tensor | None = None, 55 | position_ids: torch.Tensor | None = None, 56 | inputs_embeds: torch.Tensor | None = None, 57 | return_last_hidden_states: bool = False 58 | ) -> CausalLMOutput: 59 | if self.hf_config._attn_implementation != 'flash_attention_2': 60 | attention_mask = attention_mask.clamp_max(1) 61 | 62 | outputs = self.hf_model( 63 | input_ids=input_ids, 64 | attention_mask=attention_mask, 65 | position_ids=position_ids, 66 | inputs_embeds=inputs_embeds, 67 | output_hidden_states=return_last_hidden_states 68 | ) 69 | 70 | last_hidden_states = None 71 | if return_last_hidden_states: 72 | last_hidden_states = outputs.hidden_states[-1] 73 | 74 | return CausalLMOutput( 75 | logits=outputs.logits, 76 | last_hidden_states=last_hidden_states 77 | ) 78 | 79 | @copy_method_signature(forward) 80 | def __call__(): ... 81 | 82 | def get_input_embeddings(self) -> nn.Embedding: 83 | return self.hf_model.get_input_embeddings() 84 | 85 | def get_output_embeddings(self) -> nn.Linear: 86 | return self.hf_model.get_output_embeddings() 87 | 88 | def configure_fully_sharded_data_parallel(self, dp_mesh, reshard_after_forward, mp_policy, offload_policy, **kwargs): 89 | if dp_mesh.size() == 1: 90 | return 91 | 92 | fully_shard_ = partial( 93 | fully_shard, 94 | mesh=dp_mesh, 95 | reshard_after_forward=reshard_after_forward, 96 | mp_policy=mp_policy, 97 | offload_policy=offload_policy 98 | ) 99 | 100 | fully_sharded_module_names = [] 101 | no_split_modules = self.hf_model._no_split_modules or [] 102 | for n, m in self.named_modules(): 103 | if m.__class__.__name__ in no_split_modules: 104 | fully_shard_(m) 105 | fully_sharded_module_names.append(n) 106 | 107 | # shard the rest modules for gradient clipping as a workaround 108 | for n, m in self.named_modules(): 109 | for o in fully_sharded_module_names: 110 | if o.startswith(n) or n.startswith(o): 111 | break 112 | else: 113 | if any(True for _ in m.parameters()): 114 | fully_shard_(m) 115 | -------------------------------------------------------------------------------- /src/llm_training/data/base_datamodule.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | from functools import partial 4 | from typing import Mapping, TextIO 5 | 6 | import lightning as L 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | from .base_datacollator import BaseDataCollator 10 | from .base_datamodule_config import BaseDataModuleConfig 11 | from .resumable_dataloader import ResumableDataLoader 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | DatasetDict = Mapping[str, Dataset] 16 | 17 | 18 | class BaseDataModule(L.LightningDataModule): 19 | datacollator_class: type[BaseDataCollator] | None = None 20 | 21 | def __init__(self, config: BaseDataModuleConfig) -> None: 22 | super().__init__() 23 | 24 | self.config = config 25 | self.datacollator = self.datacollator_class(config) if self.datacollator_class is not None else None 26 | self.prepare_data_per_node = config.prepare_data_per_node 27 | self.raw_dataset_dict = None 28 | self.pre_processed_dataset_dict = None 29 | self.dataset_dict = None 30 | 31 | def load_data(self) -> DatasetDict: 32 | raise NotImplementedError() 33 | 34 | def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: 35 | return dataset_dict 36 | 37 | def split_data(self, dataset_dict: DatasetDict): 38 | return dataset_dict 39 | 40 | def post_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: 41 | dataset_dict = self.split_data(dataset_dict) 42 | return dataset_dict 43 | 44 | def prepare_data(self) -> None: 45 | if self.config.pre_processed_data_path is None: 46 | dataset_dict = self.load_data() 47 | dataset_dict = self.pre_process_data(dataset_dict) 48 | 49 | def save_pre_processed_data(self, path: str | None = None) -> None: 50 | raise NotImplementedError() 51 | 52 | def load_pre_processed_data(self, path: str | None = None) -> None: 53 | raise NotImplementedError() 54 | 55 | def print_dataset_info(self, file: TextIO | None = None) -> None: 56 | print_ = partial(print, file=file) 57 | def print_header(header: str) -> None: 58 | n = shutil.get_terminal_size().columns 59 | m = (n - len(header) - 2) // 2 60 | divider = '─' * m 61 | header = f'{divider} {header} {divider}' 62 | print_(f'{header:^{n}}', end='\n\n') 63 | 64 | print_header('Raw Dataset') 65 | print_(self.raw_dataset_dict, end='\n\n') 66 | print_header('Pre-processed Dataset') 67 | print_(self.pre_processed_dataset_dict, end='\n\n') 68 | print_header('Final Dataset') 69 | print_(self.dataset_dict, end='\n\n') 70 | 71 | def _get_dataloader(self, split: str): 72 | dataloader_class = DataLoader 73 | dataloader_kwargs = dict( 74 | dataset=self.dataset_dict[split], 75 | batch_size=self.config.batch_size, 76 | num_workers=self.config.num_workers, 77 | collate_fn=self.datacollator, 78 | pin_memory=self.config.pin_memory, 79 | prefetch_factor=self.config.prefetch_factor 80 | ) 81 | 82 | if split == 'train': 83 | dataloader_class = ResumableDataLoader 84 | dataloader_kwargs['shuffle'] = True 85 | dataloader_kwargs['trainer'] = self.trainer 86 | 87 | return dataloader_class(**dataloader_kwargs) 88 | 89 | def setup(self, stage: str | None = None) -> None: 90 | if self.config.pre_processed_data_path is None: 91 | self.raw_dataset_dict = self.load_data() 92 | self.pre_processed_dataset_dict = self.pre_process_data(self.raw_dataset_dict) 93 | else: 94 | logger.info('Load pre-processed data') 95 | self.load_pre_processed_data(self.config.pre_processed_data_path) 96 | logger.info('Done') 97 | 98 | self.dataset_dict = self.post_process_data(self.pre_processed_dataset_dict) 99 | 100 | mapping = { 101 | 'train': 'train_dataloader', 102 | 'validation': 'val_dataloader', 103 | 'test': 'test_dataloader', 104 | 'predict': 'predict_dataloader' 105 | } 106 | 107 | for k, v in mapping.items(): 108 | if k in self.dataset_dict: 109 | setattr(self, v, partial(self._get_dataloader, k)) 110 | else: 111 | setattr(self, v, getattr(super(), v)) 112 | 113 | def train_dataloader(self) -> DataLoader | None: ... 114 | 115 | def val_dataloader(self) -> DataLoader | None: ... 116 | 117 | def test_dataloader(self) -> DataLoader | None: ... 118 | 119 | def predict_dataloader(self) -> DataLoader | None: ... 120 | -------------------------------------------------------------------------------- /src/llm_training/models/hf_compat_model/hf_compat_model.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Generator 3 | 4 | import torch 5 | from accelerate import init_empty_weights 6 | from transformers import (AutoConfig, AutoModel, AutoTokenizer, 7 | GenerationConfig, PretrainedConfig, PreTrainedModel, 8 | PreTrainedTokenizerBase) 9 | from transformers.models.auto.auto_factory import _BaseAutoModelClass 10 | 11 | from llm_training.models.base_model.base_model import BaseModel 12 | 13 | from .hf_compat_config import HFCompatModelConfig 14 | 15 | 16 | class HFCompatModel(BaseModel): 17 | config: HFCompatModelConfig 18 | 19 | hf_config_class: type[AutoConfig] | type[PretrainedConfig] = AutoConfig 20 | hf_model_class: type[_BaseAutoModelClass] | type[PreTrainedModel] = AutoModel 21 | 22 | @property 23 | def has_pre_trained_weights(self) -> bool: 24 | return ( 25 | super().has_pre_trained_weights 26 | or (self.config.hf_path is not None and self.config.load_hf_weights) 27 | ) 28 | 29 | def __init__(self, config: HFCompatModelConfig) -> None: 30 | super().__init__(config) 31 | 32 | if self.config.hf_path is not None: 33 | self.hf_config = self.load_hf_config() 34 | self.merge_hf_config(self.hf_config) 35 | 36 | def merge_hf_config(self, hf_config: PretrainedConfig) -> None: ... 37 | 38 | def load_hf_config(self, **kwargs) -> PretrainedConfig: 39 | default_kwargs = { 40 | 'trust_remote_code': self.config.trust_remote_code, 41 | 'revision': self.config.revision, 42 | 'attn_implementation': self.config._attn_implementation, 43 | **self.config.hf_extra_kwargs 44 | } 45 | kwargs = default_kwargs | kwargs 46 | return self.hf_config_class.from_pretrained(self.config.hf_path, **kwargs) 47 | 48 | def load_hf_model(self, **kwargs) -> "PreTrainedModel": 49 | default_kwargs = { 50 | 'low_cpu_mem_usage': self.config.low_cpu_mem_usage, 51 | 'torch_dtype': self.config.torch_dtype, 52 | 'trust_remote_code': self.config.trust_remote_code, 53 | 'revision': self.config.revision, 54 | 'attn_implementation': self.config._attn_implementation, 55 | **self.config.hf_extra_kwargs 56 | } 57 | kwargs = default_kwargs | kwargs 58 | return self.hf_model_class.from_pretrained(self.config.hf_path, **kwargs) 59 | 60 | def load_hf_tokenizer(self, **kwargs) -> PreTrainedTokenizerBase: 61 | path = self.config.hf_tokenizer_path or self.config.hf_path 62 | default_kwargs = { 63 | 'trust_remote_code': self.config.trust_remote_code, 64 | 'revision': self.config.revision 65 | } 66 | kwargs = default_kwargs | kwargs 67 | return AutoTokenizer.from_pretrained(path, **kwargs) 68 | 69 | @contextmanager 70 | def torch_dtype_context(self) -> Generator[None, None, None]: 71 | original_dtype = torch.get_default_dtype() 72 | torch_dtype = self.config.torch_dtype 73 | torch_dtype = original_dtype if torch_dtype == 'auto' else torch_dtype 74 | torch.set_default_dtype(torch_dtype) 75 | yield 76 | torch.set_default_dtype(original_dtype) 77 | 78 | def construct_hf_model(self, **kwargs) -> PreTrainedModel: 79 | with self.torch_dtype_context(): 80 | if issubclass(self.hf_model_class, _BaseAutoModelClass): 81 | default_kwargs = {} 82 | default_kwargs['trust_remote_code'] = self.config.trust_remote_code 83 | default_kwargs['attn_implementation'] = self.config._attn_implementation 84 | kwargs = default_kwargs | kwargs 85 | model = self.hf_model_class.from_config(self.hf_config, **kwargs) 86 | else: 87 | model = self.hf_model_class(self.hf_config) 88 | 89 | try: 90 | model.generation_config = GenerationConfig.from_pretrained(self.config.hf_path) 91 | except OSError: 92 | ... 93 | 94 | return model 95 | 96 | def convert_state_dict_from_hf(self, hf_state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 97 | return hf_state_dict 98 | 99 | def convert_state_dict_to_hf(self, state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: 100 | return state_dict 101 | 102 | def get_hf_pre_trained_weights(self) -> dict[str, torch.Tensor]: 103 | model = self.load_hf_model() 104 | state_dict = model.state_dict() 105 | state_dict = self.convert_state_dict_from_hf(state_dict) 106 | return state_dict 107 | 108 | def get_pre_trained_weights(self) -> dict[str, torch.Tensor]: 109 | if self.config.hf_path is not None: 110 | return self.get_hf_pre_trained_weights() 111 | return super().get_pre_trained_weights() 112 | 113 | def get_hf_model(self) -> PreTrainedModel: 114 | with init_empty_weights(include_buffers=False): 115 | hf_model = self.construct_hf_model() 116 | state_dict = self.convert_state_dict_to_hf(self.state_dict()) 117 | hf_model.load_state_dict(state_dict, assign=True) 118 | hf_model.tie_weights() 119 | return hf_model 120 | -------------------------------------------------------------------------------- /src/llm_training/data/chat_templates/llama-3.1.j2: -------------------------------------------------------------------------------- 1 | {{- bos_token }} 2 | {%- if custom_tools is defined %} 3 | {%- set tools = custom_tools %} 4 | {%- endif %} 5 | {%- if not tools_in_user_message is defined %} 6 | {%- set tools_in_user_message = true %} 7 | {%- endif %} 8 | {%- if not date_string is defined %} 9 | {%- set date_string = "26 Jul 2024" %} 10 | {%- endif %} 11 | {%- if not tools is defined %} 12 | {%- set tools = none %} 13 | {%- endif %} 14 | 15 | {#- This block extracts the system message, so we can slot it into the right place. #} 16 | {%- if messages[0]['role'] == 'system' %} 17 | {%- set system_message = messages[0]['content']|trim %} 18 | {%- set messages = messages[1:] %} 19 | {%- else %} 20 | {%- set system_message = "" %} 21 | {%- endif %} 22 | 23 | {#- System message + builtin tools #} 24 | {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} 25 | {%- if builtin_tools is defined or tools is not none %} 26 | {{- "Environment: ipython\n" }} 27 | {%- endif %} 28 | {%- if builtin_tools is defined %} 29 | {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} 30 | {%- endif %} 31 | {{- "Cutting Knowledge Date: December 2023\n" }} 32 | {{- "Today Date: " + date_string + "\n\n" }} 33 | {%- if tools is not none and not tools_in_user_message %} 34 | {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} 35 | {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} 36 | {{- "Do not use variables.\n\n" }} 37 | {%- for t in tools %} 38 | {{- t | tojson(indent=4) }} 39 | {{- "\n\n" }} 40 | {%- endfor %} 41 | {%- endif %} 42 | {{- system_message }} 43 | {{- "<|eot_id|>" }} 44 | 45 | {#- Custom tools are passed in a user message with some extra guidance #} 46 | {%- if tools_in_user_message and not tools is none %} 47 | {#- Extract the first user message so we can plug it in here #} 48 | {%- if messages | length != 0 %} 49 | {%- set first_user_message = messages[0]['content']|trim %} 50 | {%- set messages = messages[1:] %} 51 | {%- else %} 52 | {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} 53 | {%- endif %} 54 | {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} 55 | {{- "Given the following functions, please respond with a JSON for a function call " }} 56 | {{- "with its proper arguments that best answers the given prompt.\n\n" }} 57 | {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} 58 | {{- "Do not use variables.\n\n" }} 59 | {%- for t in tools %} 60 | {{- t | tojson(indent=4) }} 61 | {{- "\n\n" }} 62 | {%- endfor %} 63 | {{- first_user_message + "<|eot_id|>"}} 64 | {%- endif %} 65 | 66 | {%- for message in messages %} 67 | {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} 68 | {{- '<|start_header_id|>' + message.role + '<|end_header_id|>\n\n' }} 69 | {%- set content = message.content | trim + '<|eot_id|>' %} 70 | {%- if message.role == 'assistant' %} 71 | {% generation %} 72 | {{- content -}} 73 | {% endgeneration %} 74 | {% else %} 75 | {{- content }} 76 | {%- endif %} 77 | {%- elif 'tool_calls' in message %} 78 | {%- if not message.tool_calls|length == 1 %} 79 | {{- raise_exception("This model only supports single tool-calls at once!") }} 80 | {%- endif %} 81 | {%- set tool_call = message.tool_calls[0].function %} 82 | {%- if builtin_tools is defined and tool_call.name in builtin_tools %} 83 | {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} 84 | {% generation %} 85 | {{- "<|python_tag|>" + tool_call.name + ".call(" -}} 86 | {% endgeneration %} 87 | {%- for arg_name, arg_val in tool_call.arguments | items %} 88 | {{- arg_name + '="' + arg_val + '"' }} 89 | {%- if not loop.last %} 90 | {{- ", " }} 91 | {%- endif %} 92 | {%- endfor %} 93 | {{- ")" }} 94 | {%- else %} 95 | {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} 96 | {% generation %} 97 | {{- '{"name": "' + tool_call.name + '", ' }} 98 | {{- '"parameters": ' }} 99 | {{- tool_call.arguments | tojson }} 100 | {{- "}" -}} 101 | {% endgeneration %} 102 | {%- endif %} 103 | {% generation %} 104 | {%- if builtin_tools is defined %} 105 | {#- This means we're in ipython mode #} 106 | {{- "<|eom_id|>" }} 107 | {%- else %} 108 | {{- "<|eot_id|>" }} 109 | {%- endif -%} 110 | {% endgeneration %} 111 | {%- elif message.role == "tool" or message.role == "ipython" %} 112 | {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} 113 | {%- if message.content is mapping or message.content is iterable %} 114 | {{- message.content | tojson }} 115 | {%- else %} 116 | {{- message.content }} 117 | {%- endif %} 118 | {{- "<|eot_id|>" }} 119 | {%- endif %} 120 | {%- endfor %} 121 | {%- if add_generation_prompt %} 122 | {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} 123 | {%- endif %} 124 | -------------------------------------------------------------------------------- /docs/instruction_tuning.md: -------------------------------------------------------------------------------- 1 | # Instruction Tuning 2 | 3 | Instruction tuning is implemented by the [`InstructionTuningDataModule`](/src/llm_training/data/instruction_tuning/instruction_tuning_datamodule.py). 4 | 5 | Same as `PreTrainingDataModule`, it uses [datasets](https://github.com/huggingface/datasets) under the hood to load and process data. 6 | 7 | A valid dataset must include a `messages` field, which is an array. 8 | The format can be referenced as follows: 9 | ```json 10 | [ 11 | {"role": "user", "content": "Hello, how are you?"}, 12 | {"role": "assistant", "content": "I'm doing great. How can I help you today?"}, 13 | {"role": "user", "content": "I'd like to show off how chat templating works!"} 14 | ] 15 | ``` 16 | 17 | See [Templates for Chat Models](https://huggingface.co/docs/transformers/main/en/chat_templating) for more details. 18 | 19 | ## Key Parameters 20 | 21 | | Parameter | Description | 22 | | :--------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | 23 | | `dataset_kwargs` | kwargs to be passed to [`datasets.load_dataset`](https://huggingface.co/docs/datasets/loading) for loading data. | 24 | | `tokenizer` | A transformers tokenizer for tokenizing the data. | 25 | | `chat_template` | If the value exists in the predefined templates, the predefined template will be selected. Otherwise, the value will be used directly as a Jinja2 syntax template. If the value is None, the tokenizer's built-in template will be used. Note that using the original templates like llama-3, phi-3, etc., directly often leads to incorrect labels. Therefore, it is recommended to use predefined templates. If the desired predefined template does not exist, you should modify the original template yourself. | 26 | | `packing_method` | Methods for concatenating data. `NO_PACKING` will do nothing. `GROUP_BY_LENGTH` will group data based on the length of each entry, ensuring that the combined length of each group does not exceed `max_length` | 27 | | `max_length` | Max length of the tokenized data. | 28 | | `num_proc` | Number of CPU cores for processing data. | 29 | 30 | For a complete set of parameters, please refer to [`InstructionTuningDataModuleConfig`](/src/llm_training/data/instruction_tuning/instruction_tuning_datamodule_config.py). 31 | 32 | ## Example 33 | 34 | ```yaml 35 | ... 36 | data: 37 | class_path: llm_training.data.InstructionTuningDataModule 38 | init_args.config: 39 | dataset_kwargs: 40 | path: ShinoharaHare/Infinity-Instruct-Reformatted 41 | name: "0625" 42 | tokenizer: 43 | class_path: HFTokenizer 44 | init_args: 45 | path: microsoft/Phi-3-mini-128k-instruct 46 | batch_size: 1 47 | add_default_system_prompt_rate: 0.0 48 | default_system_prompt: "" 49 | chat_template: phi-3 50 | packing_method: GROUP_BY_LENGTH 51 | max_length: 4096 52 | pad_to_multiple_of: 64 53 | validation_split: null 54 | num_proc: 4 55 | num_workers: 4 56 | enable_cache: true 57 | ``` 58 | -------------------------------------------------------------------------------- /src/llm_training/data/preference_tuning/preference_tuning_datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import tokenizers 4 | from datasets import Features, Sequence, Value 5 | from packaging.version import Version 6 | from transformers import PreTrainedTokenizerBase 7 | 8 | from llm_training.data.hf_based.hf_based_datamodule import (DatasetDict, 9 | HFBasedDataModule) 10 | 11 | from .preference_tuning_datacollator import PreferenceTuningDataCollator 12 | from .preference_tuning_datamodule_config import ( 13 | OverlongHandlingMethod, PreferenceTuningDataModuleConfig) 14 | 15 | 16 | class PreferenceTuningDataModule(HFBasedDataModule): 17 | config: PreferenceTuningDataModuleConfig 18 | datacollator_class = PreferenceTuningDataCollator 19 | 20 | def __init__(self, config: PreferenceTuningDataModuleConfig) -> None: 21 | super().__init__(config) 22 | 23 | if Version(tokenizers.__version__) < Version('0.20.1'): 24 | raise ValueError( 25 | "`tokenizers` must be at least version 0.20.1, " 26 | "otherwise LLaMA 3 tokenizer will produce incorrect prompt/response mask." 27 | ) 28 | 29 | @classmethod 30 | def _apply_chat_template_and_tokenize( 31 | cls, 32 | batch: dict[str, list[str]], 33 | tokenizer: PreTrainedTokenizerBase, 34 | chat_template: str | None 35 | ): 36 | new_batch = { 37 | 'chosen_input_ids': [], 38 | 'chosen_labels': [], 39 | 'chosen_length': [], 40 | 'rejected_input_ids': [], 41 | 'rejected_labels': [], 42 | 'rejected_length': [] 43 | } 44 | 45 | chosen_messages = [] 46 | rejected_messages = [] 47 | for prompt, chosen, rejected in zip( 48 | batch['prompt'], 49 | batch['chosen'], 50 | batch['rejected'] 51 | ): 52 | chosen_messages.append([ 53 | {'role': 'user', 'content': prompt}, 54 | {'role': 'assistant', 'content': chosen} 55 | ]) 56 | 57 | rejected_messages.append([ 58 | {'role': 'user', 'content': prompt}, 59 | {'role': 'assistant', 'content': rejected} 60 | ]) 61 | 62 | kwargs = dict( 63 | chat_template=chat_template, 64 | return_dict=True, 65 | return_assistant_tokens_mask=True, 66 | tokenizer_kwargs=dict( 67 | return_attention_mask=False, 68 | verbose=False 69 | ) 70 | ) 71 | 72 | chosen_batch_encoding = tokenizer.apply_chat_template(chosen_messages, **kwargs) 73 | for input_ids, assistant_masks in zip( 74 | chosen_batch_encoding['input_ids'], 75 | chosen_batch_encoding['assistant_masks'] 76 | ): 77 | labels = [i if a == 1 else -100 for i, a in zip(input_ids, assistant_masks)] 78 | new_batch['chosen_input_ids'].append(input_ids) 79 | new_batch['chosen_labels'].append(labels) 80 | new_batch['chosen_length'].append(len(input_ids)) 81 | 82 | rejected_batch_encoding = tokenizer.apply_chat_template(rejected_messages, **kwargs) 83 | for input_ids, assistant_masks in zip( 84 | rejected_batch_encoding['input_ids'], 85 | rejected_batch_encoding['assistant_masks'] 86 | ): 87 | labels = [i if a == 1 else -100 for i, a in zip(input_ids, assistant_masks)] 88 | new_batch['rejected_input_ids'].append(input_ids) 89 | new_batch['rejected_labels'].append(labels) 90 | new_batch['rejected_length'].append(len(input_ids)) 91 | 92 | return new_batch 93 | 94 | @classmethod 95 | def _drop_overlong_examples( 96 | cls, 97 | batch: dict[str, Any], 98 | max_length: int 99 | ): 100 | indices = [ 101 | i for i in range(len(batch['chosen_length'])) 102 | if max(batch['chosen_length'][i], batch['rejected_length'][i]) <= max_length 103 | ] 104 | return {k: [v[i] for i in indices] for k, v in batch.items()} 105 | 106 | @classmethod 107 | def _pre_process_data( 108 | cls, 109 | batch: dict[str, Any], 110 | tokenizer: PreTrainedTokenizerBase, 111 | chat_template: str | None, 112 | max_length: int | None, 113 | overlong_handling_method: OverlongHandlingMethod 114 | ): 115 | batch = cls._apply_chat_template_and_tokenize( 116 | batch, 117 | tokenizer=tokenizer, 118 | chat_template=chat_template 119 | ) 120 | 121 | if max_length is not None: 122 | if overlong_handling_method == OverlongHandlingMethod.DROP: 123 | batch = cls._drop_overlong_examples(batch, max_length) 124 | 125 | return batch 126 | 127 | def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: 128 | dataset_dict = self.map_dataset_dict( 129 | dataset_dict, 130 | self._pre_process_data, 131 | fn_kwargs=dict( 132 | tokenizer=self.config.tokenizer, 133 | chat_template=self.config.chat_template, 134 | max_length=self.config.max_length, 135 | overlong_handling_method=self.config.overlong_handling_method 136 | ), 137 | batched=True, 138 | remove_columns=True, 139 | num_proc=self.config.num_proc, 140 | features=Features({ 141 | 'chosen_input_ids': Sequence(Value('int32')), 142 | 'chosen_labels': Sequence(Value('int32')), 143 | 'chosen_length': Value('uint32'), 144 | 'rejected_input_ids': Sequence(Value('int32')), 145 | 'rejected_labels': Sequence(Value('int32')), 146 | 'rejected_length': Value('uint32') 147 | }), 148 | desc='Pre-processing data' 149 | ) 150 | return dataset_dict 151 | -------------------------------------------------------------------------------- /scripts/extend_fast_tokenizer.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import fire 4 | from tokenizers import Tokenizer, decoders, pre_tokenizers 5 | from transformers import AutoTokenizer, PreTrainedTokenizerFast 6 | 7 | 8 | def main( 9 | tokenizer_path: str, 10 | vocab_path: str, 11 | output_path: str, 12 | add_missing_vocab: bool = True, 13 | pad_token: str | None = None 14 | ) -> None: 15 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) 16 | 17 | with open(vocab_path) as f: 18 | vocab = json.load(f) 19 | 20 | new_tokenizer = extend_vocab(tokenizer, vocab, add_missing_vocab=add_missing_vocab) 21 | new_tokenizer.slow_tokenizer_class = None 22 | 23 | if pad_token is not None: 24 | new_tokenizer.pad_token = pad_token 25 | 26 | new_tokenizer.save_pretrained(output_path) 27 | 28 | 29 | def get_byte_level_pre_tokenizer(tokenizer: PreTrainedTokenizerFast) -> pre_tokenizers.ByteLevel | None: 30 | rust_tokenizer = tokenizer.backend_tokenizer 31 | tokenizer_json = json.loads(rust_tokenizer.to_str()) 32 | pre_tokenizer_json = tokenizer_json['pre_tokenizer'] 33 | 34 | if pre_tokenizer_json is None: 35 | return None 36 | 37 | pre_tokenizers_ = [] 38 | 39 | if pre_tokenizer_json['type'] == 'ByteLevel': 40 | pre_tokenizers_ = [pre_tokenizer_json] 41 | elif pre_tokenizer_json['type'] == 'Sequence': 42 | pre_tokenizers_ = pre_tokenizer_json['pretokenizers'] 43 | 44 | pre_tokenizer = None 45 | for pt in pre_tokenizers_: 46 | if pt['type'] == 'ByteLevel': 47 | kwargs = {k: v for k, v in pt.items() if k != 'type'} 48 | pre_tokenizer = pre_tokenizers.ByteLevel(**kwargs) 49 | 50 | return pre_tokenizer 51 | 52 | 53 | def get_byte_level_decoder(tokenizer: PreTrainedTokenizerFast) -> decoders.ByteLevel | None: 54 | rust_tokenizer = tokenizer.backend_tokenizer 55 | tokenizer_json = json.loads(rust_tokenizer.to_str()) 56 | decoder_json = tokenizer_json['decoder'] 57 | 58 | decoders_ = [] 59 | if decoder_json['type'] == 'ByteLevel': 60 | decoders_ = [decoder_json] 61 | elif decoder_json['type'] == 'Sequence': 62 | decoders_ = decoder_json['decoders'] 63 | 64 | decoder = None 65 | for d in decoders_: 66 | if d['type'] == 'ByteLevel': 67 | kwargs = {k: v for k, v in d.items() if k != 'type'} 68 | decoder = decoders.ByteLevel(**kwargs) 69 | 70 | return decoder 71 | 72 | 73 | def iterate_vocab(vocab: dict[str, int]): 74 | for k, v in sorted(vocab.items(), key=lambda x: x[1]): 75 | yield k, v 76 | 77 | 78 | def compute_merges( 79 | vocab: dict[str, int], 80 | decoder: decoders.ByteLevel | None = None 81 | ) -> tuple[list[tuple[str, str]], set[str]]: 82 | merges = [] 83 | missing_vocab = set() 84 | for merge, piece_score in vocab.items(): 85 | local = [] 86 | missing_pieces = [] 87 | for index in range(1, len(merge)): 88 | piece_l, piece_r = merge[:index], merge[index:] 89 | if piece_l in vocab and piece_r in vocab: 90 | local.append((piece_l, piece_r, piece_score)) 91 | else: 92 | missing_pieces.append((piece_l, piece_r)) 93 | local = sorted(local, key=lambda x: (vocab[x[0]], vocab[x[1]])) 94 | missing_pieces = sorted(missing_pieces, key=lambda x: (vocab.get(x[0], len(vocab)), (vocab.get(x[1], len(vocab))))) 95 | 96 | if len(merges) > 1 and not local: 97 | merge_str = merge.replace('\n', '\\n').replace('\x9d', '\\x9d') 98 | print(f'Missing merge: {merge_str}', end='') 99 | 100 | if decoder is not None: 101 | print(f'({decoder.decode([merge])})', end='') 102 | 103 | print() 104 | 105 | if missing_pieces: 106 | piece_l, piece_r = missing_pieces[0] 107 | if piece_l not in vocab: 108 | missing_vocab.add(piece_l) 109 | if piece_r not in vocab: 110 | missing_vocab.add(piece_r) 111 | 112 | merges.extend(local) 113 | 114 | merges = sorted(merges, key=lambda val: val[2], reverse=True) 115 | merges = [(m[0], m[1]) for m in merges] 116 | return merges, missing_vocab 117 | 118 | 119 | def extend_vocab( 120 | tokenizer: PreTrainedTokenizerFast, 121 | vocab: dict[str, int], 122 | add_missing_vocab: bool 123 | ) -> PreTrainedTokenizerFast: 124 | byte_level_pre_tokenizer = get_byte_level_pre_tokenizer(tokenizer) 125 | byte_level_decoder = get_byte_level_decoder(tokenizer) 126 | 127 | if byte_level_pre_tokenizer is not None: 128 | vocab_tmp = {} 129 | for k, v in iterate_vocab(vocab): 130 | k = byte_level_pre_tokenizer.pre_tokenize_str(k)[0][0] 131 | vocab_tmp[k] = len(vocab) 132 | vocab = vocab_tmp 133 | 134 | tokenizer_json = json.loads(tokenizer.backend_tokenizer.to_str()) 135 | added_vocab = tokenizer.get_added_vocab() 136 | tokenizer_vocab = tokenizer_json['model']['vocab'] 137 | 138 | # Make sure added tokens are in the vocab 139 | for k, v in iterate_vocab(added_vocab): 140 | if k in tokenizer_vocab: 141 | continue 142 | tokenizer_vocab[k] = v 143 | 144 | for k, v in iterate_vocab(vocab): 145 | if k in tokenizer_vocab: 146 | continue 147 | tokenizer_vocab[k] = len(tokenizer_vocab) 148 | 149 | merges, missing_vocab = compute_merges( 150 | {k: v for k, v in tokenizer_vocab.items() if k not in added_vocab}, 151 | byte_level_decoder 152 | ) 153 | while add_missing_vocab and missing_vocab: 154 | for v in missing_vocab: 155 | tokenizer_vocab[v] = len(tokenizer_vocab) 156 | 157 | merges, missing_vocab = compute_merges( 158 | {k: v for k, v in tokenizer_vocab.items() if k not in added_vocab}, 159 | byte_level_decoder 160 | ) 161 | 162 | tokenizer_merges = tokenizer_json['model']['merges'] 163 | merges_set = {tuple(m) for m in tokenizer_merges} 164 | 165 | new_merges = [] 166 | for m in merges: 167 | if m in merges_set: 168 | continue 169 | new_merges.append(m) 170 | merges_set.add(m) 171 | 172 | tokenizer_merges += new_merges 173 | 174 | return tokenizer.__class__( 175 | tokenizer_object=Tokenizer.from_str(json.dumps(tokenizer_json)), 176 | **tokenizer.init_kwargs 177 | ) 178 | 179 | 180 | if __name__ == '__main__': 181 | fire.Fire(main) 182 | -------------------------------------------------------------------------------- /scripts/convert_to_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | import fire 6 | import torch 7 | import yaml 8 | from accelerate import init_empty_weights 9 | from lightning import LightningDataModule, LightningModule 10 | from lightning.pytorch.cli import LightningArgumentParser 11 | 12 | from llm_training.data import * 13 | from llm_training.lightning.cli import * 14 | from llm_training.lms import BaseLightningModule 15 | from llm_training.models import HFCompatModel 16 | 17 | 18 | def main( 19 | checkpoint_path: str | Path, 20 | output_dir: str | Path | None = None, 21 | config_path: str | None = None, 22 | eos_token_id: int | list[int] | None = None, 23 | dtype: torch.dtype | None = None 24 | ) -> None: 25 | checkpoint_path = Path(checkpoint_path) 26 | 27 | if output_dir is None: 28 | output_dir = checkpoint_path.parent / 'hf' / checkpoint_path.stem 29 | else: 30 | output_dir = Path(output_dir) 31 | 32 | print('Converting checkpoint') 33 | checkpoint = convert_checkpoint(checkpoint_path) 34 | 35 | if config_path is not None: 36 | with open(config_path) as f: 37 | config = yaml.safe_load(f) 38 | else: 39 | config = checkpoint['config'] 40 | 41 | dtype = dtype or get_dtype_from_config(config) 42 | 43 | lightning_module, datamodule = instantiate_model_and_datamodule(config) 44 | 45 | assert isinstance(lightning_module, BaseLightningModule) 46 | 47 | lightning_module.config.load_weights = False 48 | lightning_module.config.init_weights = False 49 | with init_empty_weights(include_buffers=False): 50 | lightning_module.configure_model() 51 | 52 | state_dict = checkpoint['state_dict'] 53 | 54 | required_keys = lightning_module.required_keys 55 | missing_keys = required_keys - state_dict.keys() 56 | 57 | if len(missing_keys) > 0: 58 | print(f'There are {len(missing_keys)} keys missing from the checkpoint, trying to take from pre-trained weights.') 59 | original_state_dict = lightning_module.get_pre_trained_weights() 60 | state_dict |= {k: original_state_dict[k] for k in missing_keys} 61 | 62 | incompatiable_keys = lightning_module.load_state_dict(state_dict, strict=False, assign=True) 63 | 64 | missing_keys = set(incompatiable_keys.missing_keys) 65 | required_missing_keys = missing_keys & required_keys 66 | assert len(missing_keys & required_keys) == 0, f"Missing keys: {required_missing_keys}" 67 | 68 | model = lightning_module.get_model() 69 | assert isinstance(model, HFCompatModel), f"{model.__class__} is not supported to be converted to HF version." 70 | model.config.torch_dtype = dtype 71 | hf_model = model.get_hf_model() 72 | 73 | if eos_token_id is not None: 74 | hf_model.config.eos_token_id = eos_token_id if isinstance(eos_token_id, int) else eos_token_id[0] 75 | hf_model.generation_config.eos_token_id = eos_token_id 76 | 77 | print('Saving model') 78 | hf_model.to(dtype).save_pretrained(output_dir) 79 | 80 | tokenizer = None 81 | if isinstance( 82 | datamodule, 83 | ( 84 | PreTrainingDataModule, 85 | InstructionTuningDataModule, 86 | PreferenceTuningDataModule 87 | ) 88 | ): 89 | tokenizer = datamodule.config.tokenizer 90 | 91 | if tokenizer is not None: 92 | print('Saving tokenizer') 93 | tokenizer.model_max_length = max(tokenizer.model_max_length, datamodule.config.max_length) 94 | chat_template = getattr(datamodule.config, 'chat_template', None) 95 | if chat_template is not None: 96 | tokenizer.chat_template = chat_template 97 | tokenizer.save_pretrained(output_dir) 98 | 99 | 100 | def convert_checkpoint(path: Path) -> dict[str, Any]: 101 | if ( 102 | path.is_dir() 103 | and path.joinpath('checkpoint').is_dir() 104 | and path.joinpath('latest').is_file() 105 | and path.joinpath('zero_to_fp32.py').is_file() 106 | ): 107 | print('DeepSpeed checkpoint detected') 108 | from lightning.pytorch.utilities.deepspeed import \ 109 | convert_zero_checkpoint_to_fp32_state_dict 110 | return convert_zero_checkpoint_to_fp32_state_dict(path, os.devnull) 111 | 112 | if ( 113 | path.is_dir() 114 | and path.joinpath('meta.pt').is_file() 115 | and len(list(path.glob('*.distcp'))) > 0 116 | ): 117 | print('FSDP checkpoint detected') 118 | return convert_fsdp_checkpoint(path) 119 | 120 | return torch.load(path, 'cpu') 121 | 122 | 123 | def convert_fsdp_checkpoint(path: Path) -> dict[str, Any]: 124 | from lightning.fabric.utilities.load import _METADATA_FILENAME 125 | from torch.distributed.checkpoint import FileSystemReader, load 126 | from torch.distributed.checkpoint.metadata import (BytesStorageMetadata, 127 | TensorStorageMetadata) 128 | 129 | reader = FileSystemReader(path) 130 | metadata = reader.read_metadata() 131 | 132 | tensor_names = [n for n in metadata.state_dict_metadata.keys() if n.startswith('state_dict.')] 133 | state_dict = {} 134 | for tensor_name in tensor_names: 135 | sd_metadata = metadata.state_dict_metadata[tensor_name] 136 | 137 | if isinstance(sd_metadata, BytesStorageMetadata): 138 | state_dict[tensor_name] = '' 139 | elif isinstance(sd_metadata, TensorStorageMetadata): 140 | state_dict[tensor_name] = torch.empty( 141 | size=sd_metadata.size, 142 | dtype=sd_metadata.properties.dtype, 143 | device=torch.device('cpu'), 144 | memory_format=sd_metadata.properties.memory_format, 145 | layout=sd_metadata.properties.layout, 146 | requires_grad=sd_metadata.properties.requires_grad, 147 | pin_memory=sd_metadata.properties.pin_memory 148 | ) 149 | else: 150 | raise NotImplementedError() 151 | 152 | load(state_dict=state_dict, storage_reader=reader) 153 | 154 | state_dict = {k.removeprefix('state_dict.'): v for k, v in state_dict.items()} 155 | checkpoint = {'state_dict': state_dict} 156 | # This is the extra file saved by Fabric, with user data separate from weights and optimizer states 157 | extra_file = path / _METADATA_FILENAME 158 | extra = torch.load(extra_file, map_location='cpu') if extra_file.is_file() else {} 159 | checkpoint.update(extra) 160 | 161 | return checkpoint 162 | 163 | 164 | def instantiate_model_and_datamodule(config: dict[str, Any]) -> tuple[LightningModule, LightningDataModule]: 165 | parser = LightningArgumentParser() 166 | parser.add_lightning_class_args(LightningModule, 'model', subclass_mode=True) 167 | parser.add_lightning_class_args(LightningDataModule, 'data', subclass_mode=True) 168 | classes = parser.instantiate_classes({k: config[k] for k in ['model', 'data']}) 169 | return classes['model'], classes['data'] 170 | 171 | 172 | def get_dtype_from_config(config: dict[str, Any]) -> torch.dtype: 173 | dtype_mapping = { 174 | '16-true': torch.half, 175 | 'bf16-true': torch.bfloat16 176 | } 177 | return dtype_mapping.get(config['trainer']['precision'], torch.float) 178 | 179 | 180 | if __name__ == '__main__': 181 | fire.Fire(main) 182 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM-Training 2 | 3 | A distributed training framework for large language models powered by Lightning. 4 | 5 | ## Supported Training Methods 6 | 7 | | Method | Full Training | Tensor Parallelism | 8 | | ------------------ | ------------------ | ------------------ | 9 | | Pre-training | :white_check_mark: | :white_check_mark: | 10 | | Instruction Tuning | :white_check_mark: | :white_check_mark: | 11 | | DPO | :white_check_mark: | :white_check_mark: | 12 | | ORPO | :white_check_mark: | :white_check_mark: | 13 | 14 | ### Pre-training 15 | 16 | - Supports Best-fit Bin Packing for less truncation. 17 | - Supports dynamic data sampling via configs, allowing flexible control of data sampling from multiple sources. 18 | 19 | ### Instruction Tuning 20 | 21 | - Supports data packing without cross-contamination. 22 | - Supports NEFTune. 23 | 24 | ## Supported Models 25 | 26 | All GPT-like models supported by HuggingFace are compatible. 27 | However, only text models are supported currently. 28 | 29 | Besides, alternative implementations that support additional features for specific model architectures are available. 30 | 31 | | Architecture | Selective Activation Checkpointing | Liger Kernel | Tensor Parallelism | Sequence Parallelism | 32 | | -------------- | ---------------------------------- | ------------------ | ------------------ | -------------------- | 33 | | LLaMA(2/3/3.x) | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | 34 | | Phi-3(3.5/4) | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: | 35 | 36 | ## Installation 37 | 38 | It is recommended to use [conda](https://github.com/conda/conda)/[mamba](https://github.com/mamba-org/mamba) for environment management. 39 | 40 | ```bash 41 | # Clone this repository 42 | git clone https://github.com/ShinoharaHare/LLM-Training.git && cd LLM-Training 43 | 44 | # Optional: Choose the version of LLM Training 45 | # By default, the main branch is used, which includes the latest features and changes but may come with instability. 46 | # Alternatively, you can switch to a specific release from the release page for more stability. 47 | # In most cases, using the latest release is recommended. 48 | git checkout vX.X.X 49 | 50 | # Create conda environment 51 | conda env create -f environment.yaml 52 | # or 53 | mamba env create -f environment.yaml 54 | 55 | # Activate the created conda environment 56 | conda activate llm-training 57 | 58 | # Install LLM Training 59 | ./install.sh 60 | ``` 61 | 62 | ## Usage 63 | 64 | > [!TIP] 65 | > The current documentation is not very comprehensive, as I haven’t had enough time to write it. 66 | > I can only provide brief usage examples, but many details and customizable parameters are not listed or explained in full. 67 | > As a result, you may need to refer to the source code to understand the purpose and usage of some parameters. 68 | > If this does not meet your expectations, you might want to consider using other open-source training frameworks, as there are likely many available in the community. 69 | 70 | ### Config 71 | 72 | To start a training, you will need to write your own config file first. 73 | 74 | A config file is a YAML file used to set up everything, including seeding, distributed strategy, hyper-parameters, model, data, and more. 75 | 76 | You can refer to the files under the [config](config/examples) directory to write your own config file. 77 | 78 | See [document](docs/config.md) for more information. 79 | 80 | ### Start a training 81 | 82 | ```bash 83 | llm-training fit --config 84 | ``` 85 | 86 | ### Multi-node training with SLURM 87 | 88 | You can launch a multi-node training using SLURM. 89 | 90 | ```bash 91 | srun llm-training fit --config --trainer.num_nodes 92 | ``` 93 | 94 | See [train.sh](scripts/train.sh) for sbatch script template. 95 | 96 | ### Convert to Hugging Face 97 | 98 | ```bash 99 | python scripts/convert_to_hf.py 100 | ``` 101 | 102 | Note that `` could either be a file or a folder, depending on the parallelization strategy you are using. 103 | By default, its name will follow this format: `epoch=xxx-step=yyy.ckpt`. 104 | 105 | ## Hints 106 | 107 | ### Cross-contamination Attention 108 | 109 | To improve training efficiency, we typically perform data packing, where multiple sequences of different lengths are merged into a single sequence, ensuring that each packed sequence has similar lengths. 110 | However, without proper handling, the attention mechanism may focus on irrelevant information, increasing the risk of hallucination in the model. 111 | 112 | The model architecture implemented in LLM-Training has already addressed this issue. 113 | On the other hand, if you are using the model architecture provided by HuggingFace, this issue is only handled when Flash Attention 2 is enabled. 114 | 115 | Reference: https://github.com/MeetKai/functionary/tree/main/functionary/train/packing 116 | 117 | ### Faulty Gradient Accumulation 118 | 119 | Gradient accumulation is a commonly used technique to simulate large-batch training under limited GPU memory. However, the Unsloth AI team discovered an issue in previous implementations, where the accumulated gradients are inconsistent with those from full-batch training. 120 | The root cause of this problem lies in improper loss normalization, which can also occur in distributed training scenarios. 121 | 122 | Currently, LLM-Training has not addressed this issue in its `main` branch, but the `fix-ga-dp` branch includes a fix for `CLM`. 123 | However, our experiments show that the corrected loss calculation does not significantly improve model performance and may even lead to a slight decrease. 124 | 125 | If you observe different experimental results, we encourage you to share them. 126 | 127 | Reference: https://unsloth.ai/blog/gradient 128 | 129 | ### Difference between DeepSpeed and FSDP 130 | 131 | DeepSpeed and FSDP are both implementations of distributed training, with their algorithms based on ZeRO. 132 | As a result, they are generally considered to deliver similar performance. 133 | However, there are some differences in their details, particularly in parameter precision settings, which are discussed in this [blog](https://huggingface.co/blog/deepspeed-to-fsdp-and-back) post. 134 | 135 | In FSDP2's mixed-precision training, we observed that it does not appear to store full-precision parameters separately. 136 | This causes both gradients and optimizer states to remain in half precision, which can significantly degrade training performance. 137 | 138 | To address this issue, we implemented an optimizer wrapper that automatically maintains a copy of full-precision parameters. 139 | The optimizer operates on these full-precision parameters and then synchronizes the updates back to the half precision parameters, ensuring training performance. 140 | 141 | ## Issues 142 | 143 | If you encounter any issue while using this framework, please avoid directly contacting the author. 144 | Instead, consider submitting an issue on the repository. 145 | This makes it easier to manage and address errors while also serving as a reference for others who may face the same problem in the future. 146 | 147 | Currently, there is no specific format for submitting an Issue. However, when reporting a problem, please provide as much relevant information as possible, such as: 148 | 149 | - The version or commit ID of LLM-Training you are using 150 | - The training config file 151 | - The full error message 152 | 153 | This will help ensure the issue can be resolved more efficiently. 154 | -------------------------------------------------------------------------------- /src/llm_training/lms/clm/clm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from contextlib import nullcontext 3 | from typing import Any 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch.distributed.tensor import DTensor 9 | from torch.distributed.tensor.parallel import loss_parallel 10 | 11 | from llm_training.lightning.strategy import FSDP2Strategy 12 | from llm_training.lms.base_lm import BaseLightningModule 13 | from llm_training.lms.protos import CausalLMProto 14 | from llm_training.lms.utils import get_model 15 | from llm_training.metrics import ConsumedSamples, ConsumedTokens, Perplexity 16 | from llm_training.models.base_model.base_model import BaseModel 17 | from llm_training.ops import shift_labels 18 | from llm_training.ops.liger_kernel import cross_entropy 19 | 20 | from .clm_config import CLMConfig 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class CLM(BaseLightningModule): 26 | config: CLMConfig 27 | model: CausalLMProto | BaseModel | None 28 | 29 | def __init__(self, config: CLMConfig) -> None: 30 | super().__init__(config) 31 | 32 | self.model = None 33 | 34 | @property 35 | def has_pre_trained_weights(self) -> bool: 36 | if self.model is None: 37 | return False 38 | return self.model.has_pre_trained_weights 39 | 40 | def get_pre_trained_weights(self) -> dict[str, torch.Tensor]: 41 | state_dict = self.model.get_pre_trained_weights() 42 | state_dict = {f'model.{k}': v for k, v in state_dict.items()} 43 | return state_dict 44 | 45 | def neftune_forward_hook( 46 | self, 47 | module: nn.Module, 48 | input: torch.Tensor, 49 | output: torch.Tensor 50 | ) -> torch.Tensor: 51 | if module.training: 52 | attention_mask = getattr(self, '_current_attention_mask', None) 53 | if attention_mask is None: 54 | attention_mask = torch.ones_like(input) 55 | 56 | # For packed attention mask 57 | attention_mask = attention_mask.bool().to(output.dtype) 58 | 59 | noise = torch.empty( 60 | output.shape, 61 | dtype=output.dtype, 62 | device=output.device 63 | ) 64 | noise = noise.uniform_(-1, 1) 65 | input_lengths = attention_mask.sum(1) 66 | delta = noise * attention_mask.unsqueeze(2) 67 | dims = input_lengths * output.size(-1) 68 | magnitude = self.config.neftune_alpha / torch.sqrt(dims) 69 | delta = (delta * magnitude.view(-1, 1, 1)).detach() 70 | if isinstance(output, DTensor): 71 | delta = DTensor.from_local( 72 | delta, 73 | device_mesh=output.device_mesh, 74 | placements=output.placements, 75 | run_check=False 76 | ) 77 | output = output + delta 78 | return output 79 | 80 | def register_neftune_hook(self) -> None: 81 | embedding = self.model.get_input_embeddings() 82 | self._neftune_hook_handle = embedding.register_forward_hook(self.neftune_forward_hook) 83 | 84 | def configure_model(self) -> None: 85 | process_group = self.strategy.dp_mesh.get_group() if isinstance(self.strategy, FSDP2Strategy) else None 86 | self.consumed_samples = ConsumedSamples(process_group=process_group) 87 | self.consumed_tokens = ConsumedTokens( 88 | ignore_index=self.config.ignore_index, 89 | process_group=process_group 90 | ) 91 | if self.config.log_perplexity: 92 | self.train_perplexity = Perplexity( 93 | ignore_index=self.config.ignore_index, 94 | process_group=process_group 95 | ) 96 | self.val_perplexity = Perplexity( 97 | ignore_index=self.config.ignore_index, 98 | process_group=process_group 99 | ) 100 | 101 | self.model = get_model(self.config.model) 102 | 103 | if self.global_rank == 0: 104 | logger.info(f'Config:\n{repr(self.config)}') 105 | logger.info(f'Model:\n{self.model}') 106 | 107 | if self.config.neftune_alpha is not None: 108 | self.register_neftune_hook() 109 | 110 | def on_fsdp_parallelize_model(self, **kwargs) -> None: 111 | self.model.parallelize(**kwargs) 112 | 113 | def compute_loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 114 | if isinstance(self.strategy, FSDP2Strategy) and self.strategy.tp_size > 1: 115 | with loss_parallel(): 116 | return F.cross_entropy( 117 | logits.flatten(end_dim=1), 118 | labels.flatten(end_dim=1), 119 | ignore_index=self.config.ignore_index 120 | ) 121 | 122 | return cross_entropy( 123 | logits=logits, 124 | labels=labels, 125 | ignore_index=self.config.ignore_index 126 | ) 127 | 128 | def backward(self, loss: torch.Tensor, *args, **kwargs) -> None: 129 | backward_ctx = nullcontext() 130 | if isinstance(self.strategy, FSDP2Strategy) and self.strategy.tp_size > 1: 131 | backward_ctx = loss_parallel() 132 | 133 | with backward_ctx: 134 | super().backward(loss, *args, **kwargs) 135 | 136 | def training_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int) -> torch.Tensor: 137 | labels = shift_labels(batch['labels'], self.config.ignore_index) 138 | 139 | if self.config.neftune_alpha is not None: 140 | self._current_attention_mask = batch['attention_mask'] 141 | 142 | outputs = self.model( 143 | input_ids=batch['input_ids'], 144 | attention_mask=batch['attention_mask'], 145 | position_ids=batch.get('position_ids', None) 146 | ) 147 | logits = outputs.logits.float() 148 | 149 | if self.config.neftune_alpha is not None: 150 | self.log('NEFTune Alpha', self.config.neftune_alpha) 151 | self._current_attention_mask = None 152 | 153 | loss = self.compute_loss(logits, labels) 154 | 155 | self.log('loss', loss, prog_bar=True, logger=False) 156 | self.log('Loss/Train/Step', loss) 157 | 158 | if self.config.log_perplexity: 159 | self.train_perplexity(loss) 160 | self.log('Perplexity/Train/Step', self.train_perplexity) 161 | 162 | self.consumed_samples.update(labels) 163 | self.consumed_tokens.update(labels) 164 | self.logger.log_metrics({ 165 | 'Consumed Samples': self.consumed_samples.compute(), 166 | 'Consumed Tokens': self.consumed_tokens.compute() 167 | }) 168 | return loss 169 | 170 | def validation_step(self, batch: dict[str, torch.Tensor | Any], batch_idx: int, dataloader_idx: int = 0): 171 | batch_size = batch['input_ids'].size(0) 172 | labels = shift_labels(batch['labels'], self.config.ignore_index) 173 | outputs = self.model( 174 | input_ids=batch['input_ids'], 175 | attention_mask=batch['attention_mask'], 176 | position_ids=batch.get('position_ids', None) 177 | ) 178 | logits = outputs.logits.float() 179 | 180 | loss = self.compute_loss(logits, labels) 181 | 182 | if isinstance(loss, DTensor): 183 | loss = loss.full_tensor() 184 | 185 | self.log('Loss/Val', loss, batch_size=batch_size, sync_dist=True) 186 | 187 | if self.config.log_perplexity: 188 | self.val_perplexity.update(loss) 189 | self.log('Perplexity/Val', self.val_perplexity) 190 | 191 | def get_model(self) -> BaseModel: 192 | return self.model 193 | -------------------------------------------------------------------------------- /src/llm_training/lightning/strategy/deepspeed/deepspeed_strategy.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | from typing import Any, Dict, List 4 | 5 | import lightning as L 6 | import torch 7 | from lightning.fabric.plugins import ClusterEnvironment 8 | from lightning.fabric.utilities.types import _PATH 9 | from lightning.pytorch.accelerators import Accelerator 10 | from lightning.pytorch.plugins import Precision 11 | from lightning.pytorch.strategies.deepspeed import \ 12 | DeepSpeedStrategy as _DeepSpeedStrategy 13 | from lightning.pytorch.utilities.types import STEP_OUTPUT 14 | 15 | 16 | class DeepSpeedStrategy(_DeepSpeedStrategy): 17 | def __init__( 18 | self, 19 | accelerator: Accelerator | None = None, 20 | zero_optimization: bool = True, 21 | stage: int = 2, 22 | remote_device: str | None = None, 23 | offload_optimizer: bool = False, 24 | offload_parameters: bool = False, 25 | offload_params_device: str = 'cpu', 26 | nvme_path: str = '/local_nvme', 27 | params_buffer_count: int = 5, 28 | params_buffer_size: int = 100000000, 29 | max_in_cpu: int = 1000000000, 30 | offload_optimizer_device: str = 'cpu', 31 | optimizer_buffer_count: int = 4, 32 | block_size: int = 1048576, 33 | queue_depth: int = 8, 34 | single_submit: bool = False, 35 | overlap_events: bool = True, 36 | thread_count: int = 1, 37 | pin_memory: bool = False, 38 | sub_group_size: int = 1000000000000, 39 | contiguous_gradients: bool = True, 40 | overlap_comm: bool = True, 41 | allgather_partitions: bool = True, 42 | reduce_scatter: bool = True, 43 | allgather_bucket_size: int = 200000000, 44 | reduce_bucket_size: int = 200000000, 45 | zero_allow_untested_optimizer: bool = True, 46 | logging_batch_size_per_gpu: str | int = 'auto', 47 | config: _PATH | Dict[str, Any] | None = None, 48 | logging_level: int | str = logging.WARN, 49 | parallel_devices: List[torch.device] | None = None, 50 | cluster_environment: ClusterEnvironment | None = None, 51 | loss_scale: float = 0, 52 | initial_scale_power: int = 16, 53 | loss_scale_window: int = 1000, 54 | hysteresis: int = 2, 55 | min_loss_scale: int = 1, 56 | partition_activations: bool = False, 57 | cpu_checkpointing: bool = False, 58 | contiguous_memory_optimization: bool = False, 59 | synchronize_checkpoint_boundary: bool = False, 60 | load_full_weights: bool = False, 61 | precision_plugin: Precision | None = None, 62 | process_group_backend: str | None = None, 63 | exclude_frozen_parameters: bool = True, 64 | raise_error_at_min_scale: bool | None = None, 65 | zero3_leaf_modules: list[type] | None = None, 66 | stage3_max_live_parameters: int | float = 1e9, 67 | stage3_max_reuse_distance: int | float = 1e9, 68 | stage3_prefetch_bucket_size: int | float = 5e8, 69 | stage3_param_persistence_threshold: int | float = 1e6, 70 | zero_hpz_partition_size: int = 1, 71 | zero_quantized_weights: bool = False, 72 | zero_quantized_gradients: bool = False 73 | ): 74 | if isinstance(logging_level, str): 75 | logging_level = getattr(logging, logging_level.upper()) 76 | 77 | self.exclude_frozen_parameters = exclude_frozen_parameters 78 | self.raise_error_at_min_scale = raise_error_at_min_scale 79 | self.zero3_leaf_modules = zero3_leaf_modules 80 | self.stage3_max_live_parameters = stage3_max_live_parameters 81 | self.stage3_max_reuse_distance = stage3_max_reuse_distance 82 | self.stage3_prefetch_bucket_size = stage3_prefetch_bucket_size 83 | self.stage3_param_persistence_threshold = stage3_param_persistence_threshold 84 | self.zero_hpz_partition_size = zero_hpz_partition_size 85 | self.zero_quantized_weights = zero_quantized_weights 86 | self.zero_quantized_gradients = zero_quantized_gradients 87 | 88 | super().__init__(accelerator, zero_optimization, stage, remote_device, offload_optimizer, offload_parameters, offload_params_device, nvme_path, params_buffer_count, params_buffer_size, max_in_cpu, offload_optimizer_device, optimizer_buffer_count, block_size, queue_depth, single_submit, overlap_events, thread_count, pin_memory, sub_group_size, contiguous_gradients, overlap_comm, allgather_partitions, reduce_scatter, allgather_bucket_size, reduce_bucket_size, zero_allow_untested_optimizer, logging_batch_size_per_gpu, config, logging_level, parallel_devices, cluster_environment, loss_scale, initial_scale_power, loss_scale_window, hysteresis, min_loss_scale, partition_activations, cpu_checkpointing, contiguous_memory_optimization, synchronize_checkpoint_boundary, load_full_weights, precision_plugin, process_group_backend) 89 | 90 | @property 91 | def is_fp16(self) -> bool: 92 | return self.precision_plugin.precision.startswith('16') 93 | 94 | def _create_default_config(self, *args, **kwargs) -> dict[str, Any]: 95 | kwargs.setdefault('stage3_max_live_parameters', self.stage3_max_live_parameters) 96 | kwargs.setdefault('stage3_max_reuse_distance', self.stage3_max_reuse_distance) 97 | kwargs.setdefault('stage3_prefetch_bucket_size', self.stage3_prefetch_bucket_size) 98 | kwargs.setdefault('stage3_param_persistence_threshold', self.stage3_param_persistence_threshold) 99 | kwargs.setdefault('zero_hpz_partition_size', self.zero_hpz_partition_size) 100 | kwargs.setdefault('zero_quantized_weights', self.zero_quantized_weights) 101 | kwargs.setdefault('zero_quantized_gradients', self.zero_quantized_gradients) 102 | return super()._create_default_config(*args, **kwargs) 103 | 104 | def _set_raise_error_at_min_scale(self): 105 | optimizer = getattr(self.deepspeed_engine, 'optimizer', None) 106 | loss_scaler = getattr(optimizer, 'loss_scaler', None) 107 | if self.raise_error_at_min_scale is not None and loss_scaler is not None: 108 | loss_scaler.raise_error_at_min_scale = self.raise_error_at_min_scale 109 | 110 | def _convert_metrics(self): 111 | from torchmetrics import Metric 112 | 113 | for m in self.model.modules(): 114 | if isinstance(m, Metric): 115 | m.to(self.root_device) 116 | m.set_dtype(m.dtype) 117 | 118 | def init_deepspeed(self) -> None: 119 | import deepspeed # type: ignore 120 | 121 | if self.zero3_leaf_modules: 122 | deepspeed.utils.set_z3_leaf_modules(self.model, self.zero3_leaf_modules) 123 | super().init_deepspeed() 124 | 125 | def setup(self, trainer: L.Trainer) -> None: 126 | super().setup(trainer) 127 | 128 | self._set_raise_error_at_min_scale() 129 | self._convert_metrics() 130 | 131 | def _maybe_add_skipped_steps_to_progress_bar(self): 132 | if not self.is_fp16: 133 | return 134 | 135 | progress_bar_metrics = self.lightning_module.trainer.progress_bar_metrics 136 | progress_bar_metrics['skipped_steps'] = self.deepspeed_engine.skipped_steps 137 | 138 | def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: 139 | self.lightning_module._grad_norm = None 140 | output = super().training_step(*args, **kwargs) 141 | self._maybe_add_skipped_steps_to_progress_bar() 142 | return output 143 | 144 | def optimizer_step(self, optimizer, closure, model = None, **kwargs): 145 | output = super().optimizer_step(optimizer, closure, model, **kwargs) 146 | self.lightning_module._grad_norm = self.deepspeed_engine.get_global_grad_norm() 147 | return output 148 | 149 | def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Any | None = None) -> None: 150 | save_checkpoint = self.deepspeed_engine.save_checkpoint 151 | self.deepspeed_engine.save_checkpoint = partial( 152 | save_checkpoint, 153 | exclude_frozen_parameters=self.exclude_frozen_parameters 154 | ) 155 | super().save_checkpoint(checkpoint, filepath, storage_options) 156 | self.deepspeed_engine.save_checkpoint = save_checkpoint 157 | -------------------------------------------------------------------------------- /src/llm_training/data/instruction_tuning/instruction_tuning_datamodule.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any 3 | 4 | import tokenizers 5 | from datasets import Features, Sequence, Value 6 | from packaging.version import Version 7 | from transformers import PreTrainedTokenizerBase 8 | 9 | from llm_training.data.hf_based.hf_based_datamodule import (DatasetDict, 10 | HFBasedDataModule) 11 | 12 | from .instruction_tuning_datacollator import InstructionTuningDataCollator 13 | from .instruction_tuning_datamodule_config import ( 14 | InstructionTuningDataModuleConfig, OverlongHandlingMethod, PackingMethod) 15 | 16 | 17 | class InstructionTuningDataModule(HFBasedDataModule): 18 | config: InstructionTuningDataModuleConfig 19 | datacollator_class = InstructionTuningDataCollator 20 | 21 | def __init__(self, config: InstructionTuningDataModuleConfig) -> None: 22 | super().__init__(config) 23 | 24 | if Version(tokenizers.__version__) < Version('0.20.1'): 25 | raise ValueError( 26 | "`tokenizers` must be at least version 0.20.1, " 27 | "otherwise LLaMA 3 tokenizer will produce incorrect prompt/response mask." 28 | ) 29 | 30 | @classmethod 31 | def _apply_chat_template_and_tokenize( 32 | cls, 33 | batch: dict[str, list[str]], 34 | tokenizer: PreTrainedTokenizerBase, 35 | chat_template: str | None, 36 | default_system_prompt: str | None, 37 | add_default_system_prompt_rate: float | None 38 | ): 39 | new_batch = { 40 | 'input_ids': [], 41 | 'attention_mask': [], 42 | 'labels': [], 43 | 'length': [] 44 | } 45 | 46 | for messages in batch['messages']: 47 | # Add an empty system prompt randomly if it does not exist. 48 | has_system_prompt = any(m['role'] == 'system' for m in messages) 49 | if ( 50 | not has_system_prompt 51 | and default_system_prompt is not None 52 | and add_default_system_prompt_rate is not None 53 | and random.random() < add_default_system_prompt_rate 54 | ): 55 | messages.insert(0, {'role': 'system', 'content': default_system_prompt}) 56 | 57 | batch_encoding = tokenizer.apply_chat_template( 58 | batch['messages'], 59 | chat_template=chat_template, 60 | return_dict=True, 61 | return_assistant_tokens_mask=True, 62 | tokenizer_kwargs=dict( 63 | return_attention_mask=False, 64 | verbose=False 65 | ) 66 | ) 67 | 68 | for input_ids, assistant_masks in zip( 69 | batch_encoding['input_ids'], 70 | batch_encoding['assistant_masks'] 71 | ): 72 | labels = [i if a == 1 else -100 for i, a in zip(input_ids, assistant_masks)] 73 | new_batch['input_ids'].append(input_ids) 74 | new_batch['attention_mask'].append([1] * len(input_ids)) 75 | new_batch['labels'].append(labels) 76 | new_batch['length'].append(len(input_ids)) 77 | 78 | return new_batch 79 | 80 | @classmethod 81 | def _drop_overlong_examples( 82 | cls, 83 | batch: dict[str, Any], 84 | max_length: int 85 | ): 86 | indices = [i for i, n in enumerate(batch['length']) if n <= max_length] 87 | return {k: [v[i] for i in indices] for k, v in batch.items()} 88 | 89 | @classmethod 90 | def _truncate_overlong_examples( 91 | cls, 92 | batch: dict[str, Any], 93 | max_length: int 94 | ): 95 | for i in range(len(batch['input_ids'])): 96 | if batch['length'][i] > max_length: 97 | batch['input_ids'][i] = batch['input_ids'][:max_length] 98 | batch['labels'][i] = batch['labels'][:max_length] 99 | batch['length'][i] = max_length 100 | return batch 101 | 102 | @classmethod 103 | def _group_indices_by_length(cls, lengths: list[int], max_length: int) -> list[list[int]]: 104 | groups = [] 105 | current_group = [] 106 | current_sum = 0 107 | 108 | for i, l in sorted(enumerate(lengths), key=lambda x: x[1]): 109 | if current_sum + l + len(current_group) <= max_length: 110 | current_group.append(i) 111 | current_sum += l 112 | else: 113 | groups.append(current_group) 114 | current_group = [i] 115 | current_sum = l 116 | 117 | if current_group: 118 | groups.append(current_group) 119 | 120 | return groups 121 | 122 | @classmethod 123 | def _group_by_length(cls, batch: dict[str, list[list[int]]], max_length: int): 124 | new_batch = { 125 | 'input_ids': [], 126 | 'attention_mask': [], 127 | 'labels': [], 128 | 'length': [] 129 | } 130 | 131 | groups = cls._group_indices_by_length(batch['length'], max_length) 132 | for group in groups: 133 | input_ids = [] 134 | attention_mask = [] 135 | labels = [] 136 | for local_idx, global_idx in enumerate(group): 137 | input_ids += batch['input_ids'][global_idx] 138 | attention_mask += [local_idx + 1] * batch['length'][global_idx] 139 | labels += batch['labels'][global_idx] 140 | new_batch['input_ids'].append(input_ids) 141 | new_batch['attention_mask'].append(attention_mask) 142 | new_batch['labels'].append(labels) 143 | new_batch['length'].append(len(input_ids)) 144 | 145 | return new_batch 146 | 147 | @classmethod 148 | def _pre_process_data( 149 | cls, 150 | batch: dict[str, list], 151 | tokenizer: PreTrainedTokenizerBase, 152 | chat_template: str | None, 153 | default_system_prompt: str | None, 154 | add_default_system_prompt_rate: float | None, 155 | max_length: int | None, 156 | overlong_handling_method: OverlongHandlingMethod, 157 | packing_method: PackingMethod 158 | ) -> dict[str, list]: 159 | batch = cls._apply_chat_template_and_tokenize( 160 | batch, 161 | tokenizer=tokenizer, 162 | chat_template=chat_template, 163 | default_system_prompt=default_system_prompt, 164 | add_default_system_prompt_rate=add_default_system_prompt_rate 165 | ) 166 | 167 | if max_length is not None: 168 | if overlong_handling_method == OverlongHandlingMethod.DROP: 169 | batch = cls._drop_overlong_examples(batch, max_length) 170 | elif overlong_handling_method == OverlongHandlingMethod.TRUNCATE: 171 | batch = cls._truncate_overlong_examples(batch, max_length) 172 | 173 | if packing_method == PackingMethod.GROUP_BY_LENGTH: 174 | batch = cls._group_by_length(batch, max_length) 175 | 176 | return batch 177 | 178 | def pre_process_data(self, dataset_dict: DatasetDict) -> DatasetDict: 179 | dataset_dict = self.map_dataset_dict( 180 | dataset_dict, 181 | self._pre_process_data, 182 | fn_kwargs=dict( 183 | tokenizer=self.config.tokenizer, 184 | chat_template=self.config.chat_template, 185 | default_system_prompt=self.config.default_system_prompt, 186 | add_default_system_prompt_rate=self.config.add_default_system_prompt_rate, 187 | max_length=self.config.max_length, 188 | overlong_handling_method=self.config.overlong_handling_method, 189 | packing_method=self.config.packing_method 190 | ), 191 | batched=True, 192 | remove_columns=True, 193 | num_proc=self.config.num_proc, 194 | features=Features({ 195 | 'input_ids': Sequence(Value('int32')), 196 | 'attention_mask': Sequence(Value('uint16')), 197 | 'labels': Sequence(Value('int32')), 198 | 'length': Value('uint32') 199 | }), 200 | desc='Pre-processing data' 201 | ) 202 | return dataset_dict 203 | --------------------------------------------------------------------------------