├── setup.cfg ├── optimizers ├── __init__.py └── igod.py ├── lib └── ditty │ ├── diffusion │ ├── __init__.py │ └── noise_schedule.py │ ├── utils.py │ ├── base.py │ ├── __init__.py │ ├── processors.py │ ├── hf_utils.py │ ├── loss.py │ ├── data.py │ ├── checkpoint.py │ ├── trainer.py │ ├── contract.py │ ├── pipeline.py │ ├── model_factory.py │ └── example.py ├── main.py ├── requirements.txt ├── ds_config.json ├── setup.py ├── pyproject.toml ├── plans ├── FSDP_NOTES.md └── pipeline_print_method.md ├── .gitignore ├── README.md └── LICENSE /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file=README.md 3 | license_files=LICENSE 4 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .igod import HebbianIGOD 2 | 3 | __all__ = ["HebbianIGOD"] 4 | -------------------------------------------------------------------------------- /lib/ditty/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .noise_schedule import Scheduler 2 | 3 | __all__ = ["Scheduler"] 4 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import fire 2 | from lib.ditty.pipeline import Pipeline 3 | 4 | if __name__ == "__main__": 5 | fire.Fire(Pipeline) 6 | 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets 4 | einops 5 | fastcore 6 | huggingface-hub 7 | numpy 8 | peft 9 | safetensors 10 | torch 11 | tqdm 12 | transformers 13 | -------------------------------------------------------------------------------- /lib/ditty/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def convert_seconds_to_string_time(seconds): 3 | day = seconds // (24 * 3600) 4 | seconds = seconds % (24 * 3600) 5 | hour = seconds // 3600 6 | seconds %= 3600 7 | minutes = seconds // 60 8 | seconds %= 60 9 | 10 | return "%d days, %02d hours, %02d minutes, %02d seconds" % (day, hour, minutes, seconds) 11 | -------------------------------------------------------------------------------- /ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "offload_optimizer": { 5 | "device": "cpu", 6 | "pin_memory": true 7 | }, 8 | "allgather_partitions": true, 9 | "allgather_bucket_size": 2e8, 10 | "reduce_scatter": true, 11 | "reduce_bucket_size": 2e8, 12 | "overlap_comm": true, 13 | "contiguous_gradients": true 14 | } 15 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name='ditty', 6 | version='0.5.1', 7 | license='Apache V2', 8 | author="Ian T Butler (KinglyCrow)", 9 | author_email='iantbutler01@gmail.com', 10 | packages=find_packages('lib'), 11 | package_dir={'': 'lib'}, 12 | long_description=open('README.md', 'r').read(), 13 | long_description_content_type='text/markdown', 14 | url='https://github.com/iantbutler01/ditty', 15 | keywords='finetuning, llm, nlp, machine learning', 16 | install_requires=[ 17 | 'accelerate', 18 | 'transformers', 19 | 'datasets', 20 | 'bitsandbytes', 21 | 'fire', 22 | 'peft' 23 | ], 24 | 25 | ) 26 | -------------------------------------------------------------------------------- /lib/ditty/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for all ditty pipeline components. 3 | """ 4 | import re 5 | 6 | 7 | def camel_to_snake(name: str) -> str: 8 | """Convert CamelCase to snake_case, handling acronyms properly.""" 9 | # Handle acronyms followed by lowercase (e.g., MSELoss -> mse_loss) 10 | name = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', name) 11 | # Handle lowercase followed by uppercase (e.g., camelCase -> camel_case) 12 | name = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', name) 13 | return name.lower() 14 | 15 | 16 | class DittyBase: 17 | """Base class providing contract and name for all pipeline components.""" 18 | 19 | def __init__(self, name: str = "", contract: str = ""): 20 | self.name = name or camel_to_snake(self.__class__.__name__) 21 | self.contract = contract 22 | -------------------------------------------------------------------------------- /lib/ditty/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import DittyBase 2 | from .contract import ( 3 | Contract, 4 | TensorSpec, 5 | ContractViolation, 6 | ContractParseError, 7 | parse_contract, 8 | validate_pipeline_chain, 9 | format_pipeline_contracts, 10 | ) 11 | from .pipeline import Pipeline 12 | from .trainer import Trainer, TrainerState 13 | from .data import Data 14 | from . import diffusion 15 | from .loss import LossCalculator, LossOutput, MSELoss, L1Loss, CrossEntropyLoss, CompositeLoss 16 | from .processors import PreProcessor, PostProcessor, Context 17 | from .model_factory import ModelFactory, TokenizerFactory, FSDPConfig, QuantConfig, PeftConfig, ModelTransform 18 | from .checkpoint import CheckpointManager, Checkpoint 19 | from .example import print_pipeline 20 | 21 | __all__ = [ 22 | "DittyBase", 23 | "Contract", 24 | "TensorSpec", 25 | "ContractViolation", 26 | "ContractParseError", 27 | "parse_contract", 28 | "validate_pipeline_chain", 29 | "format_pipeline_contracts", 30 | "Pipeline", 31 | "Trainer", 32 | "TrainerState", 33 | "Data", 34 | "LossCalculator", 35 | "LossOutput", 36 | "MSELoss", 37 | "L1Loss", 38 | "CrossEntropyLoss", 39 | "CompositeLoss", 40 | "PreProcessor", 41 | "PostProcessor", 42 | "Context", 43 | "ModelFactory", 44 | "TokenizerFactory", 45 | "FSDPConfig", 46 | "QuantConfig", 47 | "PeftConfig", 48 | "ModelTransform", 49 | "CheckpointManager", 50 | "Checkpoint", 51 | "print_pipeline", 52 | ] 53 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "ditty" 7 | version = "0.7.0" 8 | description = "Distributed training library with FSDP2, pipeline contracts, and modern optimizers" 9 | readme = "README.md" 10 | license = "Apache-2.0" 11 | authors = [ 12 | { name = "Ian T Butler (KinglyCrow)", email = "iantbutler01@gmail.com" } 13 | ] 14 | keywords = ["finetuning", "llm", "nlp", "machine learning", "distributed training", "fsdp2"] 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Intended Audience :: Developers", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: 3.12", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | ] 25 | requires-python = ">=3.10" 26 | dependencies = [ 27 | "accelerate", 28 | "datasets>=3.2.0", 29 | "einops>=0.8.0", 30 | "fastcore", 31 | "huggingface-hub>=0.33.0", 32 | "numpy>=2.2.1", 33 | "peft", 34 | "safetensors>=0.5.3", 35 | "torch>=2.5.1", 36 | "tqdm>=4.67.1", 37 | "transformers>=4.52.4", 38 | "bitsandbytes", 39 | "torchao", 40 | ] 41 | 42 | [project.optional-dependencies] 43 | dev = [ 44 | "pytest", 45 | "black", 46 | "ruff", 47 | ] 48 | 49 | [project.urls] 50 | Homepage = "https://github.com/iantbutler01/ditty" 51 | Repository = "https://github.com/iantbutler01/ditty" 52 | 53 | [tool.hatch.build.targets.wheel] 54 | packages = ["lib/ditty"] 55 | 56 | [tool.black] 57 | line-length = 100 58 | target-version = ["py310", "py311", "py312"] 59 | 60 | [tool.ruff] 61 | line-length = 100 62 | target-version = "py310" 63 | -------------------------------------------------------------------------------- /lib/ditty/processors.py: -------------------------------------------------------------------------------- 1 | """ 2 | PreProcessor and PostProcessor abstractions for ditty trainers. 3 | 4 | Architecture: 5 | dataset -> preprocessors -> model.forward -> postprocessors -> loss_calc 6 | 7 | Contracts use terse syntax: "input:rank:dtype -> output:rank:dtype | ctx.key:rank:dtype" 8 | """ 9 | from abc import ABC, abstractmethod 10 | from typing import Any, Dict, Tuple 11 | 12 | from .base import DittyBase 13 | 14 | 15 | Context = Dict[str, Any] 16 | 17 | 18 | class PreProcessor(DittyBase, ABC): 19 | @abstractmethod 20 | def process(self, batch: Any, ctx: Context) -> Tuple[Any, Context]: 21 | """ 22 | Transform batch for model forward. 23 | 24 | Returns: 25 | (batch_transformed, ctx) 26 | """ 27 | pass 28 | 29 | def config(self) -> Dict[str, Any]: 30 | return {} 31 | 32 | def __repr__(self): 33 | cfg = self.config() 34 | if cfg: 35 | params = ", ".join(f"{k}={v}" for k, v in cfg.items()) 36 | return f"{self.name}({params})" 37 | return self.name 38 | 39 | 40 | class PostProcessor(DittyBase, ABC): 41 | @abstractmethod 42 | def process(self, model_output: Tuple[Any, ...], ctx: Context) -> Tuple[Tuple[Any, ...], Context]: 43 | """ 44 | Transform model output for loss calculation. 45 | 46 | Args: 47 | model_output: Tuple of tensors from model forward 48 | ctx: Context dict 49 | 50 | Returns: 51 | (model_output_transformed, ctx) 52 | """ 53 | pass 54 | 55 | def config(self) -> Dict[str, Any]: 56 | return {} 57 | 58 | def __repr__(self): 59 | cfg = self.config() 60 | if cfg: 61 | params = ", ".join(f"{k}={v}" for k, v in cfg.items()) 62 | return f"{self.name}({params})" 63 | return self.name 64 | -------------------------------------------------------------------------------- /lib/ditty/hf_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | from huggingface_hub import HfApi 4 | 5 | logger = getLogger("ditty_hf_utils") 6 | 7 | 8 | def push_to_hub(self, repo_id, token=None, accelerator=None, private=True): 9 | """ 10 | Push model to HuggingFace Hub with FSDP support. 11 | 12 | This function handles gathering FSDP sharded state dicts before pushing. 13 | Meant to be monkey-patched onto HuggingFace models. 14 | """ 15 | if accelerator is None: 16 | self.save_pretrained(f"/tmp/ditty_push_{repo_id.replace('/', '_')}") 17 | api = HfApi(token=token) 18 | api.create_repo(repo_id, private=private, exist_ok=True) 19 | api.upload_folder( 20 | folder_path=f"/tmp/ditty_push_{repo_id.replace('/', '_')}", 21 | repo_id=repo_id, 22 | token=token, 23 | ) 24 | return 25 | 26 | accelerator.wait_for_everyone() 27 | 28 | if accelerator.distributed_type == "FSDP": 29 | state_dict = accelerator.get_state_dict(self) 30 | if accelerator.is_main_process: 31 | unwrapped = accelerator.unwrap_model(self) 32 | unwrapped.save_pretrained( 33 | f"/tmp/ditty_push_{repo_id.replace('/', '_')}", 34 | state_dict=state_dict, 35 | ) 36 | else: 37 | if accelerator.is_main_process: 38 | self.save_pretrained(f"/tmp/ditty_push_{repo_id.replace('/', '_')}") 39 | 40 | accelerator.wait_for_everyone() 41 | 42 | if accelerator.is_main_process: 43 | api = HfApi(token=token) 44 | api.create_repo(repo_id, private=private, exist_ok=True) 45 | api.upload_folder( 46 | folder_path=f"/tmp/ditty_push_{repo_id.replace('/', '_')}", 47 | repo_id=repo_id, 48 | token=token, 49 | ) 50 | logger.info(f"Pushed model to https://huggingface.co/{repo_id}") 51 | -------------------------------------------------------------------------------- /lib/ditty/diffusion/noise_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | class Scheduler: 6 | def __init__(self, timesteps=1): 7 | self.timesteps = timesteps 8 | self.alphas: Tensor 9 | self.betas: Tensor 10 | self.alphas_cumprod: Tensor 11 | 12 | def _cosine_beta_schedule(self, s=0.008, max_beta=0.999): 13 | """ 14 | Cosine schedule as proposed in Improved DDPM paper. 15 | Matches HuggingFace diffusers betas_for_alpha_bar implementation. 16 | """ 17 | import math 18 | 19 | def alpha_bar_fn(t): 20 | return math.cos((t + s) / (1 + s) * math.pi / 2) ** 2 21 | 22 | betas = [] 23 | for i in range(self.timesteps): 24 | t1 = i / self.timesteps 25 | t2 = (i + 1) / self.timesteps 26 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) 27 | 28 | return torch.tensor(betas, dtype=torch.float32) 29 | 30 | def _linear_beta_schedule(self, beta_start=0.0001, beta_end=0.02): 31 | """ 32 | Linear schedule as used in original DDPM paper 33 | """ 34 | return torch.linspace(beta_start, beta_end, self.timesteps) 35 | 36 | def create_noise_schedule(self, schedule_type="cosine", F=1.0, beta_start=0.0001, beta_end=0.02): 37 | """ 38 | Create noise schedule. Matches HuggingFace diffusers implementation. 39 | 40 | Args: 41 | schedule_type: "cosine" or "linear" 42 | F: Rescale factor for cosine schedule (betas scaled by F², default 1.0) 43 | beta_start: Starting beta for linear schedule 44 | beta_end: Ending beta for linear schedule 45 | """ 46 | if schedule_type == "linear": 47 | betas = self._linear_beta_schedule(beta_start, beta_end) 48 | elif schedule_type == "cosine": 49 | betas = self._cosine_beta_schedule() 50 | betas = betas * (F**2) 51 | else: 52 | raise ValueError(f"Unknown schedule_type: {schedule_type}") 53 | 54 | betas = torch.clip(betas, 0.0001, 0.9999) 55 | alphas = 1 - betas 56 | alphas_cumprod = torch.cumprod(alphas, dim=0) 57 | 58 | self.alphas = alphas 59 | self.betas = betas 60 | self.alphas_cumprod = alphas_cumprod 61 | -------------------------------------------------------------------------------- /plans/FSDP_NOTES.md: -------------------------------------------------------------------------------- 1 | # FSDP2 Implementation Notes 2 | 3 | ## Current State (Manual FSDP2) 4 | 5 | ModelFactory manually applies FSDP2 via `fully_shard()` in `_apply_fsdp()`. This requires: 6 | - DTensor detection hacks in trainer.py (`_has_dtensor_params()`) 7 | - Manual device placement with `torch.cuda.set_device(local_rank)` 8 | - Skipping model in `accelerator.prepare()` to avoid DDP wrapping conflict 9 | 10 | ## Why Manual Was Used 11 | 12 | The `FSDPConfig.transformer_layers` lets you pass actual class objects to specify which layers get sharded: 13 | ```python 14 | fsdp_config = FSDPConfig( 15 | enabled=True, 16 | transformer_layers=[ResidualBlock, TransformerBlock], # actual classes 17 | ) 18 | ``` 19 | 20 | ## Better Approach: Accelerate's Native FSDP2 21 | 22 | Accelerate supports FSDP2 via `FullyShardedDataParallelPlugin` with `fsdp_version=2`. 23 | 24 | ```python 25 | from accelerate import FullyShardedDataParallelPlugin, Accelerator 26 | 27 | fsdp_plugin = FullyShardedDataParallelPlugin( 28 | fsdp_version=2, 29 | transformer_layer_cls_to_wrap="ResidualBlock,TransformerBlock", # class names as strings 30 | use_orig_params=True, # needed for frozen/trainable parameter mixing 31 | ) 32 | accelerator = Accelerator(fsdp_plugin=fsdp_plugin) 33 | ``` 34 | 35 | ## Why Accelerate FSDP2 is Better 36 | 37 | 1. No DTensor detection hacks needed 38 | 2. No manual device placement 39 | 3. No skipping model in `prepare()` 40 | 4. Accelerate handles everything automatically 41 | 5. `use_orig_params=True` supports frozen layers (like our frozen decoder) 42 | 43 | ## Migration TODO 44 | 45 | 1. Remove manual FSDP2 from ModelFactory (`_apply_fsdp()`) 46 | 2. Remove DTensor detection from trainer.py 47 | 3. Remove `torch.cuda.set_device()` hack 48 | 4. Add `fsdp_plugin` parameter to Pipeline 49 | 5. Pass layer class names as comma-separated string instead of class objects 50 | 6. Update train_ditty.py to use new API 51 | 52 | ## Why Manual FSDP2 is Required for QLoRA 53 | 54 | QLoRA + FSDP2 requires special handling that accelerate's plugin doesn't fully support: 55 | 56 | 1. **Rank-based device loading**: Rank 0 loads quantized weights to CPU, other ranks load to meta device 57 | 2. **FSDP2 distributes from rank 0**: After loading, FSDP2's `fully_shard()` distributes shards to all ranks 58 | 3. **`bnb_4bit_quant_storage`**: Must be set to enable FSDP-QLoRA compatibility (e.g., `torch.bfloat16`) 59 | 60 | From `_load_quantized_model()` in model_factory.py: 61 | ```python 62 | parallel( 63 | load_and_quantize_parallel, 64 | iter(weights.items()), 65 | model=model, 66 | to_cpu=(local_rank == 0), # rank 0 loads to CPU 67 | to_meta=(local_rank != 0), # others load to meta 68 | ) 69 | ``` 70 | 71 | This pattern is documented in Answer.AI's fsdp_qlora implementation and bitsandbytes docs. 72 | 73 | ## References 74 | 75 | - https://huggingface.co/docs/accelerate/en/usage_guides/fsdp 76 | - https://huggingface.co/docs/accelerate/en/concept_guides/fsdp1_vs_fsdp2 77 | - https://github.com/huggingface/accelerate/issues/2873 78 | - https://huggingface.co/docs/bitsandbytes/main/en/fsdp_qlora 79 | - https://www.answer.ai/posts/2024-03-14-fsdp-qlora-deep-dive.html 80 | - https://github.com/AnswerDotAI/fsdp_qlora 81 | -------------------------------------------------------------------------------- /plans/pipeline_print_method.md: -------------------------------------------------------------------------------- 1 | # Spec: `Pipeline.print()` Method 2 | 3 | ## Location 4 | `Pipeline` class in `ditty/lib/ditty/pipeline.py` 5 | 6 | ## Signature 7 | ```python 8 | def print(self) -> None: 9 | ``` 10 | 11 | ## Behavior 12 | 13 | ### 1. Build the model 14 | - Call `self.model_factory.build()` (lazy - only if `self.model` not already set) 15 | - Store result in `self.model` for reuse 16 | 17 | ### 2. Print model architecture 18 | - Class name 19 | - Total / trainable / frozen parameter counts 20 | - Model attributes (vocab_size, embed_dim, hidden_dim, latent_dim, num_layers, num_heads) 21 | - Model contract from `self.model_factory.contract` 22 | - Actual `repr(model)` or nn.Module layer structure 23 | 24 | ### 3. Print data flow diagram 25 | ``` 26 | Dataset 27 | │ 28 | ▼ 29 | PreProcessor 1: token_masker {'mask_prob': 0.15} 30 | │ batch:2:i64 -> batch:2:i64 31 | ▼ 32 | PreProcessor 2: forward_kwargs_injector 33 | │ batch:2:i64 -> batch:2:i64 34 | ▼ 35 | ╔═══════════════════════════════════════╗ 36 | ║ MODEL: ExampleModel ║ 37 | ║ batch:2:i64 -> logits:3:f, hidden:3:f║ 38 | ╚═══════════════════════════════════════╝ 39 | │ 40 | ▼ 41 | PostProcessor 1: target_extractor 42 | │ logits:3:f, hidden:3:f -> logits:3:f, hidden:3:f 43 | ▼ 44 | LOSS: CompositeLoss 45 | • masked_cross_entropy_loss (weight=1.0) 46 | logits:3:f | ctx.target:2:i64, ctx.mask:2:f -> loss:0:f 47 | • hidden_regularizer (weight=1.0) 48 | | ctx.hidden_states:3:f -> loss:0:f 49 | ``` 50 | 51 | ### 4. Run contract validation 52 | - Parse all contracts from preprocessors, model_factory, postprocessors, loss_calculator 53 | - Run `validate_pipeline_chain()` 54 | - Print PASSED / FAILED with specific errors 55 | 56 | ## Output Format 57 | ``` 58 | ====================================================================== 59 | =================== DITTY PIPELINE ARCHITECTURE ====================== 60 | ====================================================================== 61 | 62 | >>> MODEL 63 | Class: ExampleModel 64 | Parameters: 1,234,567 total, 1,000,000 trainable, 234,567 frozen 65 | Config: {'vocab_size': 1000, 'embed_dim': 256, 'hidden_dim': 512} 66 | Contract: batch:2:i64 -> logits:3:f, hidden:3:f 67 | 68 | >>> DATA FLOW 69 | [diagram as above] 70 | 71 | >>> CONTRACT VALIDATION 72 | Status: PASSED 73 | All tensor shapes and dtypes chain correctly through the pipeline. 74 | 75 | ====================================================================== 76 | ``` 77 | 78 | ## Changes Required 79 | 80 | 1. **Add `print()` method to `Pipeline` class** in `pipeline.py` 81 | - Access `self.model_factory.contract` for model contract 82 | - Build model if needed to get actual architecture 83 | - Use `self.preprocessors`, `self.postprocessors`, `self.loss_calculator` 84 | 85 | 2. **Remove from `example.py`:** 86 | - Delete `print_pipeline()` function 87 | - Update example functions to use `pipeline.print()` after construction 88 | 89 | 3. **Remove from `__init__.py`:** 90 | - Delete `from .example import print_pipeline` 91 | - Delete `"print_pipeline"` from `__all__` 92 | 93 | ## Usage 94 | ```python 95 | pipeline = Pipeline( 96 | model_factory=model_factory, 97 | dataset=dataset, 98 | ... 99 | ) 100 | pipeline.print() # prints full architecture 101 | pipeline.run() # runs training 102 | ``` 103 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | 3 | output/ 4 | lib/__pycache__/**/* 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | .vscode/* 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /lib/ditty/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss calculator abstraction for ditty trainers. 3 | 4 | Architecture: 5 | batch -> preprocess -> model.forward -> postprocess -> loss_calc(model_output, ctx) 6 | 7 | LossCalculator receives the full model output tuple and context dict, 8 | allowing flexible loss computation across multiple model outputs. 9 | """ 10 | from abc import ABC, abstractmethod 11 | from dataclasses import dataclass, field 12 | from typing import Dict, Tuple, Optional, Any 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | 17 | from .base import DittyBase 18 | from .processors import Context 19 | 20 | 21 | @dataclass 22 | class LossOutput: 23 | loss: torch.Tensor 24 | metrics: Dict[str, float] = field(default_factory=dict) 25 | 26 | 27 | class LossCalculator(DittyBase, ABC): 28 | def __init__( 29 | self, 30 | output_index: int = 0, 31 | target_key: str = "target", 32 | mask_key: Optional[str] = None, 33 | contract: str = "", 34 | ): 35 | super().__init__(contract=contract) 36 | self.output_index = output_index 37 | self.target_key = target_key 38 | self.mask_key = mask_key 39 | 40 | def get_prediction(self, model_output: Tuple[Any, ...]) -> torch.Tensor: 41 | return model_output[self.output_index] 42 | 43 | def get_target(self, ctx: Context) -> torch.Tensor: 44 | return ctx[self.target_key] 45 | 46 | def get_mask(self, ctx: Context) -> Optional[torch.Tensor]: 47 | return ctx.get(self.mask_key) if self.mask_key else None 48 | 49 | @abstractmethod 50 | def compute(self, model_output: Tuple[Any, ...], ctx: Context) -> LossOutput: 51 | """ 52 | Compute loss from model output and context. 53 | 54 | Args: 55 | model_output: Tuple of tensors from model forward pass 56 | ctx: Context dict populated by preprocessors 57 | 58 | Returns: 59 | LossOutput with loss tensor and metrics dict 60 | """ 61 | pass 62 | 63 | 64 | class ReductionLoss(LossCalculator, ABC): 65 | """Base for losses with reduction and mask support (MSE, L1, etc).""" 66 | 67 | def __init__(self, reduction: str = "mean", mask_key: str = "mask", **kwargs): 68 | super().__init__(mask_key=mask_key, **kwargs) 69 | self.reduction = reduction 70 | 71 | def apply_mask(self, loss: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: 72 | if mask is not None: 73 | return loss.sum() / mask.sum().clamp(min=1) if self.reduction == "mean" else loss.sum() 74 | return loss 75 | 76 | 77 | class MSELoss(ReductionLoss): 78 | def compute(self, model_output: Tuple[Any, ...], ctx: Context) -> LossOutput: 79 | pred, target, mask = self.get_prediction(model_output), self.get_target(ctx), self.get_mask(ctx) 80 | if mask is not None: 81 | loss = F.mse_loss(pred * mask, target * mask, reduction="none") 82 | loss = self.apply_mask(loss, mask) 83 | else: 84 | loss = F.mse_loss(pred, target, reduction=self.reduction) 85 | return LossOutput(loss=loss, metrics={"mse": loss.item()}) 86 | 87 | 88 | class L1Loss(ReductionLoss): 89 | def compute(self, model_output: Tuple[Any, ...], ctx: Context) -> LossOutput: 90 | pred, target, mask = self.get_prediction(model_output), self.get_target(ctx), self.get_mask(ctx) 91 | if mask is not None: 92 | loss = F.l1_loss(pred * mask, target * mask, reduction="none") 93 | loss = self.apply_mask(loss, mask) 94 | else: 95 | loss = F.l1_loss(pred, target, reduction=self.reduction) 96 | return LossOutput(loss=loss, metrics={"l1": loss.item()}) 97 | 98 | 99 | class CrossEntropyLoss(LossCalculator): 100 | def __init__(self, ignore_index: int = -100, **kwargs): 101 | super().__init__(**kwargs) 102 | self.ignore_index = ignore_index 103 | 104 | def compute(self, model_output: Tuple[Any, ...], ctx: Context) -> LossOutput: 105 | pred, target, mask = self.get_prediction(model_output), self.get_target(ctx), self.get_mask(ctx) 106 | if pred.dim() > 2: 107 | pred = pred.reshape(-1, pred.size(-1)) 108 | if target.dim() > 1: 109 | target = target.reshape(-1) 110 | if mask is not None: 111 | mask = mask.reshape(-1) if mask.dim() > 1 else mask 112 | loss_per_token = F.cross_entropy(pred, target, reduction="none") 113 | loss = (loss_per_token * mask).sum() / mask.sum().clamp(min=1) 114 | else: 115 | loss = F.cross_entropy(pred, target, ignore_index=self.ignore_index) 116 | return LossOutput(loss=loss, metrics={"ce": loss.item()}) 117 | 118 | 119 | class CompositeLoss(LossCalculator): 120 | """Combine multiple loss calculators with weights.""" 121 | 122 | def __init__(self, losses: list[tuple[LossCalculator, float]]): 123 | super().__init__(contract="") 124 | self.losses = losses 125 | 126 | def compute(self, model_output: Tuple[Any, ...], ctx: Context) -> LossOutput: 127 | device = ctx.get("device", "cuda") 128 | total_loss = torch.tensor(0.0, device=device) 129 | all_metrics = {} 130 | 131 | for loss_calc, weight in self.losses: 132 | if weight == 0.0: 133 | continue 134 | output = loss_calc.compute(model_output, ctx) 135 | total_loss = total_loss + weight * output.loss 136 | for k, v in output.metrics.items(): 137 | all_metrics[f"{loss_calc.name}/{k}"] = v 138 | 139 | all_metrics["total"] = total_loss.item() 140 | return LossOutput(loss=total_loss, metrics=all_metrics) -------------------------------------------------------------------------------- /lib/ditty/data.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | import torch 4 | from torch.utils.data import DataLoader, RandomSampler 5 | import datasets 6 | from transformers.trainer_pt_utils import ( 7 | LabelSmoother, 8 | LengthGroupedSampler, 9 | ) 10 | from transformers.trainer_utils import RemoveColumnsCollator, set_seed 11 | from transformers.data.data_collator import ( 12 | DataCollator, 13 | DataCollatorWithPadding, 14 | DataCollatorForLanguageModeling, 15 | default_data_collator, 16 | ) 17 | from transformers import PreTrainedTokenizerBase 18 | from typing import Callable 19 | 20 | from logging import getLogger 21 | 22 | logger = getLogger() 23 | 24 | 25 | @dataclass(kw_only=True) 26 | class Data: 27 | dataset: datasets.Dataset | None = None 28 | split: str = "train" 29 | tokenizer: PreTrainedTokenizerBase 30 | seed: Optional[int] = None 31 | batch_size: int = 8 32 | grad_accum: int = 1 33 | length_column_name: Optional[str] = None 34 | group_by_length: bool = False 35 | dataloader_num_workers: int = 0 36 | dataloader_pin_memory: bool = True 37 | dataloader_drop_last: bool = False 38 | load_kwargs: Optional[dict] = None 39 | collator: Optional[DataCollator] = None 40 | remove_unused_columns: bool = False 41 | 42 | def __post_init__(self): 43 | if self.dataset is None and self.load_kwargs is None: 44 | raise ValueError( 45 | "dataset and load_kwargs cannot both be None. Please either pass an instance of Dataset or a dict of args to load the dataset with." 46 | ) 47 | 48 | if self.dataset is None: 49 | kwargs = self.load_kwargs or {} 50 | 51 | self.dataset = datasets.load_dataset(**kwargs)[self.split] 52 | 53 | if not self.collator: 54 | collator = DataCollatorForLanguageModeling( 55 | tokenizer=self.tokenizer, return_tensors="pt", mlm=False 56 | ) 57 | self.collator = collator 58 | 59 | def _get_sampler(self) -> Optional[torch.utils.data.Sampler]: 60 | generator = torch.Generator() 61 | 62 | if self.seed: 63 | generator.manual_seed(self.seed) 64 | 65 | # Build the sampler. 66 | if self.group_by_length: 67 | lengths = ( 68 | self.dataset[self.length_column_name] 69 | if self.length_column_name in self.dataset.column_names 70 | else None 71 | ) 72 | model_input_name = self.tokenizer.model_input_names[0] 73 | return LengthGroupedSampler( 74 | self.batch_size * self.grad_accum, 75 | dataset=self.dataset, 76 | lengths=lengths, 77 | model_input_name=model_input_name, 78 | generator=generator, 79 | ) 80 | else: 81 | return RandomSampler(self.dataset, generator=generator) 82 | 83 | def _get_collator_with_removed_columns( 84 | self, 85 | data_collator: Callable, 86 | ) -> Callable: 87 | """Wrap the data collator in a callable removing unused columns.""" 88 | if not self.remove_unused_columns: 89 | return data_collator 90 | 91 | remove_columns_collator = RemoveColumnsCollator( 92 | data_collator=data_collator, 93 | signature_columns=None, 94 | logger=logger, 95 | description=self.split, 96 | model_name=self.model.__class__.__name__, 97 | ) 98 | return remove_columns_collator 99 | 100 | def _remove_unused_columns(self, dataset: "datasets.Dataset"): 101 | if not self.remove_unused_columns: 102 | return dataset 103 | 104 | ignored_columns = list(set(dataset.column_names)) 105 | 106 | return dataset.remove_columns(ignored_columns) 107 | 108 | def prepare(self, pipeline: list[(str, Callable, dict)]): 109 | if self.dataset is None: 110 | raise ValueError("Dataset not set.") 111 | 112 | for op_name, func, kwargs in pipeline: 113 | op = getattr(self.dataset, op_name) 114 | 115 | if not func: 116 | self.dataset = op(**kwargs) 117 | else: 118 | self.dataset = op(func, **kwargs) 119 | 120 | return self._get_dataloader() 121 | 122 | def _seed_worker(self): 123 | if not self.seed: 124 | worker_seed = torch.initial_seed() % 2**32 125 | else: 126 | worker_seed = self.seed 127 | 128 | set_seed(worker_seed) 129 | 130 | def _get_dataloader(self) -> DataLoader: 131 | """ 132 | Returns a [`~torch.utils.data.DataLoader`]. 133 | 134 | Will use no sampler if `dataset` does not implement `__len__`, a random sampler (adapted to distributed 135 | training if necessary) otherwise. 136 | 137 | Subclass and override this method if you want to inject some custom behavior. 138 | """ 139 | dataset = self.dataset 140 | dataset = self._remove_unused_columns(dataset) 141 | 142 | data_collator = self.collator 143 | data_collator = self._get_collator_with_removed_columns(data_collator) 144 | 145 | if isinstance(dataset, torch.utils.data.IterableDataset): 146 | return DataLoader( 147 | dataset, 148 | batch_size=self.batch_size, 149 | collate_fn=data_collator, 150 | num_workers=self.dataloader_num_workers, 151 | pin_memory=self.dataloader_pin_memory, 152 | ) 153 | 154 | sampler = self._get_sampler() 155 | 156 | return DataLoader( 157 | dataset, 158 | batch_size=self.batch_size, 159 | sampler=sampler, 160 | collate_fn=data_collator, 161 | drop_last=self.dataloader_drop_last, 162 | num_workers=self.dataloader_num_workers, 163 | pin_memory=self.dataloader_pin_memory, 164 | worker_init_fn=self._seed_worker, 165 | ) 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Ditty 2 | 3 | A distributed training library for PyTorch. 4 | 5 | ## What 6 | 7 | A flexible library for training and finetuning models with modern distributed training support. Integrates with the HuggingFace ecosystem (Accelerate, Transformers, Datasets, PEFT, Hub) while providing a custom training loop and pipeline architecture. Works with any PyTorch model - from pretrained HuggingFace models to custom architectures like diffusion models. 8 | 9 | Ditty has support for: 10 | - Full training and finetuning 11 | - LORA, QLORA 12 | - 8bit, 4bit quantization 13 | - FP16, BFLOAT16, FP8 (via transformer-engine) 14 | - 8bit Adam (torchao or bitsandbytes backends) 15 | - FSDP2 with DTensor-based sharding 16 | - FSDP + QLORA (needs testing with FSDP2) 17 | - torch.compile compatible 18 | - Checkpointing and resume 19 | - Pushing to HuggingFace Hub 20 | 21 | ### FSDP2 22 | 23 | Ditty uses PyTorch's FSDP2 with per-parameter DTensor sharding. This provides: 24 | - Memory-efficient training across multiple GPUs 25 | - Compatible with torchao's 8-bit optimizers 26 | - Works with torch.compile 27 | 28 | To enable FSDP2, pass an `FSDPConfig` to your `ModelFactory`: 29 | 30 | ```python 31 | from ditty import ModelFactory, FSDPConfig 32 | 33 | fsdp_config = FSDPConfig( 34 | enabled=True, 35 | transformer_layers=[MyTransformerBlock], # Layers to shard 36 | ) 37 | 38 | model_factory = ModelFactory.from_instance( 39 | my_model, 40 | fsdp_config=fsdp_config, 41 | ) 42 | ``` 43 | 44 | ### 8-bit Optimizers 45 | 46 | Two backends are available for 8-bit Adam: 47 | 48 | - `torchao` (default) - Works with FSDP2/DTensor, torch.compile compatible 49 | - `bnb` - bitsandbytes, does not work with FSDP2 50 | 51 | ```python 52 | pipeline = Pipeline( 53 | model_factory=model_factory, 54 | dataset=dataset, 55 | use_8bit_optim=True, 56 | optim_backend="torchao", # or "bnb" 57 | ... 58 | ) 59 | ``` 60 | 61 | ### FP8 Training 62 | 63 | FP8 training is supported via [NVIDIA Transformer Engine](https://github.com/NVIDIA/TransformerEngine). This provides compute speedups on supported GPUs (H100, Ada Lovelace). 64 | 65 | To use FP8: 66 | 1. Install transformer-engine: `pip install transformer-engine[pytorch]` 67 | 2. Pass `accelerator_kwargs={"mixed_precision": "fp8"}` to Pipeline 68 | 69 | ## Architecture 70 | 71 | Ditty uses a pipeline pattern for training: 72 | 73 | ``` 74 | batch -> preprocessors -> model.forward -> postprocessors -> loss_calculator 75 | ``` 76 | 77 | This allows flexible composition of training workflows without modifying the core trainer. 78 | 79 | ## Classes 80 | 81 | ### Pipeline 82 | 83 | The main entry point. Pass a `ModelFactory`, dataset, `LossCalculator`, and optional pre/post processors: 84 | 85 | ```python 86 | from ditty import Pipeline, ModelFactory, CompositeLoss 87 | 88 | model_factory = ModelFactory.from_instance(my_model) 89 | # or: ModelFactory.from_checkpoint(path, model_class, **kwargs) 90 | 91 | pipeline = Pipeline( 92 | model_factory=model_factory, 93 | dataset=my_dataset, 94 | collate_fn=my_collate_fn, 95 | loss_calculator=my_loss, 96 | preprocessors=[...], 97 | postprocessors=[...], 98 | output_dir="./output", 99 | fp16=True, 100 | use_8bit_optim=True, 101 | lr=2e-4, 102 | epochs=10, 103 | ) 104 | pipeline.run() 105 | ``` 106 | 107 | ### ModelFactory 108 | 109 | Handles model creation, checkpoint loading, and FSDP wrapping: 110 | 111 | - `ModelFactory.from_instance(model)` - Wrap an existing model instance 112 | - `ModelFactory.from_checkpoint(path, model_class, **kwargs)` - Load from checkpoint 113 | 114 | ### PreProcessor / PostProcessor 115 | 116 | Transform data before the model or outputs after: 117 | 118 | ```python 119 | from ditty.processors import PreProcessor, PostProcessor, Context 120 | 121 | class MyPreProcessor(PreProcessor): 122 | def __init__(self): 123 | super().__init__(contract="batch:3:i64 -> batch:3:i64 | ctx.my_key:0:i64") 124 | 125 | def process(self, batch, ctx: Context): 126 | ctx["forward_kwargs"] = ctx.get("forward_kwargs", {}) 127 | ctx["forward_kwargs"]["my_param"] = some_value 128 | return batch, ctx 129 | 130 | class MyPostProcessor(PostProcessor): 131 | def process(self, model_output, ctx: Context): 132 | ctx["target"] = extract_target(model_output, ctx["original_batch"]) 133 | return model_output, ctx 134 | ``` 135 | 136 | ### LossCalculator 137 | 138 | Compute loss from model outputs. Use `output_index` to select from tuple outputs: 139 | 140 | ```python 141 | from ditty.loss import LossCalculator, LossOutput, CompositeLoss 142 | 143 | class MyLoss(LossCalculator): 144 | def __init__(self): 145 | super().__init__(output_index=0, target_key="target", mask_key="mask") 146 | 147 | def compute(self, model_output, ctx) -> LossOutput: 148 | pred = self.get_prediction(model_output) 149 | target = self.get_target(ctx) 150 | mask = self.get_mask(ctx) 151 | loss = F.mse_loss(pred, target) 152 | return LossOutput(loss=loss, metrics={"mse": loss.item()}) 153 | 154 | # Combine multiple losses with weights 155 | loss_calculator = CompositeLoss([ 156 | (MSELoss(output_index=0), 1.0), 157 | (CrossEntropyLoss(output_index=1), 0.1), 158 | ]) 159 | ``` 160 | 161 | ### Contracts (Optional) 162 | 163 | Processors and losses can declare contracts for validation: 164 | 165 | ```python 166 | # Terse syntax: "input_shape -> output_shape | ctx.key:shape:dtype" 167 | contract = "batch:3:i64 -> batch:3:i64 | ctx.t:1:i64" 168 | ``` 169 | 170 | Pipeline validates that contracts chain together correctly at initialization. 171 | 172 | ## Setup 173 | 174 | ``` 175 | pip install ditty 176 | ``` 177 | 178 | ## Attribution 179 | 180 | ### Huggingface 181 | 182 | Portions of this library reference Huggingface's transformers Trainer class and in some cases re-implement functions from Trainer. 183 | 184 | ### Answer.ai 185 | 186 | Portions of this library implement Answer.ai's method for FSDP+QLORA. The original work can be found at: https://github.com/AnswerDotAI/fsdp_qlora 187 | 188 | ## License 189 | 190 | Apache V2 - see the LICENSE file for full text. 191 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /lib/ditty/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from dataclasses import dataclass, field 5 | from logging import getLogger 6 | from typing import Optional, Dict, Any, List 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | logger = getLogger("ditty_checkpoint") 13 | 14 | 15 | @dataclass 16 | class Checkpoint: 17 | """Container for all checkpoint data.""" 18 | model_state: Optional[Dict[str, Any]] = None 19 | optimizer_state: Optional[Dict[str, Any]] = None 20 | scheduler_state: Optional[Dict[str, Any]] = None 21 | training_state: Dict[str, Any] = field(default_factory=dict) 22 | scaler_state: Optional[Dict[str, Any]] = None 23 | rng_states: Dict[str, Any] = field(default_factory=dict) 24 | 25 | 26 | class CheckpointManager: 27 | """ 28 | Unified checkpoint manager for ditty training. 29 | 30 | Handles saving and loading of: 31 | - Model weights 32 | - Optimizer state 33 | - Scheduler state 34 | - Training state (epoch, steps, etc.) 35 | - Gradient scaler state 36 | - RNG states for reproducibility 37 | 38 | This replaces accelerate's save_state/load_state to give us control 39 | over the loading order (load before prepare() instead of after). 40 | """ 41 | 42 | def __init__(self, output_dir: str): 43 | self.output_dir = output_dir 44 | self.checkpoints_dir = os.path.join(output_dir, "checkpoints") 45 | 46 | def _get_checkpoint_path(self, checkpoint_num: int) -> str: 47 | return os.path.join(self.checkpoints_dir, f"checkpoint_{checkpoint_num}") 48 | 49 | def _get_latest_checkpoint_num(self) -> Optional[int]: 50 | if not os.path.exists(self.checkpoints_dir): 51 | return None 52 | 53 | checkpoint_dirs = [] 54 | for name in os.listdir(self.checkpoints_dir): 55 | if name.startswith("checkpoint_"): 56 | try: 57 | num = int(name.split("_")[1]) 58 | checkpoint_dirs.append(num) 59 | except (IndexError, ValueError): 60 | continue 61 | 62 | if not checkpoint_dirs: 63 | return None 64 | 65 | return max(checkpoint_dirs) 66 | 67 | def save( 68 | self, 69 | checkpoint_num: int, 70 | model: nn.Module, 71 | optimizer: torch.optim.Optimizer, 72 | training_state: Dict[str, Any], 73 | scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, 74 | scaler: Optional[torch.amp.GradScaler] = None, 75 | is_fsdp: bool = False, 76 | rank: int = 0, 77 | local_rank: int = 0, 78 | ): 79 | """Save a complete training checkpoint.""" 80 | checkpoint_path = self._get_checkpoint_path(checkpoint_num) 81 | os.makedirs(checkpoint_path, exist_ok=True) 82 | 83 | # Save model weights (full state dict for FSDP) 84 | model_path = os.path.join(self.output_dir, "dist", "model.pt") 85 | os.makedirs(os.path.dirname(model_path), exist_ok=True) 86 | 87 | # Get the unwrapped model if compiled (avoids _orig_mod. prefix in state_dict) 88 | unwrapped_model = getattr(model, '_orig_mod', model) 89 | 90 | if is_fsdp: 91 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 92 | from torch.distributed.fsdp import StateDictType, FullStateDictConfig 93 | from torch.distributed.tensor import DTensor 94 | 95 | # For FSDP2 with fully_shard, we need to use get_model_state_dict 96 | try: 97 | from torch.distributed.checkpoint.state_dict import get_model_state_dict 98 | model_state = get_model_state_dict(unwrapped_model) 99 | # Convert DTensors to regular tensors for portable checkpoints 100 | model_state = { 101 | k: v.full_tensor().cpu() if isinstance(v, DTensor) else v.cpu() 102 | for k, v in model_state.items() 103 | } 104 | except ImportError: 105 | # Fallback for older torch versions 106 | model_state = {k: v.cpu() for k, v in unwrapped_model.state_dict().items()} 107 | 108 | if rank == 0: 109 | torch.save(model_state, model_path) 110 | else: 111 | if rank == 0: 112 | torch.save(unwrapped_model.state_dict(), model_path) 113 | 114 | # Save optimizer state 115 | optimizer_path = os.path.join(checkpoint_path, "optimizer.bin") 116 | if is_fsdp: 117 | try: 118 | from torch.distributed.checkpoint.state_dict import get_optimizer_state_dict 119 | optim_state = get_optimizer_state_dict(model, optimizer) 120 | except ImportError: 121 | optim_state = optimizer.state_dict() 122 | else: 123 | optim_state = optimizer.state_dict() 124 | 125 | if rank == 0: 126 | torch.save(optim_state, optimizer_path) 127 | 128 | # Save scheduler state 129 | if scheduler is not None and rank == 0: 130 | scheduler_path = os.path.join(checkpoint_path, "scheduler.pt") 131 | torch.save(scheduler.state_dict(), scheduler_path) 132 | 133 | # Save training state 134 | if rank == 0: 135 | training_state_path = os.path.join(checkpoint_path, "training_state.pt") 136 | torch.save(training_state, training_state_path) 137 | 138 | # Save scaler state 139 | if scaler is not None and rank == 0: 140 | scaler_path = os.path.join(checkpoint_path, "scaler.pt") 141 | torch.save(scaler.state_dict(), scaler_path) 142 | 143 | # Save RNG states for this rank 144 | rng_state = { 145 | "python": random.getstate(), 146 | "numpy": np.random.get_state(), 147 | "torch": torch.get_rng_state(), 148 | } 149 | if torch.cuda.is_available(): 150 | rng_state["cuda"] = torch.cuda.get_rng_state(local_rank) 151 | 152 | rng_path = os.path.join(checkpoint_path, f"rng_state_{rank}.pt") 153 | torch.save(rng_state, rng_path) 154 | 155 | if rank == 0: 156 | logger.info(f"Saved checkpoint to {checkpoint_path}") 157 | 158 | def load(self, checkpoint_num: Optional[int] = None) -> Optional[Checkpoint]: 159 | """ 160 | Load a checkpoint. If checkpoint_num is None, loads the latest. 161 | Returns None if no checkpoint exists. 162 | """ 163 | if checkpoint_num is None: 164 | checkpoint_num = self._get_latest_checkpoint_num() 165 | if checkpoint_num is None: 166 | return None 167 | 168 | checkpoint_path = self._get_checkpoint_path(checkpoint_num) 169 | if not os.path.exists(checkpoint_path): 170 | return None 171 | 172 | checkpoint = Checkpoint() 173 | 174 | # Load model weights 175 | model_path = os.path.join(self.output_dir, "dist", "model.pt") 176 | if os.path.exists(model_path): 177 | state_dict = torch.load(model_path, map_location="cpu", weights_only=False) 178 | # Strip _orig_mod. prefix added by torch.compile 179 | checkpoint.model_state = { 180 | k.replace("_orig_mod.", ""): v for k, v in state_dict.items() 181 | } 182 | 183 | # Load optimizer state 184 | optimizer_path = os.path.join(checkpoint_path, "optimizer.bin") 185 | if os.path.exists(optimizer_path): 186 | checkpoint.optimizer_state = torch.load(optimizer_path, map_location="cpu", weights_only=False) 187 | 188 | # Load scheduler state 189 | scheduler_path = os.path.join(checkpoint_path, "scheduler.pt") 190 | if os.path.exists(scheduler_path): 191 | checkpoint.scheduler_state = torch.load(scheduler_path, map_location="cpu", weights_only=False) 192 | 193 | # Load training state (new format or legacy accelerate format) 194 | training_state_path = os.path.join(checkpoint_path, "training_state.pt") 195 | if os.path.exists(training_state_path): 196 | checkpoint.training_state = torch.load(training_state_path, map_location="cpu", weights_only=False) 197 | else: 198 | # Try legacy accelerate format 199 | legacy_path = os.path.join(checkpoint_path, "custom_checkpoint_0.pkl") 200 | if os.path.exists(legacy_path): 201 | try: 202 | checkpoint.training_state = torch.load(legacy_path, map_location="cpu", weights_only=False) 203 | logger.info("Loaded training state from legacy accelerate format") 204 | except Exception as e: 205 | logger.warning(f"Failed to load legacy training state: {e}") 206 | 207 | # Load scaler state 208 | scaler_path = os.path.join(checkpoint_path, "scaler.pt") 209 | if os.path.exists(scaler_path): 210 | checkpoint.scaler_state = torch.load(scaler_path, map_location="cpu", weights_only=False) 211 | 212 | logger.info(f"Loaded checkpoint from {checkpoint_path}") 213 | return checkpoint 214 | 215 | def load_rng_state(self, checkpoint_num: Optional[int] = None, rank: int = 0, local_rank: int = 0): 216 | """Load and restore RNG states for a specific rank.""" 217 | if checkpoint_num is None: 218 | checkpoint_num = self._get_latest_checkpoint_num() 219 | if checkpoint_num is None: 220 | return 221 | 222 | checkpoint_path = self._get_checkpoint_path(checkpoint_num) 223 | rng_path = os.path.join(checkpoint_path, f"rng_state_{rank}.pt") 224 | 225 | if not os.path.exists(rng_path): 226 | # Try legacy format 227 | rng_path = os.path.join(checkpoint_path, f"random_states_{rank}.pkl") 228 | if not os.path.exists(rng_path): 229 | return 230 | 231 | rng_state = torch.load(rng_path, map_location="cpu", weights_only=False) 232 | 233 | # Handle our new format 234 | if "python" in rng_state: 235 | random.setstate(rng_state["python"]) 236 | if "numpy" in rng_state: 237 | np.random.set_state(rng_state["numpy"]) 238 | if "torch" in rng_state: 239 | torch.set_rng_state(rng_state["torch"]) 240 | if "cuda" in rng_state and torch.cuda.is_available(): 241 | torch.cuda.set_rng_state(rng_state["cuda"], local_rank) 242 | 243 | # Handle accelerate format 244 | if "random_state" in rng_state: 245 | random.setstate(rng_state["random_state"]) 246 | if "numpy_random_seed" in rng_state: 247 | np.random.set_state(rng_state["numpy_random_seed"]) 248 | if "torch_manual_seed" in rng_state: 249 | torch.set_rng_state(rng_state["torch_manual_seed"]) 250 | if "torch_cuda_manual_seed" in rng_state and torch.cuda.is_available(): 251 | torch.cuda.set_rng_state(rng_state["torch_cuda_manual_seed"], local_rank) 252 | 253 | def get_latest_checkpoint_num(self) -> Optional[int]: 254 | return self._get_latest_checkpoint_num() 255 | 256 | def apply_to_model(self, checkpoint: Checkpoint, model: nn.Module): 257 | """Apply checkpoint model state to a model.""" 258 | if checkpoint.model_state is not None: 259 | model.load_state_dict(checkpoint.model_state) 260 | logger.info("Loaded model weights from checkpoint") 261 | 262 | def apply_to_optimizer(self, checkpoint: Checkpoint, optimizer: torch.optim.Optimizer): 263 | """Apply checkpoint optimizer state to an optimizer.""" 264 | if checkpoint.optimizer_state is not None: 265 | optimizer.load_state_dict(checkpoint.optimizer_state) 266 | logger.info("Loaded optimizer state from checkpoint") 267 | 268 | def apply_to_scheduler(self, checkpoint: Checkpoint, scheduler: torch.optim.lr_scheduler.LRScheduler): 269 | """Apply checkpoint scheduler state to a scheduler.""" 270 | if checkpoint.scheduler_state is not None: 271 | scheduler.load_state_dict(checkpoint.scheduler_state) 272 | logger.info("Loaded scheduler state from checkpoint") 273 | 274 | def apply_to_scaler(self, checkpoint: Checkpoint, scaler: torch.amp.GradScaler): 275 | """Apply checkpoint scaler state to a gradient scaler.""" 276 | if checkpoint.scaler_state is not None: 277 | scaler.load_state_dict(checkpoint.scaler_state) 278 | logger.info("Loaded scaler state from checkpoint") 279 | -------------------------------------------------------------------------------- /lib/ditty/trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import time 3 | from .utils import convert_seconds_to_string_time 4 | from .loss import LossCalculator, MSELoss, LossOutput 5 | from .processors import PreProcessor, PostProcessor, Context 6 | from .checkpoint import CheckpointManager 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | from accelerate import Accelerator 11 | from accelerate.utils import set_seed 12 | from transformers.trainer_pt_utils import get_model_param_count 13 | import atexit 14 | import contextlib 15 | from logging import getLogger 16 | from typing import Optional, Any, List, Union, Callable 17 | import os 18 | 19 | 20 | def default_scheduler_factory(optimizer): 21 | return torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 22 | 23 | 24 | logger = getLogger("ditty_training") 25 | 26 | 27 | @dataclass(kw_only=True) 28 | class TrainerState: 29 | epoch: int = 0 30 | steps: int = 0 31 | total_steps: int = 0 32 | global_loss: float = 0.0 33 | 34 | def state_dict(self): 35 | return { 36 | "epoch": self.epoch, 37 | "steps": self.steps, 38 | "total_steps": self.total_steps, 39 | "global_loss": self.global_loss, 40 | } 41 | 42 | def load_state_dict(self, state_dict): 43 | self.epoch = state_dict.get("epoch", 0) 44 | self.steps = state_dict.get("steps", 0) 45 | self.total_steps = state_dict.get("total_steps", 0) 46 | self.global_loss = state_dict.get("global_loss", 0.0) 47 | 48 | 49 | @dataclass(kw_only=True) 50 | class Trainer: 51 | """ 52 | Training loop with pipeline pattern: 53 | batch -> preprocessors -> model.forward -> postprocessors -> loss_calc(pred, target) 54 | """ 55 | model: nn.Module 56 | optimizer: torch.optim.Optimizer 57 | accelerator: Accelerator 58 | dataset: DataLoader 59 | device: torch.device 60 | 61 | # Pipeline 62 | preprocessors: List[PreProcessor] = field(default_factory=list) 63 | postprocessors: List[PostProcessor] = field(default_factory=list) 64 | loss_calculator: LossCalculator = None # type: ignore[assignment] 65 | 66 | # Training config 67 | scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None 68 | use_scheduler: bool = True 69 | grad_accum: int = 1 70 | fp16: bool = False 71 | use_bfloat16: bool = False 72 | output_dir: str = "./output" 73 | checkpoint_every: int = 1000 74 | hf_hub_token: Optional[str] = None 75 | seed: Optional[int] = None 76 | metrics_logger: Optional[Any] = None 77 | log_every: int = 10 78 | max_grad_norm: Optional[float] = None 79 | shuffle_each_epoch: bool = True 80 | total_batches: Optional[int] = None 81 | is_fsdp: bool = False 82 | 83 | # Pre-loaded state (from CheckpointManager, loaded before Trainer creation) 84 | initial_state: Optional[TrainerState] = None 85 | 86 | def __post_init__(self): 87 | if self.seed: 88 | set_seed(self.seed) 89 | 90 | os.makedirs(self.output_dir, exist_ok=True) 91 | 92 | self.batch_size = self.dataset.batch_size 93 | self.preprocessors = self.preprocessors or [] 94 | self.postprocessors = self.postprocessors or [] 95 | self.loss_calculator = self.loss_calculator or MSELoss() 96 | 97 | if self.use_scheduler and not self.scheduler: 98 | self.scheduler = default_scheduler_factory(self.optimizer) 99 | 100 | if self.fp16 and self.use_bfloat16: 101 | self.f16_dtype = torch.bfloat16 102 | elif self.fp16: 103 | self.f16_dtype = torch.float16 104 | 105 | self.device = self.accelerator.device 106 | 107 | if self.is_fsdp: 108 | if self.use_scheduler: 109 | self.optimizer, self.dataset, self.scheduler = self.accelerator.prepare( 110 | self.optimizer, self.dataset, self.scheduler 111 | ) 112 | else: 113 | self.optimizer, self.dataset = self.accelerator.prepare( 114 | self.optimizer, self.dataset 115 | ) 116 | else: 117 | if self.use_scheduler: 118 | ( 119 | self.model, 120 | self.optimizer, 121 | self.dataset, 122 | self.scheduler, 123 | ) = self.accelerator.prepare( 124 | self.model, self.optimizer, self.dataset, self.scheduler 125 | ) 126 | else: 127 | self.model, self.optimizer, self.dataset = self.accelerator.prepare( 128 | self.model, self.optimizer, self.dataset 129 | ) 130 | 131 | # Use pre-loaded state if provided, otherwise start fresh 132 | if self.initial_state is not None: 133 | self.state = self.initial_state 134 | else: 135 | self.state = TrainerState() 136 | 137 | # Initialize checkpoint manager 138 | self.checkpoint_manager = CheckpointManager(self.output_dir) 139 | self._checkpoint_iteration = self.checkpoint_manager.get_latest_checkpoint_num() or 0 140 | if self.initial_state is not None: 141 | self._checkpoint_iteration += 1 142 | 143 | def _save(self, no_dist=False): 144 | rank = int(os.environ.get("RANK", 0)) 145 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 146 | 147 | if self.accelerator.is_main_process: 148 | logger.info(f"Saving checkpoint at step {self.state.steps} (total: {self.state.total_steps})") 149 | self.accelerator.wait_for_everyone() 150 | 151 | self.checkpoint_manager.save( 152 | checkpoint_num=self._checkpoint_iteration, 153 | model=self.accelerator.unwrap_model(self.model), 154 | optimizer=self.optimizer, 155 | training_state=self.state.state_dict(), 156 | scheduler=self.scheduler if self.use_scheduler else None, 157 | scaler=self.accelerator.scaler if hasattr(self.accelerator, 'scaler') and self.accelerator.scaler else None, 158 | is_fsdp=self.is_fsdp, 159 | rank=rank, 160 | local_rank=local_rank, 161 | ) 162 | self._checkpoint_iteration += 1 163 | 164 | def _log_pipeline(self): 165 | logger.info("Pipeline:") 166 | logger.info(f" preprocessors:") 167 | for p in self.preprocessors: 168 | logger.info(f" - {p}") 169 | logger.info(f" model: {self.model.__class__.__name__} ({get_model_param_count(self.model, trainable_only=True):,} params)") 170 | logger.info(f" postprocessors:") 171 | for p in self.postprocessors: 172 | logger.info(f" - {p}") 173 | logger.info(f" loss: {self.loss_calculator.__class__.__name__}") 174 | 175 | def _train_accelerate(self, epochs=1, max_steps=None): 176 | context_manager = contextlib.nullcontext() 177 | if self.fp16: 178 | context_manager = torch.autocast(device_type=self.device.type, dtype=self.f16_dtype) 179 | 180 | self.model.train() 181 | if self.total_batches is not None: 182 | total_batches = self.total_batches 183 | else: 184 | try: 185 | total_batches = len(self.dataset) * epochs 186 | except TypeError: 187 | total_batches = None 188 | start_time = time.time() 189 | 190 | atexit.register(self._save) 191 | 192 | for ep in range(self.state.epoch, epochs): 193 | dataset = self.dataset 194 | 195 | if self.shuffle_each_epoch and hasattr(dataset, 'set_epoch'): 196 | dataset.set_epoch(ep) 197 | 198 | if self.state.steps > 0: 199 | if self.accelerator.is_main_process: 200 | logger.info(f"Resuming from batch {self.state.steps}, skipping {self.state.steps} batches.") 201 | dataset = self.accelerator.skip_first_batches(self.dataset, self.state.steps) 202 | 203 | for batch in dataset: 204 | if batch is None: 205 | break 206 | 207 | original_batch = batch 208 | ctx: Context = { 209 | "epoch": ep, 210 | "step": self.state.steps, 211 | "total_steps": self.state.total_steps, 212 | "device": self.device, 213 | "original_batch": original_batch, 214 | } 215 | 216 | for preprocessor in self.preprocessors: 217 | result = preprocessor.process(batch, ctx) 218 | if result[0] is None: 219 | batch = None 220 | break 221 | batch, ctx = result 222 | 223 | if batch is None: 224 | continue 225 | 226 | with self.accelerator.accumulate(self.model): 227 | with context_manager: 228 | model_output = self.model(batch, **ctx.get("forward_kwargs", {})) 229 | if not isinstance(model_output, tuple): 230 | model_output = (model_output,) 231 | 232 | for postprocessor in self.postprocessors: 233 | model_output, ctx = postprocessor.process(model_output, ctx) 234 | 235 | loss_output = self.loss_calculator.compute(model_output, ctx) 236 | loss = loss_output.loss 237 | 238 | self.accelerator.backward(loss) 239 | if self.max_grad_norm is not None and self.accelerator.sync_gradients: 240 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 241 | batch_loss = loss.item() 242 | self.optimizer.step() 243 | if self.use_scheduler and self.scheduler: 244 | self.scheduler.step() 245 | self.optimizer.zero_grad(set_to_none=True) 246 | 247 | time_elapsed = time.time() - start_time 248 | if total_batches is not None: 249 | batches_per_epoch = total_batches // epochs if epochs > 0 else total_batches 250 | total_batches_done = ep * batches_per_epoch + self.state.steps 251 | current_epoch_decimal = total_batches_done / total_batches if total_batches > 0 else 0 252 | batches_remaining = total_batches - total_batches_done 253 | estimated_time_remaining = ( 254 | (time_elapsed / total_batches_done) * batches_remaining 255 | if total_batches_done > 0 else 0 256 | ) 257 | estimated_time_remaining_ddhhmmss = convert_seconds_to_string_time( 258 | estimated_time_remaining 259 | ) 260 | percent_done = (total_batches_done / total_batches) * 100 if total_batches > 0 else 0 261 | batch_info = f"Batch {self.state.steps}/{batches_per_epoch}" 262 | progress_info = f"{percent_done:.2f}% done | ETA: {estimated_time_remaining_ddhhmmss}" 263 | else: 264 | current_epoch_decimal = ep + (self.state.steps / 1000) 265 | batch_info = f"Batch {self.state.steps}" 266 | progress_info = f"elapsed: {convert_seconds_to_string_time(time_elapsed)}" 267 | 268 | if self.state.steps % self.log_every == 0 and self.accelerator.is_main_process: 269 | metrics_str = " | ".join(f"{k}: {v:.4f}" for k, v in loss_output.metrics.items()) 270 | logger.info( 271 | f"Epoch {current_epoch_decimal:.2f} | {batch_info} | " 272 | f"{metrics_str} | {progress_info}" 273 | ) 274 | 275 | if self.metrics_logger: 276 | for k, v in loss_output.metrics.items(): 277 | self.metrics_logger.log_scalar(f"train/{k}", v, self.state.total_steps) 278 | 279 | self.state.global_loss += batch_loss 280 | 281 | self.state.steps += 1 282 | self.state.total_steps += 1 283 | 284 | if max_steps is not None and self.state.total_steps >= max_steps: 285 | break 286 | 287 | if self.state.steps % self.checkpoint_every == 0 and self.state.steps > 0: 288 | self._save() 289 | 290 | self.accelerator.wait_for_everyone() 291 | self.state.epoch += 1 292 | self.state.steps = 0 293 | 294 | atexit.unregister(self._save) 295 | self._save() 296 | 297 | return self.state.global_loss / self.state.total_steps if self.state.total_steps > 0 else 0 298 | 299 | def train(self, epochs=1, max_steps=None): 300 | if self.accelerator.is_main_process: 301 | logger.info("***** Running training *****") 302 | try: 303 | logger.info(f" Num examples = {len(self.dataset):,}") 304 | except TypeError: 305 | logger.info(" Num examples = unknown (iterable dataset)") 306 | logger.info(f" Num Epochs = {epochs:,}") 307 | if max_steps: 308 | logger.info(f" Total optimization steps = {max_steps:,}") 309 | logger.info(f" Instantaneous batch size per device = {self.batch_size:,}") 310 | logger.info(f" Gradient Accumulation steps = {self.grad_accum}") 311 | logger.info( 312 | f" Number of trainable parameters = {get_model_param_count(self.model, trainable_only=True):,}" 313 | ) 314 | logger.info(f" Loss calculator = {self.loss_calculator.__class__.__name__}") 315 | 316 | return self._train_accelerate(epochs=epochs, max_steps=max_steps) 317 | -------------------------------------------------------------------------------- /lib/ditty/contract.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pipeline contract system for declarative tensor specifications. 3 | 4 | Contracts use a terse syntax to declare tensor shapes and types: 5 | "name:rank:dtype, name:rank:dtype -> output:rank:dtype" 6 | 7 | Examples: 8 | "tokens:3:i64 -> logits:4:f, z0_pred:4:f, z0_true:4:f, mask:3:b" 9 | "logits:4:f, z0_pred:4:f | ctx.input_ids:3:i64 -> loss:0:f" 10 | 11 | Dtype shorthand (rust-style): 12 | f - any float (f16, bf16, f32, f64) 13 | f16 - float16 14 | f32 - float32 15 | bf16 - bfloat16 16 | i - any int (i8, i16, i32, i64) 17 | i64 - int64 18 | i32 - int32 19 | b - bool 20 | u8 - uint8 21 | 22 | Context dependencies use | separator: 23 | "input:3:f | ctx.target:4:f, ctx.mask:3:b -> output:4:f" 24 | """ 25 | from dataclasses import dataclass 26 | from typing import Dict, List, Optional, Tuple, Any, Set 27 | import re 28 | import torch 29 | 30 | 31 | DTYPE_MAP = { 32 | "f": (torch.float16, torch.float32, torch.bfloat16, torch.float64), 33 | "f16": (torch.float16,), 34 | "f32": (torch.float32,), 35 | "f64": (torch.float64,), 36 | "bf16": (torch.bfloat16,), 37 | "i": (torch.int8, torch.int16, torch.int32, torch.int64), 38 | "i8": (torch.int8,), 39 | "i16": (torch.int16,), 40 | "i32": (torch.int32,), 41 | "i64": (torch.int64,), 42 | "b": (torch.bool,), 43 | "u8": (torch.uint8,), 44 | } 45 | 46 | 47 | class ContractViolation(Exception): 48 | """Raised when a tensor doesn't match its declared contract.""" 49 | pass 50 | 51 | 52 | class ContractParseError(Exception): 53 | """Raised when a contract string is malformed.""" 54 | pass 55 | 56 | 57 | @dataclass 58 | class TensorSpec: 59 | """Specification for a single tensor.""" 60 | name: str 61 | rank: int 62 | dtype: Optional[str] = None # None means any dtype 63 | is_ctx: bool = False # True if this is a ctx.xxx reference 64 | 65 | def __str__(self) -> str: 66 | dtype_str = f":{self.dtype}" if self.dtype else "" 67 | prefix = "ctx." if self.is_ctx else "" 68 | return f"{prefix}{self.name}:{self.rank}{dtype_str}" 69 | 70 | def validate(self, tensor: torch.Tensor) -> None: 71 | """Validate a tensor against this spec. Raises ContractViolation on mismatch.""" 72 | if tensor.ndim != self.rank: 73 | raise ContractViolation( 74 | f"{self.name}: expected rank {self.rank}, got {tensor.ndim} " 75 | f"(shape: {tuple(tensor.shape)})" 76 | ) 77 | 78 | if self.dtype: 79 | valid_dtypes = DTYPE_MAP.get(self.dtype) 80 | if valid_dtypes is None: 81 | raise ContractViolation(f"{self.name}: unknown dtype spec '{self.dtype}'") 82 | if tensor.dtype not in valid_dtypes: 83 | expected = self.dtype if len(valid_dtypes) > 1 else valid_dtypes[0] 84 | raise ContractViolation( 85 | f"{self.name}: expected dtype {expected}, got {tensor.dtype}" 86 | ) 87 | 88 | 89 | @dataclass 90 | class Contract: 91 | """Parsed contract with inputs, outputs, and context dependencies.""" 92 | inputs: List[TensorSpec] 93 | outputs: List[TensorSpec] 94 | ctx_deps: List[TensorSpec] # ctx.xxx dependencies 95 | raw: str # Original contract string 96 | 97 | def __str__(self) -> str: 98 | inputs_str = ", ".join(str(s) for s in self.inputs) 99 | outputs_str = ", ".join(str(s) for s in self.outputs) 100 | if self.ctx_deps: 101 | ctx_str = ", ".join(str(s) for s in self.ctx_deps) 102 | return f"{inputs_str} | {ctx_str} -> {outputs_str}" 103 | return f"{inputs_str} -> {outputs_str}" 104 | 105 | @property 106 | def input_names(self) -> Set[str]: 107 | return {s.name for s in self.inputs} 108 | 109 | @property 110 | def output_names(self) -> Set[str]: 111 | return {s.name for s in self.outputs} 112 | 113 | @property 114 | def ctx_names(self) -> Set[str]: 115 | return {s.name for s in self.ctx_deps} 116 | 117 | def validate_inputs(self, tensors: Tuple[Any, ...], ctx: Dict[str, Any]) -> None: 118 | """Validate input tensors and context against contract.""" 119 | if len(tensors) != len(self.inputs): 120 | raise ContractViolation( 121 | f"Expected {len(self.inputs)} inputs, got {len(tensors)}" 122 | ) 123 | 124 | for spec, tensor in zip(self.inputs, tensors): 125 | if not isinstance(tensor, torch.Tensor): 126 | raise ContractViolation( 127 | f"{spec.name}: expected Tensor, got {type(tensor).__name__}" 128 | ) 129 | spec.validate(tensor) 130 | 131 | for spec in self.ctx_deps: 132 | if spec.name not in ctx: 133 | raise ContractViolation(f"Missing ctx.{spec.name}") 134 | val = ctx[spec.name] 135 | if isinstance(val, torch.Tensor): 136 | spec.validate(val) 137 | 138 | def validate_outputs(self, tensors: Tuple[Any, ...]) -> None: 139 | """Validate output tensors against contract.""" 140 | if len(tensors) != len(self.outputs): 141 | raise ContractViolation( 142 | f"Expected {len(self.outputs)} outputs, got {len(tensors)}" 143 | ) 144 | 145 | for spec, tensor in zip(self.outputs, tensors): 146 | if not isinstance(tensor, torch.Tensor): 147 | raise ContractViolation( 148 | f"{spec.name}: expected Tensor, got {type(tensor).__name__}" 149 | ) 150 | spec.validate(tensor) 151 | 152 | 153 | def parse_tensor_spec(spec_str: str) -> TensorSpec: 154 | """ 155 | Parse a single tensor spec like "name:rank:dtype" or "ctx.name:rank:dtype". 156 | 157 | Examples: 158 | "logits:4:f" -> TensorSpec(name="logits", rank=4, dtype="f") 159 | "mask:3:b" -> TensorSpec(name="mask", rank=3, dtype="b") 160 | "tokens:3" -> TensorSpec(name="tokens", rank=3, dtype=None) 161 | "ctx.target:4:f" -> TensorSpec(name="target", rank=4, dtype="f", is_ctx=True) 162 | """ 163 | spec_str = spec_str.strip() 164 | if not spec_str: 165 | raise ContractParseError("Empty tensor spec") 166 | 167 | is_ctx = spec_str.startswith("ctx.") 168 | if is_ctx: 169 | spec_str = spec_str[4:] # Remove "ctx." prefix 170 | 171 | parts = spec_str.split(":") 172 | if len(parts) < 2: 173 | raise ContractParseError( 174 | f"Invalid tensor spec '{spec_str}': expected 'name:rank' or 'name:rank:dtype'" 175 | ) 176 | 177 | name = parts[0].strip() 178 | if not name: 179 | raise ContractParseError(f"Empty name in tensor spec '{spec_str}'") 180 | 181 | try: 182 | rank = int(parts[1].strip()) 183 | except ValueError: 184 | raise ContractParseError( 185 | f"Invalid rank '{parts[1]}' in tensor spec '{spec_str}': expected integer" 186 | ) 187 | 188 | dtype = None 189 | if len(parts) >= 3: 190 | dtype = parts[2].strip() 191 | if dtype and dtype not in DTYPE_MAP: 192 | raise ContractParseError( 193 | f"Unknown dtype '{dtype}' in tensor spec '{spec_str}'. " 194 | f"Valid dtypes: {', '.join(DTYPE_MAP.keys())}" 195 | ) 196 | 197 | return TensorSpec(name=name, rank=rank, dtype=dtype or None, is_ctx=is_ctx) 198 | 199 | 200 | def parse_tensor_list(specs_str: str) -> Tuple[List[TensorSpec], List[TensorSpec]]: 201 | """ 202 | Parse a comma-separated list of tensor specs, separating regular and ctx specs. 203 | 204 | Returns: 205 | (regular_specs, ctx_specs) 206 | """ 207 | if not specs_str.strip(): 208 | return [], [] 209 | 210 | regular = [] 211 | ctx = [] 212 | 213 | for spec_str in specs_str.split(","): 214 | spec_str = spec_str.strip() 215 | if not spec_str: 216 | continue 217 | spec = parse_tensor_spec(spec_str) 218 | if spec.is_ctx: 219 | ctx.append(spec) 220 | else: 221 | regular.append(spec) 222 | 223 | return regular, ctx 224 | 225 | 226 | def parse_contract(contract_str: str) -> Contract: 227 | """ 228 | Parse a full contract string. 229 | 230 | Format: "inputs | ctx_deps -> outputs" 231 | 232 | Examples: 233 | "tokens:3:i64 -> logits:4:f" 234 | "logits:4:f, z0_pred:4:f | ctx.input_ids:3:i64 -> loss:0:f" 235 | """ 236 | if not contract_str or not contract_str.strip(): 237 | raise ContractParseError("Empty contract string") 238 | 239 | contract_str = contract_str.strip() 240 | 241 | # Split on -> 242 | if "->" not in contract_str: 243 | raise ContractParseError( 244 | f"Invalid contract '{contract_str}': missing '->' separator" 245 | ) 246 | 247 | input_side, output_side = contract_str.split("->", 1) 248 | input_side = input_side.strip() 249 | output_side = output_side.strip() 250 | 251 | # Parse output side (no ctx deps allowed) 252 | outputs, output_ctx = parse_tensor_list(output_side) 253 | if output_ctx: 254 | raise ContractParseError( 255 | f"Invalid contract: ctx references not allowed in outputs" 256 | ) 257 | 258 | # Split input side on | for ctx deps 259 | ctx_deps = [] 260 | if "|" in input_side: 261 | main_inputs, ctx_side = input_side.split("|", 1) 262 | main_inputs = main_inputs.strip() 263 | ctx_side = ctx_side.strip() 264 | 265 | # Parse ctx deps - everything after | is a ctx dep 266 | regular_specs, ctx_specs = parse_tensor_list(ctx_side) 267 | # Mark all specs from the ctx side as ctx deps 268 | for spec in regular_specs: 269 | spec.is_ctx = True 270 | for spec in ctx_specs: 271 | spec.is_ctx = True 272 | ctx_deps = regular_specs + ctx_specs 273 | else: 274 | main_inputs = input_side 275 | 276 | # Parse main inputs 277 | inputs, input_ctx = parse_tensor_list(main_inputs) 278 | ctx_deps.extend(input_ctx) 279 | 280 | return Contract( 281 | inputs=inputs, 282 | outputs=outputs, 283 | ctx_deps=ctx_deps, 284 | raw=contract_str, 285 | ) 286 | 287 | 288 | def validate_pipeline_chain( 289 | preprocessor_contracts: List[Contract], 290 | model_contract: Contract, 291 | postprocessor_contracts: List[Contract], 292 | loss_contract: Contract, 293 | ) -> List[str]: 294 | """ 295 | Validate that a pipeline's contracts chain together correctly. 296 | 297 | Returns list of validation errors (empty if valid). 298 | """ 299 | errors = [] 300 | 301 | # Track what's available in ctx 302 | ctx_available: Set[str] = set() 303 | 304 | # Track current tensor outputs (name -> TensorSpec) 305 | current_outputs: Dict[str, TensorSpec] = {} 306 | 307 | # Process preprocessors 308 | for i, contract in enumerate(preprocessor_contracts): 309 | # Check ctx deps are satisfied 310 | for ctx_dep in contract.ctx_deps: 311 | if ctx_dep.name not in ctx_available: 312 | errors.append( 313 | f"Preprocessor {i}: requires ctx.{ctx_dep.name} but not available" 314 | ) 315 | 316 | # Add outputs to ctx (preprocessors output to ctx) 317 | for out in contract.outputs: 318 | if out.is_ctx: 319 | ctx_available.add(out.name) 320 | else: 321 | current_outputs[out.name] = out 322 | 323 | # Check model inputs 324 | for inp in model_contract.inputs: 325 | # Model inputs come from preprocessor outputs or ctx 326 | if inp.name not in current_outputs and inp.name not in ctx_available: 327 | errors.append( 328 | f"Model: requires input '{inp.name}' but not provided by preprocessors" 329 | ) 330 | elif inp.name in current_outputs: 331 | provided = current_outputs[inp.name] 332 | if provided.rank != inp.rank: 333 | errors.append( 334 | f"Model input '{inp.name}': expected rank {inp.rank}, " 335 | f"preprocessor provides rank {provided.rank}" 336 | ) 337 | 338 | for ctx_dep in model_contract.ctx_deps: 339 | if ctx_dep.name not in ctx_available: 340 | errors.append(f"Model: requires ctx.{ctx_dep.name} but not available") 341 | 342 | # Model outputs become current outputs 343 | current_outputs = {out.name: out for out in model_contract.outputs} 344 | 345 | # Process postprocessors 346 | for i, contract in enumerate(postprocessor_contracts): 347 | for inp in contract.inputs: 348 | if inp.name not in current_outputs: 349 | errors.append( 350 | f"Postprocessor {i}: requires input '{inp.name}' but not available" 351 | ) 352 | else: 353 | provided = current_outputs[inp.name] 354 | if provided.rank != inp.rank: 355 | errors.append( 356 | f"Postprocessor {i} input '{inp.name}': expected rank {inp.rank}, " 357 | f"got rank {provided.rank}" 358 | ) 359 | 360 | for ctx_dep in contract.ctx_deps: 361 | if ctx_dep.name not in ctx_available: 362 | errors.append( 363 | f"Postprocessor {i}: requires ctx.{ctx_dep.name} but not available" 364 | ) 365 | 366 | # Update current outputs 367 | current_outputs = {out.name: out for out in contract.outputs} 368 | # Postprocessors can also add to ctx 369 | for out in contract.outputs: 370 | if out.is_ctx: 371 | ctx_available.add(out.name) 372 | 373 | # Check loss calculator inputs 374 | for inp in loss_contract.inputs: 375 | if inp.name not in current_outputs: 376 | errors.append( 377 | f"LossCalculator: requires input '{inp.name}' but not available" 378 | ) 379 | else: 380 | provided = current_outputs[inp.name] 381 | if provided.rank != inp.rank: 382 | errors.append( 383 | f"LossCalculator input '{inp.name}': expected rank {inp.rank}, " 384 | f"got rank {provided.rank}" 385 | ) 386 | 387 | for ctx_dep in loss_contract.ctx_deps: 388 | if ctx_dep.name not in ctx_available: 389 | errors.append( 390 | f"LossCalculator: requires ctx.{ctx_dep.name} but not available" 391 | ) 392 | 393 | return errors 394 | 395 | 396 | def format_pipeline_contracts( 397 | preprocessor_contracts: List[Tuple[str, Contract]], 398 | model_contract: Tuple[str, Contract], 399 | postprocessor_contracts: List[Tuple[str, Contract]], 400 | loss_contract: Tuple[str, Contract], 401 | ) -> str: 402 | """Format pipeline contracts for display/debugging.""" 403 | lines = ["Pipeline Contracts:", "=" * 60] 404 | 405 | if preprocessor_contracts: 406 | lines.append("\nPreprocessors:") 407 | for name, contract in preprocessor_contracts: 408 | lines.append(f" {name}:") 409 | lines.append(f" {contract}") 410 | 411 | lines.append(f"\nModel ({model_contract[0]}):") 412 | lines.append(f" {model_contract[1]}") 413 | 414 | if postprocessor_contracts: 415 | lines.append("\nPostprocessors:") 416 | for name, contract in postprocessor_contracts: 417 | lines.append(f" {name}:") 418 | lines.append(f" {contract}") 419 | 420 | lines.append(f"\nLossCalculator ({loss_contract[0]}):") 421 | lines.append(f" {loss_contract[1]}") 422 | 423 | lines.append("=" * 60) 424 | return "\n".join(lines) 425 | -------------------------------------------------------------------------------- /optimizers/igod.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hebbian IGOD: Inertial Gradient Optimization via Decomposition with Coherence Neighborhoods 3 | 4 | An optimizer that: 5 | 1. Keeps inertia (belief) per parameter 6 | 2. Groups parameters into blocks (by index) 7 | 3. Learns which blocks "fire together" using gradient similarity (Hebbian rule) 8 | 4. Computes coherence at the neighborhood level for IGOD decomposition 9 | 5. Applies decomposed, coherence-scaled updates per parameter 10 | """ 11 | 12 | import math 13 | import torch 14 | from torch.optim import Optimizer 15 | from typing import Dict, List, Tuple, Optional, Callable 16 | from dataclasses import dataclass, field 17 | 18 | try: 19 | from torchao.optim.subclass_8bit import OptimState8bit 20 | TORCHAO_AVAILABLE = True 21 | except ImportError: 22 | TORCHAO_AVAILABLE = False 23 | 24 | 25 | @dataclass 26 | class BlockState: 27 | """State for a single block.""" 28 | g_ema: torch.Tensor # Can be OptimState8bit or regular tensor 29 | neighbors: List[int] = field(default_factory=list) 30 | weights: Dict[int, float] = field(default_factory=dict) 31 | 32 | def get_g_ema_float(self) -> torch.Tensor: 33 | if hasattr(self.g_ema, 'dequantize'): 34 | return self.g_ema.dequantize() 35 | return self.g_ema.float() 36 | 37 | 38 | class HebbianIGOD(Optimizer): 39 | """ 40 | Hebbian IGOD: Inertial Gradient Optimization via Decomposition 41 | with Coherence Neighborhoods. 42 | 43 | Args: 44 | params: Model parameters 45 | lr: Learning rate (η) 46 | gamma: Inertia update rate (γ) 47 | alpha: Boost for confirming corrections (α ≥ 0) 48 | beta: Suppression for contradicting corrections (β ≥ 0) 49 | delta: Boost for orthogonal/new directions (δ ≥ 0) 50 | eps: Numerical stability constant (ε) 51 | rho: EMA rate for g_ema (ρ) 52 | lambda_w: EMA rate for Hebbian weights (λ_W) 53 | tau: Threshold for Hebbian weight to count as neighbor (τ) 54 | k: Max neighbors per block 55 | t_hebb: Frequency (in steps) to update Hebbian neighborhoods (T_hebb) 56 | block_size: Size of parameter blocks (B) 57 | use_8bit: Use 8-bit quantization for inertia state 58 | """ 59 | 60 | def __init__( 61 | self, 62 | params, 63 | lr: float = 1e-3, 64 | gamma: float = 0.1, 65 | alpha: float = 0.0, 66 | beta: float = 1.0, 67 | delta: float = 0.0, 68 | eps: float = 1e-8, 69 | rho: float = 0.01, 70 | lambda_w: float = 0.01, 71 | tau: float = 0.5, 72 | k: int = 8, 73 | t_hebb: int = 100, 74 | block_size: int = 256, 75 | use_8bit: bool = False, 76 | ): 77 | if use_8bit and not TORCHAO_AVAILABLE: 78 | raise ImportError("torchao required for 8-bit mode: pip install torchao") 79 | 80 | defaults = dict( 81 | lr=lr, 82 | gamma=gamma, 83 | alpha=alpha, 84 | beta=beta, 85 | delta=delta, 86 | eps=eps, 87 | rho=rho, 88 | lambda_w=lambda_w, 89 | tau=tau, 90 | k=k, 91 | t_hebb=t_hebb, 92 | block_size=block_size, 93 | use_8bit=use_8bit, 94 | ) 95 | super().__init__(params, defaults) 96 | 97 | self._block_registry: Dict[int, Tuple[int, int]] = {} 98 | self._block_states: Dict[int, BlockState] = {} 99 | self._param_blocks: Dict[int, List[int]] = {} 100 | self._param_names: Dict[int, str] = {} 101 | self._global_step = 0 102 | self._next_block_id = 0 103 | 104 | # Logging stats (collected during step, reset after get_stats) 105 | self._step_stats: Dict[str, list] = { 106 | 'inertia_norms': [], 107 | 'c_G_values': [], 108 | 'neighbor_counts': [], 109 | } 110 | 111 | def _get_block_slice(self, block_idx: int, num_elements: int, block_size: int) -> Tuple[int, int]: 112 | start = block_idx * block_size 113 | end = min((block_idx + 1) * block_size, num_elements) 114 | return start, end 115 | 116 | def _init_param_state(self, p: torch.Tensor, param_id: int, group: dict): 117 | state = self.state[p] 118 | block_size = group['block_size'] 119 | use_8bit = group['use_8bit'] 120 | 121 | if use_8bit and p.numel() >= 4096: 122 | state['inertia'] = OptimState8bit.zeros( 123 | p.shape, signed=True, block_size=block_size, device=p.device 124 | ) 125 | else: 126 | state['inertia'] = torch.zeros_like(p, memory_format=torch.preserve_format) 127 | 128 | num_elements = p.numel() 129 | num_blocks = math.ceil(num_elements / block_size) 130 | 131 | block_ids = [] 132 | for local_idx in range(num_blocks): 133 | global_id = self._next_block_id 134 | self._next_block_id += 1 135 | 136 | self._block_registry[global_id] = (param_id, local_idx) 137 | block_ids.append(global_id) 138 | 139 | start, end = self._get_block_slice(local_idx, num_elements, block_size) 140 | block_len = end - start 141 | 142 | if use_8bit and block_len >= 256: 143 | g_ema = OptimState8bit.zeros((block_len,), signed=True, block_size=min(block_size, block_len), device=p.device) 144 | else: 145 | g_ema = torch.zeros(block_len, device=p.device, dtype=torch.float32) 146 | 147 | self._block_states[global_id] = BlockState(g_ema=g_ema) 148 | 149 | self._param_blocks[param_id] = block_ids 150 | state['param_id'] = param_id 151 | state['initialized'] = True 152 | 153 | def _update_hebbian_neighborhoods(self, group: dict): 154 | lambda_w = group['lambda_w'] 155 | tau = group['tau'] 156 | k = group['k'] 157 | eps = group['eps'] 158 | 159 | all_block_ids = list(self._block_states.keys()) 160 | if len(all_block_ids) < 2: 161 | return 162 | 163 | for b in all_block_ids: 164 | b_state = self._block_states[b] 165 | g_b = b_state.get_g_ema_float() 166 | 167 | g_b_norm = torch.norm(g_b) 168 | if g_b_norm < eps: 169 | continue 170 | 171 | candidates = [j for j in all_block_ids if j != b] 172 | if len(candidates) > k * 4: 173 | indices = torch.randperm(len(candidates))[:k * 4].tolist() 174 | candidates = [candidates[i] for i in indices] 175 | 176 | for j in candidates: 177 | j_state = self._block_states[j] 178 | g_j = j_state.get_g_ema_float() 179 | 180 | g_j_norm = torch.norm(g_j) 181 | if g_j_norm < eps: 182 | continue 183 | 184 | sim = torch.dot(g_b, g_j) / (g_b_norm * g_j_norm + eps) 185 | sim = sim.item() 186 | 187 | old_w = b_state.weights.get(j, 0.0) 188 | new_w = (1 - lambda_w) * old_w + lambda_w * sim 189 | b_state.weights[j] = new_w 190 | 191 | valid_neighbors = [(j, w) for j, w in b_state.weights.items() if w > tau] 192 | valid_neighbors.sort(key=lambda x: x[1], reverse=True) 193 | top_k = valid_neighbors[:k] 194 | 195 | b_state.neighbors = [j for j, w in top_k] 196 | b_state.weights = {j: w for j, w in top_k} 197 | 198 | def _compute_igod_update( 199 | self, 200 | g_flat: torch.Tensor, 201 | I_flat: torch.Tensor, 202 | block_ids: List[int], 203 | group: dict, 204 | ) -> torch.Tensor: 205 | lr = group['lr'] 206 | alpha = group['alpha'] 207 | beta = group['beta'] 208 | delta = group['delta'] 209 | eps = group['eps'] 210 | block_size = group['block_size'] 211 | 212 | num_elements = g_flat.numel() 213 | update = torch.zeros_like(g_flat) 214 | updated_mask = torch.zeros(num_elements, dtype=torch.bool, device=g_flat.device) 215 | 216 | for local_idx, global_id in enumerate(block_ids): 217 | start, end = self._get_block_slice(local_idx, num_elements, block_size) 218 | 219 | b_state = self._block_states[global_id] 220 | group_ids = [global_id] + b_state.neighbors 221 | 222 | g_parts = [] 223 | I_parts = [] 224 | 225 | for gid in group_ids: 226 | param_id, blk_idx = self._block_registry[gid] 227 | 228 | for p in group['params']: 229 | p_state = self.state.get(p, {}) 230 | if p_state.get('param_id') == param_id: 231 | p_g = p.grad.view(-1).float() 232 | if hasattr(p_state['inertia'], 'dequantize'): 233 | p_I = p_state['inertia'].dequantize().view(-1) 234 | else: 235 | p_I = p_state['inertia'].view(-1).float() 236 | 237 | blk_start, blk_end = self._get_block_slice( 238 | blk_idx, p_g.numel(), block_size 239 | ) 240 | g_parts.append(p_g[blk_start:blk_end]) 241 | I_parts.append(p_I[blk_start:blk_end]) 242 | break 243 | 244 | if not g_parts: 245 | continue 246 | 247 | g_G = torch.cat(g_parts) 248 | I_G = torch.cat(I_parts) 249 | 250 | B_G = torch.norm(I_G).item() 251 | 252 | if B_G < eps: 253 | update[start:end] = -lr * g_flat[start:end] 254 | updated_mask[start:end] = True 255 | continue 256 | 257 | u_G = I_G / (B_G + eps) 258 | 259 | c_G = torch.dot(g_G, u_G).item() 260 | g_parallel = c_G * u_G 261 | g_orthogonal = g_G - g_parallel 262 | 263 | # Log c_G and neighbor count 264 | self._step_stats['c_G_values'].append(c_G) 265 | self._step_stats['neighbor_counts'].append(len(b_state.neighbors)) 266 | 267 | if c_G > 0: 268 | s_parallel = 1.0 + alpha * B_G 269 | elif c_G < 0: 270 | s_parallel = 1.0 / (1.0 + beta * B_G) 271 | else: 272 | s_parallel = 1.0 273 | 274 | s_orthogonal = 1.0 + delta * B_G 275 | 276 | delta_G = -lr * (s_parallel * g_parallel + s_orthogonal * g_orthogonal) 277 | 278 | block_len = end - start 279 | block_update = delta_G[:block_len] 280 | 281 | update[start:end] = block_update 282 | updated_mask[start:end] = True 283 | 284 | if not updated_mask.all(): 285 | update[~updated_mask] = -lr * g_flat[~updated_mask] 286 | 287 | return update 288 | 289 | @torch.no_grad() 290 | def step(self, closure=None): 291 | loss = None 292 | if closure is not None: 293 | with torch.enable_grad(): 294 | loss = closure() 295 | 296 | self._global_step += 1 297 | param_id_counter = 0 298 | 299 | for group in self.param_groups: 300 | gamma = group['gamma'] 301 | rho = group['rho'] 302 | block_size = group['block_size'] 303 | t_hebb = group['t_hebb'] 304 | 305 | for p in group['params']: 306 | if p.grad is None: 307 | continue 308 | 309 | state = self.state[p] 310 | 311 | if 'initialized' not in state: 312 | self._init_param_state(p, param_id_counter, group) 313 | param_id_counter += 1 314 | 315 | param_id = state['param_id'] 316 | grad = p.grad.float() 317 | g_flat = grad.view(-1) 318 | 319 | I = state['inertia'] 320 | if hasattr(I, 'dequantize'): 321 | I_flat = I.dequantize().view(-1) 322 | else: 323 | I_flat = I.view(-1).float() 324 | 325 | block_ids = self._param_blocks[param_id] 326 | num_elements = g_flat.numel() 327 | 328 | for local_idx, global_id in enumerate(block_ids): 329 | start, end = self._get_block_slice(local_idx, num_elements, block_size) 330 | g_block = g_flat[start:end] 331 | 332 | b_state = self._block_states[global_id] 333 | g_ema_f32 = b_state.get_g_ema_float() 334 | new_g_ema = (1 - rho) * g_ema_f32 + rho * g_block 335 | b_state.g_ema.copy_(new_g_ema) 336 | 337 | new_inertia = I_flat + gamma * g_flat 338 | I.copy_(new_inertia.view(I.shape)) 339 | 340 | # Log inertia norm 341 | inertia_norm = torch.norm(new_inertia).item() 342 | param_name = self._param_names.get(param_id, f"param_{param_id}") 343 | self._step_stats['inertia_norms'].append((param_name, inertia_norm)) 344 | 345 | if self._global_step % t_hebb == 0: 346 | self._update_hebbian_neighborhoods(group) 347 | 348 | for p in group['params']: 349 | if p.grad is None: 350 | continue 351 | 352 | state = self.state[p] 353 | param_id = state['param_id'] 354 | block_ids = self._param_blocks[param_id] 355 | 356 | grad = p.grad.float() 357 | g_flat = grad.view(-1) 358 | 359 | I = state['inertia'] 360 | if hasattr(I, 'dequantize'): 361 | I_flat = I.dequantize().view(-1) 362 | else: 363 | I_flat = I.view(-1).float() 364 | 365 | update = self._compute_igod_update(g_flat, I_flat, block_ids, group) 366 | 367 | p.add_(update.view(p.shape).to(p.dtype)) 368 | 369 | return loss 370 | 371 | def get_stats(self, reset: bool = True) -> Dict[str, any]: 372 | """ 373 | Get logging statistics from the last step(s). 374 | 375 | Returns dict with: 376 | - inertia_norms: list of (param_name, norm) tuples 377 | - c_G_values: list of c_G scalars (confirmation vs contradiction) 378 | - c_G_mean: mean c_G (positive = confirming, negative = contradicting) 379 | - c_G_pos_ratio: fraction of blocks with c_G > 0 380 | - neighbor_counts: list of neighbor counts per block 381 | - avg_neighbors: average neighbors per block 382 | - total_blocks: total number of blocks 383 | - step: current global step 384 | """ 385 | stats = {} 386 | 387 | # Inertia norms 388 | stats['inertia_norms'] = list(self._step_stats['inertia_norms']) 389 | 390 | # c_G distribution 391 | c_vals = self._step_stats['c_G_values'] 392 | stats['c_G_values'] = list(c_vals) 393 | if c_vals: 394 | stats['c_G_mean'] = sum(c_vals) / len(c_vals) 395 | stats['c_G_pos_ratio'] = sum(1 for c in c_vals if c > 0) / len(c_vals) 396 | else: 397 | stats['c_G_mean'] = 0.0 398 | stats['c_G_pos_ratio'] = 0.0 399 | 400 | # Neighbor counts 401 | neighbor_counts = self._step_stats['neighbor_counts'] 402 | stats['neighbor_counts'] = list(neighbor_counts) 403 | if neighbor_counts: 404 | stats['avg_neighbors'] = sum(neighbor_counts) / len(neighbor_counts) 405 | else: 406 | stats['avg_neighbors'] = 0.0 407 | 408 | stats['total_blocks'] = len(self._block_states) 409 | stats['step'] = self._global_step 410 | 411 | if reset: 412 | self._step_stats = { 413 | 'inertia_norms': [], 414 | 'c_G_values': [], 415 | 'neighbor_counts': [], 416 | } 417 | 418 | return stats 419 | 420 | def set_param_names(self, model: torch.nn.Module): 421 | """Set parameter names from model for cleaner logging.""" 422 | param_to_name = {id(p): name for name, p in model.named_parameters()} 423 | for p in self.param_groups[0]['params']: 424 | state = self.state.get(p, {}) 425 | if 'param_id' in state: 426 | name = param_to_name.get(id(p), f"param_{state['param_id']}") 427 | self._param_names[state['param_id']] = name 428 | 429 | def get_neighborhood_graph(self) -> Dict[int, List[Tuple[int, float]]]: 430 | """ 431 | Get the current Hebbian neighborhood graph. 432 | 433 | Returns dict mapping block_id -> [(neighbor_id, weight), ...] 434 | """ 435 | graph = {} 436 | for block_id, state in self._block_states.items(): 437 | graph[block_id] = [(n, state.weights.get(n, 0.0)) for n in state.neighbors] 438 | return graph 439 | -------------------------------------------------------------------------------- /lib/ditty/pipeline.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import types 4 | from logging import getLogger 5 | from typing import Optional, List, Dict, Any, Union, Callable 6 | import bitsandbytes as bnb 7 | from accelerate import Accelerator 8 | from accelerate.utils import ProjectConfiguration 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils.data import DataLoader 13 | from datasets import Dataset, IterableDataset 14 | from .trainer import Trainer, TrainerState 15 | from .data import Data 16 | from .hf_utils import push_to_hub 17 | from .model_factory import ModelFactory, TokenizerFactory 18 | from .loss import LossCalculator, MSELoss 19 | from .processors import PreProcessor, PostProcessor 20 | from .contract import parse_contract, validate_pipeline_chain, format_pipeline_contracts, ContractParseError 21 | from .checkpoint import CheckpointManager, Checkpoint 22 | 23 | 24 | logging.basicConfig(level=logging.INFO) 25 | 26 | logger = getLogger("ditty_pipeline") 27 | 28 | 29 | class Pipeline: 30 | def __init__( 31 | self, 32 | model_factory: ModelFactory, 33 | dataset: Union[Dataset, DataLoader], 34 | collate_fn: Optional[Callable] = None, 35 | tokenizer_factory: Optional[TokenizerFactory] = None, 36 | loss_calculator: LossCalculator = None, # type: ignore[assignment] 37 | preprocessors: Optional[List[PreProcessor]] = None, 38 | postprocessors: Optional[List[PostProcessor]] = None, 39 | output_dir: str = "./output", 40 | fp16: bool = True, 41 | use_bfloat16: bool = False, 42 | seed: Optional[int] = None, 43 | batch_size: int = 4, 44 | grad_accum: int = 1, 45 | checkpoint_every: int = 1000, 46 | load_checkpoint: bool = True, 47 | gradient_checkpointing: bool = True, 48 | use_8bit_optim: bool = False, 49 | optim_backend: str = "torchao", # "torch", "bnb", or "torchao" 50 | lr: float = 1e-4, 51 | weight_decay: float = 0.01, 52 | max_grad_norm: float = 1.0, 53 | epochs: int = 1, 54 | max_steps: Optional[int] = None, 55 | log_every: int = 10, 56 | metrics_logger: Optional[Any] = None, 57 | accelerator_kwargs: Dict[str, Any] = {}, 58 | optimizer: Optional[torch.optim.Optimizer] = None, 59 | # Hub options 60 | push_to_hub: bool = False, 61 | output_hub_repo: Optional[str] = None, 62 | hf_hub_token: Optional[str] = None, 63 | merge_adapters: bool = False, 64 | private_repo: bool = True, 65 | # Dataset options 66 | shuffle_each_epoch: bool = True, 67 | num_workers: int = 4, 68 | shuffle_buffer_size: int = 1000, 69 | ): 70 | self.model_factory = model_factory 71 | self._dataset = dataset 72 | self.collate_fn = collate_fn 73 | self.tokenizer_factory = tokenizer_factory 74 | self.loss_calculator = loss_calculator or MSELoss() 75 | self.preprocessors = preprocessors or [] 76 | self.postprocessors = postprocessors or [] 77 | self.output_dir = output_dir 78 | self.fp16 = fp16 79 | self.use_bfloat16 = use_bfloat16 80 | self.seed = seed 81 | self.batch_size = batch_size 82 | self.grad_accum = grad_accum 83 | self.checkpoint_every = checkpoint_every 84 | self.load_checkpoint = load_checkpoint 85 | self.gradient_checkpointing = gradient_checkpointing 86 | self.use_8bit_optim = use_8bit_optim 87 | self.optim_backend = optim_backend 88 | self.lr = lr 89 | self.weight_decay = weight_decay 90 | self.max_grad_norm = max_grad_norm 91 | self.epochs = epochs 92 | self.max_steps = max_steps 93 | self.log_every = log_every 94 | self.metrics_logger = metrics_logger 95 | self.accelerator_kwargs = accelerator_kwargs 96 | self._user_optimizer = optimizer 97 | self.push_to_hub_flag = push_to_hub 98 | self.output_hub_repo = output_hub_repo 99 | self.hf_hub_token = hf_hub_token or os.environ.get("HF_TOKEN") 100 | self.merge_adapters = merge_adapters 101 | self.private_repo = private_repo 102 | self.shuffle_each_epoch = shuffle_each_epoch 103 | self.num_workers = num_workers 104 | self.shuffle_buffer_size = shuffle_buffer_size 105 | 106 | # Checkpoint manager for unified checkpoint handling 107 | self.checkpoint_manager = CheckpointManager(output_dir) 108 | 109 | # Calculate dataset size and create dataloader 110 | self.dataloader, self.dataset_size, self.total_batches = self._prepare_dataloader() 111 | 112 | if self.push_to_hub_flag and not self.output_hub_repo: 113 | raise ValueError("Cannot enable push to hub without providing output_hub_repo.") 114 | 115 | self._validate_contracts() 116 | 117 | def _prepare_dataloader(self): 118 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 119 | rank = int(os.environ.get("RANK", 0)) 120 | 121 | if isinstance(self._dataset, DataLoader): 122 | try: 123 | dataset_size = len(self._dataset.dataset) 124 | total_batches = (dataset_size // world_size + self.batch_size - 1) // self.batch_size * self.epochs 125 | except TypeError: 126 | dataset_size = None 127 | total_batches = None 128 | return self._dataset, dataset_size, total_batches 129 | 130 | dataset_size = len(self._dataset) 131 | total_batches = (dataset_size // world_size + self.batch_size - 1) // self.batch_size * self.epochs 132 | 133 | if rank == 0: 134 | logger.info(f"Dataset: {dataset_size:,} examples, ~{total_batches // self.epochs:,} batches per GPU per epoch") 135 | 136 | iterable_dataset = self._dataset.to_iterable_dataset(num_shards=128) 137 | iterable_dataset = iterable_dataset.shuffle(seed=42, buffer_size=self.shuffle_buffer_size) 138 | 139 | dataloader = DataLoader( 140 | iterable_dataset, 141 | batch_size=self.batch_size, 142 | collate_fn=self.collate_fn, 143 | num_workers=self.num_workers, 144 | pin_memory=True, 145 | ) 146 | 147 | return dataloader, dataset_size, total_batches 148 | 149 | def _validate_contracts(self): 150 | parse_errors = [] 151 | 152 | def strict_parse(component, label): 153 | if not component.contract: 154 | return None 155 | try: 156 | return parse_contract(component.contract) 157 | except ContractParseError as e: 158 | parse_errors.append(f"{label}: {e}") 159 | return None 160 | 161 | preprocessor_contracts = [] 162 | for p in self.preprocessors: 163 | contract = strict_parse(p, p.name) 164 | if contract: 165 | preprocessor_contracts.append(contract) 166 | 167 | model_contract = strict_parse(self.model_factory, "model") 168 | 169 | postprocessor_contracts = [] 170 | for p in self.postprocessors: 171 | contract = strict_parse(p, p.name) 172 | if contract: 173 | postprocessor_contracts.append(contract) 174 | 175 | loss_contract = strict_parse(self.loss_calculator, "loss_calculator") 176 | 177 | if parse_errors: 178 | raise ContractParseError( 179 | "Invalid contracts:\n " + "\n ".join(parse_errors) 180 | ) 181 | 182 | if not model_contract or not loss_contract: 183 | logger.debug("Skipping contract validation - model or loss contract not specified") 184 | return 185 | 186 | errors = validate_pipeline_chain( 187 | preprocessor_contracts, 188 | model_contract, 189 | postprocessor_contracts, 190 | loss_contract, 191 | ) 192 | 193 | if errors: 194 | logger.info(format_pipeline_contracts( 195 | [(p.name, strict_parse(p, p.name)) for p in self.preprocessors if strict_parse(p, p.name)], 196 | ("model", model_contract), 197 | [(p.name, strict_parse(p, p.name)) for p in self.postprocessors if strict_parse(p, p.name)], 198 | ("loss", loss_contract), 199 | )) 200 | raise ContractParseError( 201 | "Pipeline contract validation errors:\n " + "\n ".join(errors) 202 | ) 203 | 204 | def _load_checkpoint_if_exists(self) -> tuple[Optional[Checkpoint], Optional[TrainerState]]: 205 | """ 206 | Load checkpoint if it exists and load_checkpoint is True. 207 | Returns (checkpoint, trainer_state) tuple. 208 | """ 209 | if not self.load_checkpoint: 210 | return None, None 211 | 212 | checkpoint = self.checkpoint_manager.load() 213 | if checkpoint is None: 214 | return None, None 215 | 216 | rank = int(os.environ.get("RANK", 0)) 217 | if rank == 0: 218 | logger.info(f"Found checkpoint with training state: {checkpoint.training_state}") 219 | 220 | trainer_state = TrainerState() 221 | trainer_state.load_state_dict(checkpoint.training_state) 222 | 223 | return checkpoint, trainer_state 224 | 225 | def _create_optimizer(self, model: nn.Module, checkpoint: Optional[Checkpoint] = None): 226 | """Create optimizer and optionally load state from checkpoint.""" 227 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 228 | lr = self.lr * world_size if world_size > 1 else self.lr 229 | is_fsdp = self.model_factory.fsdp_config.enabled if self.model_factory.fsdp_config else False 230 | 231 | if self._user_optimizer is not None: 232 | optimizer = self._user_optimizer 233 | elif self.use_8bit_optim: 234 | if self.optim_backend == "bnb": 235 | if is_fsdp: 236 | logger.warning("bitsandbytes 8-bit optimizer not compatible with FSDP2, falling back to torchao") 237 | from torchao.optim import AdamW8bit 238 | optimizer = AdamW8bit( 239 | model.parameters(), 240 | lr=lr, 241 | weight_decay=self.weight_decay, 242 | betas=(0.9, 0.999), 243 | eps=1e-8, 244 | ) 245 | else: 246 | optimizer = bnb.optim.Adam8bit( 247 | model.parameters(), 248 | lr=lr, 249 | weight_decay=self.weight_decay, 250 | betas=(0.9, 0.999), 251 | eps=1e-8, 252 | ) 253 | elif self.optim_backend == "torchao": 254 | from torchao.optim import AdamW8bit 255 | optimizer = AdamW8bit( 256 | model.parameters(), 257 | lr=lr, 258 | weight_decay=self.weight_decay, 259 | betas=(0.9, 0.999), 260 | eps=1e-8, 261 | ) 262 | else: 263 | raise ValueError(f"Unknown optim_backend: {self.optim_backend}") 264 | else: 265 | optimizer = torch.optim.AdamW( 266 | model.parameters(), 267 | lr=lr, 268 | weight_decay=self.weight_decay, 269 | betas=(0.9, 0.999), 270 | eps=1e-8, 271 | ) 272 | 273 | # Load optimizer state from checkpoint if available 274 | if checkpoint is not None and checkpoint.optimizer_state is not None: 275 | try: 276 | self.checkpoint_manager.apply_to_optimizer(checkpoint, optimizer) 277 | except Exception as e: 278 | logger.warning(f"Failed to load optimizer state: {e}. Starting with fresh optimizer.") 279 | 280 | return optimizer 281 | 282 | def run(self): 283 | if self.tokenizer_factory: 284 | self.tokenizer = self.tokenizer_factory.build() 285 | 286 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 287 | rank = int(os.environ.get("RANK", 0)) 288 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 289 | 290 | if world_size > 1: 291 | logger.info(f"Distributed: rank {rank}, local_rank {local_rank}, world_size {world_size}") 292 | 293 | # Step 1: Load checkpoint if exists (before building model) 294 | checkpoint, trainer_state = self._load_checkpoint_if_exists() 295 | 296 | if checkpoint is not None and checkpoint.model_state is not None: 297 | # Inject model weights into model factory for loading 298 | # The factory will use these instead of fresh initialization 299 | self.model_factory._checkpoint_state = checkpoint.model_state 300 | if rank == 0: 301 | logger.info("Will load model weights from checkpoint") 302 | 303 | # Step 2: Build model (with checkpoint weights if available) 304 | self.model = self.model_factory.build() 305 | 306 | if self.gradient_checkpointing and hasattr(self.model, "gradient_checkpointing_enable"): 307 | self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) 308 | 309 | if hasattr(self.model, "config"): 310 | self.model.config.use_cache = not self.gradient_checkpointing 311 | 312 | # Step 3: Create optimizer (and load optimizer state from checkpoint) 313 | self.optimizer = self._create_optimizer(self.model, checkpoint) 314 | 315 | # Step 4: Load RNG states if resuming 316 | if checkpoint is not None: 317 | self.checkpoint_manager.load_rng_state(rank=rank, local_rank=local_rank) 318 | if rank == 0: 319 | logger.info(f"Resuming from epoch {trainer_state.epoch}, step {trainer_state.steps}, total_steps {trainer_state.total_steps}") 320 | 321 | # Step 5: Create accelerator 322 | acc_kwargs = { 323 | "gradient_accumulation_steps": self.grad_accum, 324 | "project_dir": self.output_dir, 325 | "project_config": ProjectConfiguration( 326 | project_dir=self.output_dir, 327 | automatic_checkpoint_naming=True, 328 | save_on_each_node=True, 329 | ), 330 | "mixed_precision": "bf16" if self.use_bfloat16 else ("fp16" if self.fp16 else "no"), 331 | } 332 | acc_kwargs.update(self.accelerator_kwargs) 333 | self.accelerator = Accelerator(**acc_kwargs) 334 | 335 | if self.accelerator.is_main_process: 336 | logger.info(f"Mixed precision: {self.accelerator.mixed_precision}") 337 | logger.info(f"Model: {self.model.__class__.__name__}") 338 | total_params = sum(p.numel() for p in self.model.parameters()) 339 | trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) 340 | logger.info(f" Total params: {total_params:,}") 341 | logger.info(f" Trainable params: {trainable_params:,}") 342 | logger.info(f" Loss calculator: {self.loss_calculator.__class__.__name__}") 343 | 344 | # Step 6: Create trainer (prepare() happens inside trainer) 345 | trainer = Trainer( 346 | model=self.model, 347 | optimizer=self.optimizer, 348 | accelerator=self.accelerator, 349 | dataset=self.dataloader, 350 | device="cuda", 351 | preprocessors=self.preprocessors, 352 | postprocessors=self.postprocessors, 353 | loss_calculator=self.loss_calculator, 354 | grad_accum=self.grad_accum, 355 | fp16=self.fp16, 356 | use_bfloat16=self.use_bfloat16, 357 | output_dir=self.output_dir, 358 | checkpoint_every=self.checkpoint_every, 359 | seed=self.seed, 360 | use_scheduler=False, 361 | metrics_logger=self.metrics_logger, 362 | log_every=self.log_every, 363 | max_grad_norm=self.max_grad_norm, 364 | hf_hub_token=self.hf_hub_token, 365 | shuffle_each_epoch=self.shuffle_each_epoch, 366 | total_batches=self.total_batches, 367 | is_fsdp=self.model_factory.fsdp_config.enabled if self.model_factory.fsdp_config else False, 368 | initial_state=trainer_state, 369 | ) 370 | 371 | trainer.train(epochs=self.epochs, max_steps=self.max_steps) 372 | 373 | self.accelerator.wait_for_everyone() 374 | 375 | if self.push_to_hub_flag: 376 | model = self.accelerator.unwrap_model(self.model) 377 | 378 | if self.merge_adapters and hasattr(model, "merge_and_unload"): 379 | logger.info("Merging adapters and unloading.") 380 | model = model.merge_and_unload(True) 381 | 382 | if self.accelerator.is_main_process: 383 | logger.info("Pushing to hub!") 384 | 385 | model.push_to_hub = types.MethodType(push_to_hub, model) 386 | model.push_to_hub(self.output_hub_repo, token=self.hf_hub_token, accelerator=self.accelerator, private=self.private_repo) 387 | 388 | if self.accelerator.is_main_process: 389 | logger.info("Training complete!") 390 | 391 | return self.model 392 | -------------------------------------------------------------------------------- /lib/ditty/model_factory.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | from dataclasses import dataclass, field 4 | from logging import getLogger 5 | from typing import Optional, List, Type, Dict, Any, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import safetensors 10 | from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy 11 | from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, BitsAndBytesConfig 12 | from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME 13 | from bitsandbytes.nn import Linear4bit, Params4bit 14 | from accelerate import init_empty_weights 15 | from fastcore.parallel import parallel 16 | from tqdm.auto import tqdm 17 | 18 | logger = getLogger("ditty_model_factory") 19 | 20 | 21 | class ModelTransform: 22 | """Transform applied to a model after loading. 23 | 24 | Use for operations like wrapping models, freezing layers, etc. 25 | """ 26 | def transform(self, model: nn.Module) -> nn.Module: 27 | raise NotImplementedError 28 | 29 | 30 | @dataclass 31 | class FSDPConfig: 32 | enabled: bool = False 33 | transformer_layers: List[Type[nn.Module]] = field(default_factory=list) 34 | param_dtype: Optional[torch.dtype] = None # e.g. torch.bfloat16 35 | reduce_dtype: Optional[torch.dtype] = None # None = match param_dtype, torch.float32 for accuracy 36 | reshard_after_forward: bool = True # True = FULL_SHARD, False = SHARD_GRAD_OP 37 | 38 | 39 | @dataclass 40 | class QuantConfig: 41 | enabled: bool = False 42 | bits: int = 4 # 4 or 8 43 | use_double_quant: bool = True 44 | quant_type: str = "nf4" 45 | compute_dtype: torch.dtype = torch.bfloat16 46 | quant_storage: torch.dtype = torch.bfloat16 47 | use_dora: bool = False 48 | 49 | 50 | @dataclass 51 | class PeftConfig: 52 | enabled: bool = False 53 | r: int = 8 54 | lora_alpha: int = 16 55 | lora_dropout: float = 0.1 56 | target_modules: List[str] = field(default_factory=lambda: ["k_proj", "q_proj", "v_proj", "up_proj", "down_proj", "gate_proj"]) 57 | use_dora: bool = False 58 | 59 | 60 | class ModelFactory: 61 | """ 62 | Factory for loading models and preparing them for distributed training. 63 | 64 | Handles: 65 | - Loading from HuggingFace Hub 66 | - Loading from local checkpoints 67 | - Wrapping existing model instances 68 | - FSDP2 sharding via fully_shard() 69 | - QLoRA 4bit/8bit quantization 70 | """ 71 | 72 | def __init__( 73 | self, 74 | model: Optional[nn.Module] = None, 75 | model_path: Optional[str] = None, 76 | model_class: Optional[Type[nn.Module]] = None, 77 | fsdp_config: Optional[Union[FSDPConfig, Dict[str, Any]]] = None, 78 | quant_config: Optional[Union[QuantConfig, Dict[str, Any]]] = None, 79 | peft_config: Optional[Union[PeftConfig, Dict[str, Any]]] = None, 80 | load_kwargs: Optional[Dict[str, Any]] = None, 81 | contract: str = "", 82 | model_transform: Optional[ModelTransform] = None, 83 | use_compile: bool = False, 84 | compile_mode: str = "default", 85 | ): 86 | self._model = model 87 | self._model_path = model_path 88 | self._model_class = model_class 89 | self._load_kwargs = load_kwargs or {} 90 | self.contract = contract 91 | self._model_transform = model_transform 92 | self.use_compile = use_compile 93 | self.compile_mode = compile_mode 94 | # Injected by Pipeline when resuming from checkpoint 95 | self._checkpoint_state: Optional[Dict[str, Any]] = None 96 | 97 | if isinstance(fsdp_config, dict): 98 | self.fsdp_config = FSDPConfig(**fsdp_config) 99 | else: 100 | self.fsdp_config = fsdp_config or FSDPConfig() 101 | 102 | if isinstance(quant_config, dict): 103 | self.quant_config = QuantConfig(**quant_config) 104 | else: 105 | self.quant_config = quant_config or QuantConfig() 106 | 107 | if isinstance(peft_config, dict): 108 | self.peft_config = PeftConfig(**peft_config) 109 | else: 110 | self.peft_config = peft_config or PeftConfig() 111 | 112 | if model is None and model_path is None: 113 | raise ValueError("Must provide either model or model_path") 114 | 115 | @classmethod 116 | def from_huggingface( 117 | cls, 118 | model_path: str, 119 | fsdp_config: Optional[Union[FSDPConfig, Dict[str, Any]]] = None, 120 | quant_config: Optional[Union[QuantConfig, Dict[str, Any]]] = None, 121 | peft_config: Optional[Union[PeftConfig, Dict[str, Any]]] = None, 122 | **load_kwargs, 123 | ) -> "ModelFactory": 124 | return cls( 125 | model_path=model_path, 126 | model_class=AutoModelForCausalLM, 127 | fsdp_config=fsdp_config, 128 | quant_config=quant_config, 129 | peft_config=peft_config, 130 | load_kwargs=load_kwargs, 131 | ) 132 | 133 | @classmethod 134 | def from_checkpoint( 135 | cls, 136 | checkpoint_path: str, 137 | model_class: Type[nn.Module], 138 | fsdp_config: Optional[Union[FSDPConfig, Dict[str, Any]]] = None, 139 | model_transform: Optional[ModelTransform] = None, 140 | use_compile: bool = False, 141 | compile_mode: str = "default", 142 | **model_kwargs, 143 | ) -> "ModelFactory": 144 | return cls( 145 | model_path=checkpoint_path, 146 | model_class=model_class, 147 | fsdp_config=fsdp_config, 148 | load_kwargs=model_kwargs, 149 | model_transform=model_transform, 150 | use_compile=use_compile, 151 | compile_mode=compile_mode, 152 | ) 153 | 154 | @classmethod 155 | def from_instance( 156 | cls, 157 | model: nn.Module, 158 | fsdp_config: Optional[Union[FSDPConfig, Dict[str, Any]]] = None, 159 | use_compile: bool = False, 160 | compile_mode: str = "default", 161 | ) -> "ModelFactory": 162 | return cls( 163 | model=model, 164 | fsdp_config=fsdp_config, 165 | use_compile=use_compile, 166 | compile_mode=compile_mode, 167 | ) 168 | 169 | def _replace_linear(self, model: nn.Module, skip_modules: List[str] = None): 170 | skip_modules = skip_modules or ["lm_head"] 171 | for name, module in model.named_children(): 172 | if name in skip_modules: 173 | continue 174 | if len(list(module.children())) > 0: 175 | self._replace_linear(module, skip_modules) 176 | if isinstance(module, nn.Linear): 177 | model._modules[name] = Linear4bit( 178 | module.in_features, 179 | module.out_features, 180 | module.bias is not None, 181 | compute_dtype=self.quant_config.compute_dtype, 182 | quant_type=self.quant_config.quant_type, 183 | quant_storage=self.quant_config.quant_storage, 184 | ) 185 | return model 186 | 187 | def _n_loading_workers(self, param_count: float): 188 | devprops = torch.cuda.get_device_properties(torch.cuda.current_device()) 189 | left = int(os.cpu_count() / torch.cuda.device_count()) 190 | right = int(8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))) 191 | return min(left, right) 192 | 193 | def _load_and_quantize(self, module: nn.Module, name: str, value: torch.Tensor, 194 | device=None, dtype=None, skip_names=None, to_cpu=False, to_meta=False): 195 | skip_names = skip_names or [] 196 | 197 | def place_on_device(value): 198 | if to_meta: 199 | return value.to(device="meta", dtype=dtype) 200 | elif to_cpu: 201 | return value.to(device="cpu", dtype=dtype) 202 | return value.to(device=device, dtype=dtype) 203 | 204 | if any(skip_name in name for skip_name in skip_names): 205 | return 206 | 207 | module_key, _, value_key = name.rpartition(".") 208 | try: 209 | submodule = module.get_submodule(module_key) 210 | except AttributeError: 211 | return 212 | 213 | try: 214 | param = submodule.get_parameter(value_key) 215 | if isinstance(param, Params4bit): 216 | if self.quant_config.use_dora: 217 | setattr(submodule, "dora_scale", value.norm(p=2, dim=1).to(dtype=dtype).to("cpu")) 218 | value = type(param)(value.to(device=device, dtype=dtype).data, **param.__dict__).cuda(device) 219 | if to_meta: 220 | value = type(param)(value.data.to("meta"), **value.__dict__) 221 | elif to_cpu: 222 | value = type(param)(value.data.to("cpu"), **value.__dict__) 223 | else: 224 | value = type(param)(place_on_device(value).data) 225 | except AttributeError: 226 | value = place_on_device(value) 227 | 228 | setattr(submodule, value_key, value) 229 | 230 | def _load_quantized_model(self) -> nn.Module: 231 | rank = int(os.environ.get("RANK", 0)) 232 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 233 | 234 | cfg = AutoConfig.from_pretrained(self._model_path, **self._load_kwargs) 235 | cfg.use_cache = False 236 | if self._load_kwargs.get("attn_implementation"): 237 | cfg.attn_implementation = self._load_kwargs["attn_implementation"] 238 | 239 | with init_empty_weights(): 240 | model = AutoModelForCausalLM.from_config(cfg) 241 | model.model = self._replace_linear(model.model) 242 | 243 | model.is_loaded_in_4bit = True 244 | 245 | try: 246 | idx = hub.cached_file(self._model_path, SAFE_WEIGHTS_INDEX_NAME) 247 | files, _ = hub.get_checkpoint_shard_files(self._model_path, idx) 248 | except OSError: 249 | try: 250 | files = [hub.cached_file(self._model_path, SAFE_WEIGHTS_NAME)] 251 | except OSError as e: 252 | raise e 253 | 254 | def load_and_quantize_parallel(name_param, model, **kwargs): 255 | name, param = name_param 256 | self._load_and_quantize(model, name, param, **kwargs) 257 | 258 | param_count = sum(p.numel() for p in model.parameters()) 259 | if local_rank == 0: 260 | logger.info(f"Total model params: {param_count}") 261 | 262 | n_workers = self._n_loading_workers(param_count) 263 | if rank == 0: 264 | logger.info(f"Using n_workers: {n_workers} for loading") 265 | 266 | for filename in tqdm(files, desc="Loading & Quantizing", disable=rank != 0): 267 | weights = safetensors.torch.load_file(filename) 268 | parallel( 269 | load_and_quantize_parallel, 270 | iter(weights.items()), 271 | n_workers=n_workers, 272 | threadpool=True, 273 | model=model, 274 | dtype=self.quant_config.compute_dtype, 275 | device=torch.cuda.current_device(), 276 | skip_names=[], 277 | to_cpu=(local_rank == 0), 278 | to_meta=(local_rank != 0), 279 | ) 280 | 281 | torch.cuda.empty_cache() 282 | return model 283 | 284 | def _load_model(self) -> nn.Module: 285 | if self._model is not None: 286 | model = self._model 287 | # Apply checkpoint state if injected (for resuming training) 288 | if self._checkpoint_state is not None: 289 | logger.info("Loading model weights from checkpoint state") 290 | model.load_state_dict(self._checkpoint_state) 291 | return model 292 | 293 | if self.quant_config.enabled and self.quant_config.bits == 4 and self.fsdp_config.enabled: 294 | logger.info(f"Loading 4bit quantized model: {self._model_path}") 295 | return self._load_quantized_model() 296 | 297 | if self._model_class == AutoModelForCausalLM: 298 | logger.info(f"Loading model from HuggingFace: {self._model_path}") 299 | bnb_config = None 300 | if self.quant_config.enabled: 301 | if self.quant_config.bits == 4: 302 | bnb_config = BitsAndBytesConfig( 303 | load_in_4bit=True, 304 | bnb_4bit_use_double_quant=self.quant_config.use_double_quant, 305 | bnb_4bit_quant_type=self.quant_config.quant_type, 306 | bnb_4bit_quant_storage=self.quant_config.quant_storage, 307 | bnb_4bit_compute_dtype=self.quant_config.compute_dtype, 308 | ) 309 | elif self.quant_config.bits == 8: 310 | bnb_config = BitsAndBytesConfig(load_in_8bit=True) 311 | 312 | return AutoModelForCausalLM.from_pretrained( 313 | self._model_path, 314 | quantization_config=bnb_config, 315 | **self._load_kwargs, 316 | ) 317 | 318 | # For custom model classes, create model then optionally load checkpoint 319 | if self._model_path is None or self._model_path.endswith(".pt") or self._model_path.endswith(".pth"): 320 | # Determine which state dict to use 321 | if self._checkpoint_state is not None: 322 | # Use injected checkpoint state (from Pipeline resume) 323 | logger.info("Loading model weights from checkpoint state") 324 | model = self._model_class(**self._load_kwargs) 325 | model.load_state_dict(self._checkpoint_state) 326 | return model 327 | elif self._model_path is not None: 328 | # Load from explicit checkpoint path 329 | logger.info(f"Loading model from checkpoint: {self._model_path}") 330 | state_dict = torch.load(self._model_path, map_location="cpu", weights_only=False) 331 | if "model_state_dict" in state_dict: 332 | state_dict = state_dict["model_state_dict"] 333 | model = self._model_class(**self._load_kwargs) 334 | model.load_state_dict(state_dict) 335 | return model 336 | else: 337 | # Fresh model, no weights to load 338 | model = self._model_class(**self._load_kwargs) 339 | return model 340 | 341 | raise ValueError(f"Cannot load model from {self._model_path}") 342 | 343 | def _apply_fsdp(self, model: nn.Module) -> nn.Module: 344 | rank = int(os.environ.get("RANK", 0)) 345 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 346 | logger.info(f"Applying FSDP2 sharding (rank {rank}, local_rank {local_rank})") 347 | 348 | torch.cuda.set_device(local_rank) 349 | model = model.to("cpu") 350 | 351 | mp_policy = None 352 | if self.fsdp_config.param_dtype is not None: 353 | mp_policy = MixedPrecisionPolicy( 354 | param_dtype=self.fsdp_config.param_dtype, 355 | reduce_dtype=self.fsdp_config.reduce_dtype, 356 | ) 357 | 358 | fsdp_kwargs = { 359 | "reshard_after_forward": self.fsdp_config.reshard_after_forward, 360 | } 361 | if mp_policy: 362 | fsdp_kwargs["mp_policy"] = mp_policy 363 | 364 | for module in model.modules(): 365 | if any( 366 | isinstance(module, layer_cls) 367 | for layer_cls in self.fsdp_config.transformer_layers 368 | ): 369 | fully_shard(module, **fsdp_kwargs) 370 | 371 | fully_shard(model, **fsdp_kwargs) 372 | return model 373 | 374 | def _setup_quantized_meta_for_peft(self, model: nn.Module): 375 | def temp_to_method(self, *args, **kwargs): 376 | return self 377 | for param in model.parameters(): 378 | if isinstance(param, Params4bit): 379 | param.quant_state._orig_to = param.quant_state.to 380 | param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) 381 | 382 | def _setup_quantized_peft_meta_for_training(self, model: nn.Module): 383 | for param in model.parameters(): 384 | if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): 385 | param.quant_state.to = param.quant_state._orig_to 386 | param.quant_state._orig_to = None 387 | 388 | def _apply_peft(self, model: nn.Module) -> nn.Module: 389 | from peft import TaskType, LoraConfig, get_peft_model 390 | 391 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 392 | 393 | lora_config = LoraConfig( 394 | task_type=TaskType.CAUSAL_LM, 395 | target_modules=self.peft_config.target_modules, 396 | inference_mode=False, 397 | r=self.peft_config.r, 398 | lora_alpha=self.peft_config.lora_alpha, 399 | lora_dropout=self.peft_config.lora_dropout, 400 | bias="none", 401 | use_dora=self.peft_config.use_dora, 402 | ) 403 | 404 | model.enable_input_require_grads() 405 | 406 | if self.quant_config.enabled and local_rank != 0: 407 | self._setup_quantized_meta_for_peft(model) 408 | 409 | model = get_peft_model(model, lora_config) 410 | 411 | if self.quant_config.enabled: 412 | self._setup_quantized_peft_meta_for_training(model) 413 | 414 | return model 415 | 416 | def build(self) -> nn.Module: 417 | model = self._load_model() 418 | 419 | if self._model_transform is not None: 420 | model = self._model_transform.transform(model) 421 | 422 | if self.peft_config.enabled: 423 | model = self._apply_peft(model) 424 | 425 | if self.use_compile: 426 | logger.info(f"Compiling model with torch.compile(mode={self.compile_mode})") 427 | model = torch.compile(model, mode=self.compile_mode) 428 | 429 | if not self.fsdp_config.enabled: 430 | logger.info("FSDP disabled, returning unwrapped model") 431 | return model 432 | 433 | return self._apply_fsdp(model) 434 | 435 | 436 | class TokenizerFactory: 437 | def __init__( 438 | self, 439 | tokenizer_path: str, 440 | pad_token: Optional[str] = None, 441 | token: Optional[str] = None, 442 | **load_kwargs, 443 | ): 444 | self._tokenizer_path = tokenizer_path 445 | self._pad_token = pad_token 446 | self._token = token or os.environ.get("HF_TOKEN") 447 | self._load_kwargs = load_kwargs 448 | 449 | @classmethod 450 | def from_pretrained(cls, tokenizer_path: str, **kwargs) -> "TokenizerFactory": 451 | return cls(tokenizer_path=tokenizer_path, **kwargs) 452 | 453 | def build(self): 454 | tokenizer = AutoTokenizer.from_pretrained( 455 | self._tokenizer_path, 456 | token=self._token, 457 | **self._load_kwargs, 458 | ) 459 | if tokenizer.pad_token_id is None: 460 | if self._pad_token: 461 | tokenizer.pad_token = self._pad_token 462 | else: 463 | logger.warning("Tokenizer did not have a pad_token_id, set to EOS.") 464 | tokenizer.pad_token_id = tokenizer.eos_token_id 465 | return tokenizer 466 | -------------------------------------------------------------------------------- /lib/ditty/example.py: -------------------------------------------------------------------------------- 1 | """ 2 | End-to-end example demonstrating the ditty training pipeline. 3 | 4 | Components demonstrated: 5 | - Data: Dataset loading and preparation 6 | - PreProcessor: Transform batch before model forward 7 | - PostProcessor: Transform model output before loss calculation 8 | - LossCalculator: Compute loss from model output 9 | - ModelFactory: Create and configure models 10 | - Pipeline: Orchestrate the entire training flow 11 | - Contracts: Declarative tensor shape/dtype validation 12 | """ 13 | import torch 14 | import torch.nn as nn 15 | from typing import Any, Dict, Tuple, List 16 | from datasets import Dataset 17 | 18 | from .processors import PreProcessor, PostProcessor, Context 19 | from .loss import LossCalculator, LossOutput, CompositeLoss 20 | from .model_factory import ModelFactory, ModelTransform, FSDPConfig 21 | from .pipeline import Pipeline 22 | from .contract import parse_contract, format_pipeline_contracts, validate_pipeline_chain 23 | 24 | 25 | # --- Example Model --- 26 | 27 | 28 | class ExampleModel(nn.Module): 29 | """Simple encoder-decoder model for demonstration.""" 30 | 31 | def __init__(self, vocab_size: int = 1000, embed_dim: int = 256, hidden_dim: int = 512): 32 | super().__init__() 33 | self.vocab_size = vocab_size 34 | self.embed_dim = embed_dim 35 | self.hidden_dim = hidden_dim 36 | self.embedding = nn.Embedding(vocab_size, embed_dim) 37 | self.encoder = nn.Linear(embed_dim, hidden_dim) 38 | self.decoder = nn.Linear(hidden_dim, vocab_size) 39 | 40 | def forward( 41 | self, 42 | input_ids: torch.Tensor, 43 | return_hidden: bool = False, 44 | ) -> Tuple[torch.Tensor, ...]: 45 | emb = self.embedding(input_ids) 46 | hidden = self.encoder(emb) 47 | logits = self.decoder(hidden) 48 | if return_hidden: 49 | return logits, hidden 50 | return (logits,) 51 | 52 | 53 | # --- Example Preprocessors --- 54 | 55 | 56 | class TokenMasker(PreProcessor): 57 | """Randomly mask tokens for masked language modeling.""" 58 | 59 | def __init__(self, mask_prob: float = 0.15, mask_token_id: int = 0): 60 | super().__init__(contract="batch:2:i64 -> batch:2:i64") 61 | self.mask_prob = mask_prob 62 | self.mask_token_id = mask_token_id 63 | 64 | def config(self) -> Dict[str, Any]: 65 | return {"mask_prob": self.mask_prob} 66 | 67 | def process(self, batch: torch.Tensor, ctx: Context) -> Tuple[torch.Tensor, Context]: 68 | ctx["original_input_ids"] = batch.clone() 69 | mask = torch.rand_like(batch.float()) < self.mask_prob 70 | masked_batch = batch.clone() 71 | masked_batch[mask] = self.mask_token_id 72 | ctx["mask"] = mask.float() 73 | return masked_batch, ctx 74 | 75 | 76 | class ForwardKwargsInjector(PreProcessor): 77 | """Inject additional kwargs into model forward call.""" 78 | 79 | def __init__(self, contract: str = "", **kwargs): 80 | super().__init__(contract=contract) 81 | self.kwargs = kwargs 82 | 83 | def config(self) -> Dict[str, Any]: 84 | return self.kwargs 85 | 86 | def process(self, batch: Any, ctx: Context) -> Tuple[Any, Context]: 87 | ctx["forward_kwargs"] = ctx.get("forward_kwargs", {}) 88 | ctx["forward_kwargs"].update(self.kwargs) 89 | return batch, ctx 90 | 91 | 92 | # --- Example Postprocessors --- 93 | 94 | 95 | class TargetExtractor(PostProcessor): 96 | """Extract targets from context for loss computation.""" 97 | 98 | def __init__(self, target_key: str = "original_input_ids", contract: str = ""): 99 | super().__init__(contract=contract) 100 | self.target_key = target_key 101 | 102 | def process( 103 | self, model_output: Tuple[Any, ...], ctx: Context 104 | ) -> Tuple[Tuple[Any, ...], Context]: 105 | ctx["target"] = ctx[self.target_key] 106 | return model_output, ctx 107 | 108 | 109 | class HiddenStateExtractor(PostProcessor): 110 | """Extract hidden states for auxiliary losses.""" 111 | 112 | def __init__(self, output_index: int = 1, contract: str = ""): 113 | super().__init__(contract=contract) 114 | self.output_index = output_index 115 | 116 | def process( 117 | self, model_output: Tuple[Any, ...], ctx: Context 118 | ) -> Tuple[Tuple[Any, ...], Context]: 119 | if len(model_output) > self.output_index: 120 | ctx["hidden_states"] = model_output[self.output_index] 121 | return model_output, ctx 122 | 123 | 124 | # --- Example Loss Calculators --- 125 | 126 | 127 | class MaskedCrossEntropyLoss(LossCalculator): 128 | """Cross-entropy loss only on masked positions.""" 129 | 130 | def __init__(self, vocab_size: int = 1000, contract: str = ""): 131 | super().__init__( 132 | output_index=0, 133 | target_key="target", 134 | mask_key="mask", 135 | contract=contract, 136 | ) 137 | self.vocab_size = vocab_size 138 | 139 | def compute(self, model_output: Tuple[Any, ...], ctx: Context) -> LossOutput: 140 | logits = self.get_prediction(model_output) 141 | target = self.get_target(ctx) 142 | mask = self.get_mask(ctx) 143 | 144 | logits_flat = logits.reshape(-1, self.vocab_size) 145 | target_flat = target.reshape(-1) 146 | mask_flat = mask.reshape(-1) if mask is not None else torch.ones_like(target_flat, dtype=torch.float) 147 | 148 | loss_per_token = torch.nn.functional.cross_entropy( 149 | logits_flat, target_flat, reduction="none" 150 | ) 151 | loss = (loss_per_token * mask_flat).sum() / mask_flat.sum().clamp(min=1) 152 | return LossOutput(loss=loss, metrics={"masked_ce": loss.item()}) 153 | 154 | 155 | class HiddenRegularizer(LossCalculator): 156 | """L2 regularization on hidden states.""" 157 | 158 | def __init__(self, weight: float = 0.01, contract: str = ""): 159 | super().__init__(contract=contract) 160 | self.weight = weight 161 | 162 | def compute(self, model_output: Tuple[Any, ...], ctx: Context) -> LossOutput: 163 | hidden = ctx.get("hidden_states") 164 | if hidden is None: 165 | device = ctx.get("device", "cuda") 166 | return LossOutput(loss=torch.tensor(0.0, device=device), metrics={"hidden_reg": 0.0}) 167 | loss = self.weight * (hidden ** 2).mean() 168 | return LossOutput(loss=loss, metrics={"hidden_reg": loss.item()}) 169 | 170 | 171 | # --- Example Model Transform --- 172 | 173 | 174 | class FreezeEmbeddings(ModelTransform): 175 | """Freeze embedding layer during training.""" 176 | 177 | def transform(self, model: nn.Module) -> nn.Module: 178 | if hasattr(model, "embedding"): 179 | for param in model.embedding.parameters(): 180 | param.requires_grad = False 181 | return model 182 | 183 | 184 | # --- Architecture Printing --- 185 | 186 | 187 | def print_pipeline( 188 | model: nn.Module, 189 | preprocessors: List[PreProcessor], 190 | postprocessors: List[PostProcessor], 191 | loss_calculator: LossCalculator, 192 | model_contract: str = "", 193 | ): 194 | """ 195 | Print complete pipeline architecture with data flow and contract validation. 196 | 197 | This gives you an end-to-end view of: 198 | - Model architecture and parameters 199 | - Data flow through preprocessors -> model -> postprocessors -> loss 200 | - Contract specifications at each stage 201 | - Contract chain validation results 202 | """ 203 | width = 70 204 | 205 | print("\n" + "=" * width) 206 | print(" DITTY PIPELINE ARCHITECTURE ".center(width, "=")) 207 | print("=" * width) 208 | 209 | # --- Model --- 210 | print("\n>>> MODEL") 211 | print(f" Class: {model.__class__.__name__}") 212 | total_params = sum(p.numel() for p in model.parameters()) 213 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 214 | frozen_params = total_params - trainable_params 215 | print(f" Parameters: {total_params:,} total, {trainable_params:,} trainable, {frozen_params:,} frozen") 216 | 217 | attrs = ["vocab_size", "embed_dim", "hidden_dim", "latent_dim", "num_layers", "num_heads"] 218 | model_attrs = {a: getattr(model, a) for a in attrs if hasattr(model, a)} 219 | if model_attrs: 220 | print(f" Config: {model_attrs}") 221 | 222 | if model_contract: 223 | print(f" Contract: {model_contract}") 224 | 225 | # --- Data Flow --- 226 | print("\n>>> DATA FLOW") 227 | print(" ┌─────────────────────────────────────────────────────────────┐") 228 | print(" │ Dataset │") 229 | print(" └─────────────────────────────────────────────────────────────┘") 230 | print(" │") 231 | print(" ▼") 232 | 233 | for i, p in enumerate(preprocessors): 234 | cfg = p.config() 235 | cfg_str = f" {cfg}" if cfg else "" 236 | print(f" ┌─ PreProcessor {i+1}: {p.name}{cfg_str}") 237 | if p.contract: 238 | print(f" │ {p.contract}") 239 | print(" └──────────────────────────────────────────────────────────────") 240 | print(" │") 241 | print(" ▼") 242 | 243 | print(" ╔═════════════════════════════════════════════════════════════╗") 244 | print(f" ║ MODEL: {model.__class__.__name__:<53} ║") 245 | if model_contract: 246 | contract_display = model_contract[:55] + "..." if len(model_contract) > 55 else model_contract 247 | print(f" ║ {contract_display:<61} ║") 248 | print(" ╚═════════════════════════════════════════════════════════════╝") 249 | print(" │") 250 | print(" ▼") 251 | 252 | for i, p in enumerate(postprocessors): 253 | print(f" ┌─ PostProcessor {i+1}: {p.name}") 254 | if p.contract: 255 | print(f" │ {p.contract}") 256 | print(" └──────────────────────────────────────────────────────────────") 257 | print(" │") 258 | print(" ▼") 259 | 260 | print(" ┌─────────────────────────────────────────────────────────────┐") 261 | if isinstance(loss_calculator, CompositeLoss): 262 | print(" │ LOSS: CompositeLoss │") 263 | for calc, weight in loss_calculator.losses: 264 | print(f" │ • {calc.name} (weight={weight})") 265 | if calc.contract: 266 | contract_short = calc.contract[:50] + "..." if len(calc.contract) > 50 else calc.contract 267 | print(f" │ {contract_short}") 268 | else: 269 | print(f" │ LOSS: {loss_calculator.name:<56} │") 270 | if loss_calculator.contract: 271 | print(f" │ {loss_calculator.contract:<61} │") 272 | print(" └─────────────────────────────────────────────────────────────┘") 273 | 274 | # --- Contract Validation --- 275 | print("\n>>> CONTRACT VALIDATION") 276 | 277 | pre_contracts = [(p.name, parse_contract(p.contract)) for p in preprocessors if p.contract] 278 | model_c = parse_contract(model_contract) if model_contract else None 279 | post_contracts = [(p.name, parse_contract(p.contract)) for p in postprocessors if p.contract] 280 | 281 | loss_contract_str = "" 282 | if isinstance(loss_calculator, CompositeLoss): 283 | for calc, _ in loss_calculator.losses: 284 | if calc.contract: 285 | loss_contract_str = calc.contract 286 | break 287 | else: 288 | loss_contract_str = loss_calculator.contract 289 | 290 | loss_c = parse_contract(loss_contract_str) if loss_contract_str else None 291 | 292 | if model_c and loss_c: 293 | errors = validate_pipeline_chain( 294 | [c for _, c in pre_contracts], 295 | model_c, 296 | [c for _, c in post_contracts], 297 | loss_c, 298 | ) 299 | if errors: 300 | print(" Status: FAILED") 301 | for err in errors: 302 | print(f" ✗ {err}") 303 | else: 304 | print(" Status: PASSED") 305 | print(" All tensor shapes and dtypes chain correctly through the pipeline.") 306 | else: 307 | print(" Status: SKIPPED (model or loss contract not specified)") 308 | 309 | print("\n" + "=" * width + "\n") 310 | 311 | 312 | # --- Example Data --- 313 | 314 | 315 | def create_example_dataset(num_samples: int = 1000, seq_len: int = 32, vocab_size: int = 1000): 316 | """Create a synthetic dataset for demonstration.""" 317 | return Dataset.from_dict({ 318 | "input_ids": [ 319 | torch.randint(1, vocab_size, (seq_len,)).tolist() 320 | for _ in range(num_samples) 321 | ] 322 | }) 323 | 324 | 325 | def example_collate_fn(batch): 326 | """Collate function that stacks input_ids.""" 327 | return torch.tensor([item["input_ids"] for item in batch]) 328 | 329 | 330 | # --- Main Example --- 331 | 332 | 333 | def run_example(): 334 | """ 335 | Complete end-to-end training example with contracts. 336 | 337 | Flow: 338 | Dataset -> DataLoader 339 | -> TokenMasker (masks random tokens, stores originals in ctx) 340 | -> ForwardKwargsInjector (adds return_hidden=True) 341 | -> model.forward(batch, **forward_kwargs) 342 | -> TargetExtractor (moves original tokens to ctx["target"]) 343 | -> HiddenStateExtractor (extracts hidden states to ctx) 344 | -> CompositeLoss([MaskedCrossEntropyLoss, HiddenRegularizer]) 345 | -> backward + optimize 346 | """ 347 | vocab_size = 1000 348 | embed_dim = 256 349 | hidden_dim = 512 350 | 351 | dataset = create_example_dataset(num_samples=1000, seq_len=32, vocab_size=vocab_size) 352 | 353 | model = ExampleModel(vocab_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim) 354 | 355 | preprocessors = [ 356 | TokenMasker( 357 | mask_prob=0.15, 358 | mask_token_id=0, 359 | ), 360 | ForwardKwargsInjector( 361 | contract="batch:2:i64 -> batch:2:i64", 362 | return_hidden=True, 363 | ), 364 | ] 365 | 366 | postprocessors = [ 367 | TargetExtractor( 368 | target_key="original_input_ids", 369 | contract="logits:3:f, hidden:3:f -> logits:3:f, hidden:3:f", 370 | ), 371 | HiddenStateExtractor( 372 | output_index=1, 373 | contract="logits:3:f, hidden:3:f -> logits:3:f, hidden:3:f", 374 | ), 375 | ] 376 | 377 | loss_calculator = CompositeLoss([ 378 | (MaskedCrossEntropyLoss( 379 | vocab_size=vocab_size, 380 | contract="logits:3:f | ctx.target:2:i64, ctx.mask:2:f -> loss:0:f", 381 | ), 1.0), 382 | (HiddenRegularizer( 383 | weight=0.01, 384 | contract="| ctx.hidden_states:3:f -> loss:0:f", 385 | ), 1.0), 386 | ]) 387 | 388 | model_contract = "batch:2:i64 -> logits:3:f, hidden:3:f" 389 | 390 | print_pipeline(model, preprocessors, postprocessors, loss_calculator, model_contract) 391 | 392 | model_factory = ModelFactory.from_instance(model) 393 | 394 | pipeline = Pipeline( 395 | model_factory=model_factory, 396 | dataset=dataset, 397 | collate_fn=example_collate_fn, 398 | loss_calculator=loss_calculator, 399 | preprocessors=preprocessors, 400 | postprocessors=postprocessors, 401 | output_dir="./example_output", 402 | batch_size=16, 403 | epochs=1, 404 | lr=1e-3, 405 | fp16=False, 406 | checkpoint_every=100, 407 | log_every=10, 408 | ) 409 | 410 | pipeline.run() 411 | 412 | 413 | def run_example_with_fsdp(): 414 | """Example with FSDP for distributed training.""" 415 | vocab_size = 1000 416 | embed_dim = 256 417 | hidden_dim = 512 418 | 419 | dataset = create_example_dataset(num_samples=1000, seq_len=32, vocab_size=vocab_size) 420 | model = ExampleModel(vocab_size=vocab_size, embed_dim=embed_dim, hidden_dim=hidden_dim) 421 | 422 | fsdp_config = FSDPConfig( 423 | enabled=True, 424 | transformer_layers=[nn.Linear], 425 | param_dtype=torch.bfloat16, 426 | reduce_dtype=torch.float32, 427 | ) 428 | 429 | preprocessors = [TokenMasker(mask_prob=0.15)] 430 | postprocessors = [TargetExtractor()] 431 | loss = MaskedCrossEntropyLoss(vocab_size=vocab_size) 432 | 433 | print_pipeline(model, preprocessors, postprocessors, loss) 434 | 435 | model_factory = ModelFactory.from_instance(model, fsdp_config=fsdp_config) 436 | 437 | pipeline = Pipeline( 438 | model_factory=model_factory, 439 | dataset=dataset, 440 | collate_fn=example_collate_fn, 441 | loss_calculator=loss, 442 | preprocessors=preprocessors, 443 | postprocessors=postprocessors, 444 | output_dir="./example_fsdp_output", 445 | batch_size=16, 446 | epochs=1, 447 | use_bfloat16=True, 448 | ) 449 | 450 | pipeline.run() 451 | 452 | 453 | def run_example_with_transform(): 454 | """Example with model transform to freeze layers.""" 455 | vocab_size = 1000 456 | dataset = create_example_dataset(num_samples=500, seq_len=32, vocab_size=vocab_size) 457 | 458 | model = ExampleModel(vocab_size=vocab_size) 459 | transform = FreezeEmbeddings() 460 | transformed_model = transform.transform(model) 461 | 462 | preprocessors = [TokenMasker(mask_prob=0.15)] 463 | postprocessors = [TargetExtractor()] 464 | loss = MaskedCrossEntropyLoss(vocab_size=vocab_size) 465 | 466 | print("\n[Before Transform]") 467 | print_pipeline(model, preprocessors, postprocessors, loss) 468 | 469 | print("\n[After Transform - Embeddings Frozen]") 470 | print_pipeline(transformed_model, preprocessors, postprocessors, loss) 471 | 472 | model_factory = ModelFactory.from_instance(transformed_model) 473 | 474 | pipeline = Pipeline( 475 | model_factory=model_factory, 476 | dataset=dataset, 477 | collate_fn=example_collate_fn, 478 | loss_calculator=loss, 479 | preprocessors=preprocessors, 480 | postprocessors=postprocessors, 481 | output_dir="./example_frozen_output", 482 | batch_size=16, 483 | epochs=1, 484 | ) 485 | 486 | pipeline.run() 487 | 488 | 489 | def demo_contracts(): 490 | """Demonstrate contract parsing and validation without running training.""" 491 | print("\n" + "=" * 70) 492 | print("CONTRACT SYSTEM DEMO") 493 | print("=" * 70) 494 | 495 | print("\n[1] Parsing individual contracts:") 496 | examples = [ 497 | "tokens:2:i64 -> logits:3:f", 498 | "logits:3:f, hidden:3:f -> logits:3:f, hidden:3:f", 499 | "logits:3:f | ctx.target:2:i64, ctx.mask:2:f -> loss:0:f", 500 | ] 501 | for ex in examples: 502 | c = parse_contract(ex) 503 | print(f"\n Input: '{ex}'") 504 | print(f" Parsed: {c}") 505 | print(f" Inputs: {[str(i) for i in c.inputs]}") 506 | print(f" Outputs: {[str(o) for o in c.outputs]}") 507 | print(f" Ctx deps: {[str(d) for d in c.ctx_deps]}") 508 | 509 | print("\n\n[2] Validating a complete pipeline:") 510 | 511 | pre = [ 512 | parse_contract("batch:2:i64 -> batch:2:i64"), 513 | ] 514 | model = parse_contract("batch:2:i64 -> logits:3:f, hidden:3:f") 515 | post = [ 516 | parse_contract("logits:3:f, hidden:3:f -> logits:3:f, hidden:3:f"), 517 | ] 518 | loss = parse_contract("logits:3:f | ctx.target:2:i64 -> loss:0:f") 519 | 520 | errors = validate_pipeline_chain(pre, model, post, loss) 521 | if errors: 522 | print(" Validation errors:") 523 | for e in errors: 524 | print(f" - {e}") 525 | else: 526 | print(" Pipeline contracts are valid!") 527 | 528 | print("\n\n[3] Dtype shorthand reference:") 529 | print(" f - any float (f16, bf16, f32, f64)") 530 | print(" f16 - float16") 531 | print(" f32 - float32") 532 | print(" bf16 - bfloat16") 533 | print(" i - any int (i8, i16, i32, i64)") 534 | print(" i64 - int64") 535 | print(" i32 - int32") 536 | print(" b - bool") 537 | print(" u8 - uint8") 538 | 539 | print("\n" + "=" * 70) 540 | 541 | 542 | if __name__ == "__main__": 543 | demo_contracts() 544 | run_example() 545 | --------------------------------------------------------------------------------