├── distillkit ├── __init__.py ├── compression │ ├── __init__.py │ ├── densify.py │ ├── bitpack.py │ ├── compressor.py │ ├── config.py │ ├── legacy.py │ └── monotonic_logprobs.py ├── lossfuncs │ ├── cross_entropy.py │ ├── __init__.py │ ├── hingeloss.py │ ├── hidden_state.py │ ├── logistic_ranking.py │ ├── common.py │ ├── kl.py │ ├── tvd.py │ └── jsd.py ├── hsd_mapping.py ├── signals.py ├── trainer.py ├── pack_logits.py ├── configuration.py ├── monkey_patch_packing.py ├── sample_common.py ├── sample_logits_vllm.py └── main.py ├── pyproject.toml ├── examples ├── afm_test.yml ├── llama_70b_base.yml └── mistral3.yaml ├── tests └── test_bitpack.py ├── .gitignore ├── LICENSE └── README.md /distillkit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /distillkit/compression/__init__.py: -------------------------------------------------------------------------------- 1 | from distillkit.compression.bitpack import ( 2 | pack_to_bytes, 3 | unpack_from_bytes, 4 | ) 5 | from distillkit.compression.compressor import ( 6 | LogprobCompressor, 7 | ) 8 | from distillkit.compression.config import ( 9 | DistributionQuantizationConfig, 10 | LegacyLogitCompressionConfig, 11 | QuantizationBin, 12 | ) 13 | from distillkit.compression.densify import ( 14 | densify, 15 | ) 16 | from distillkit.compression.legacy import ( 17 | LogitCompressor as LegacyLogitCompressor, 18 | ) 19 | 20 | __all__ = [ 21 | "pack_to_bytes", 22 | "unpack_from_bytes", 23 | "QuantizationBin", 24 | "DistributionQuantizationConfig", 25 | "LogprobCompressor", 26 | "densify", 27 | "LegacyLogitCompressor", 28 | "LegacyLogitCompressionConfig", 29 | ] 30 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.modeling_outputs import CausalLMOutput 3 | from typing_extensions import override 4 | 5 | from distillkit.hsd_mapping import HiddenStateMapping 6 | from distillkit.lossfuncs.common import ( 7 | LossFunctionBase, 8 | ) 9 | from distillkit.signals import TeacherSignal 10 | 11 | 12 | class CrossEntropyLoss(LossFunctionBase): 13 | @override 14 | @classmethod 15 | def name(cls) -> str: 16 | return "cross_entropy" 17 | 18 | @override 19 | def __init__(self): ... 20 | 21 | @override 22 | def __call__( 23 | self, 24 | student_outputs: CausalLMOutput, 25 | signal: TeacherSignal, 26 | mask: torch.Tensor | None = None, 27 | hidden_state_mapping: HiddenStateMapping | None = None, 28 | num_items_in_batch: int | None = None, 29 | ) -> torch.Tensor: 30 | return student_outputs.loss 31 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "distillkit" 7 | description = "Tools for online and offline distillation of large language models." 8 | version = "1.0.0" 9 | authors = [ 10 | { name = "Charles Goddard", email = "charles@arcee.ai" }, 11 | { name = "Lucas Atkins", email = "lucas@arcee.ai" }, 12 | ] 13 | requires-python = ">=3.10" 14 | dependencies = [ 15 | "trl~=0.25.1", 16 | "accelerate~=1.11.0", 17 | "transformers>=4.50.3", 18 | "datasets>=3.5.0", 19 | "pydantic~=2.12.4", 20 | "click~=8.3.0", 21 | "torch>=2.0.0", 22 | "wandb", 23 | "tqdm", 24 | ] 25 | 26 | 27 | [project.scripts] 28 | distillkit = "distillkit.main:main" 29 | 30 | [project.optional-dependencies] 31 | dev = ["pytest~=9.0.2", "ruff~=0.14.9"] 32 | capture = [ 33 | "vllm>=0.12.0", 34 | "pyarrow", 35 | ] 36 | 37 | [project.urls] 38 | repository = "https://github.com/arcee-ai/distillkit" 39 | 40 | [tool.setuptools] 41 | packages = ["distillkit", "distillkit.compression", "distillkit.lossfuncs"] 42 | -------------------------------------------------------------------------------- /examples/afm_test.yml: -------------------------------------------------------------------------------- 1 | project_name: distillkit-afm-online-test 2 | trust_remote_code: true 3 | model: arcee-ai/AFM-4.5B-Base 4 | output_path: /workspace/models/afm-4p5b-instructdistill 5 | use_flash_attention: true 6 | sequence_length: 8192 7 | dataset: 8 | train_dataset: 9 | repo_id: allenai/tulu-3-sft-mixture 10 | split: train 11 | seed: 42 12 | loss_functions: 13 | - function: cross_entropy 14 | weight: 0.25 15 | - function: kl 16 | weight: 0.25 17 | temperature: 2.0 18 | - function: hs_cosine 19 | weight: 0.5 20 | layer_mapping: all 21 | teacher: 22 | kind: hf 23 | path: arcee-ai/AFM-4.5B 24 | kwargs: 25 | attn_implementation: flash_attention_2 26 | torch_dtype: bfloat16 27 | training_args: 28 | dataset_text_field: text 29 | packing: True 30 | num_train_epochs: 1 31 | per_device_train_batch_size: 1 32 | gradient_accumulation_steps: 8 33 | save_steps: 200 34 | save_total_limit: 1 35 | logging_steps: 1 36 | learning_rate: 1.0e-5 37 | weight_decay: 0.05 38 | warmup_ratio: 0.025 39 | lr_scheduler_type: cosine 40 | bf16: true 41 | max_grad_norm: 0.5 42 | optim: adamw_torch 43 | gradient_checkpointing: true 44 | gradient_checkpointing_kwargs: 45 | use_reentrant: false 46 | report_to: wandb 47 | push_to_hub: false 48 | dataset_num_proc: 96 49 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/__init__.py: -------------------------------------------------------------------------------- 1 | from distillkit.lossfuncs.common import LossFunctionBase, MissingProbabilityHandling 2 | from distillkit.lossfuncs.cross_entropy import CrossEntropyLoss 3 | from distillkit.lossfuncs.hidden_state import HiddenStateCosineLoss, HiddenStateMSELoss 4 | from distillkit.lossfuncs.hingeloss import HingeLoss, sparse_hinge_loss 5 | from distillkit.lossfuncs.jsd import JSDLoss, sparse_js_div 6 | from distillkit.lossfuncs.kl import KLDLoss, sparse_kl_div 7 | from distillkit.lossfuncs.logistic_ranking import ( 8 | LogisticRankingLoss, 9 | sparse_logistic_ranking_loss, 10 | ) 11 | from distillkit.lossfuncs.tvd import TVDLoss, sparse_tvd 12 | 13 | ALL_LOSS_CLASSES = [ 14 | KLDLoss, 15 | JSDLoss, 16 | TVDLoss, 17 | HingeLoss, 18 | LogisticRankingLoss, 19 | HiddenStateCosineLoss, 20 | HiddenStateMSELoss, 21 | CrossEntropyLoss, 22 | ] 23 | 24 | __all__ = [ 25 | "sparse_kl_div", 26 | "sparse_js_div", 27 | "sparse_tvd", 28 | "sparse_hinge_loss", 29 | "sparse_logistic_ranking_loss", 30 | "MissingProbabilityHandling", 31 | "KLDLoss", 32 | "JSDLoss", 33 | "TVDLoss", 34 | "HingeLoss", 35 | "LogisticRankingLoss", 36 | "HiddenStateCosineLoss", 37 | "HiddenStateMSELoss", 38 | "CrossEntropyLoss", 39 | "LossFunctionBase", 40 | "ALL_LOSS_CLASSES", 41 | ] 42 | -------------------------------------------------------------------------------- /distillkit/compression/densify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from distillkit.configuration import MissingProbabilityHandling 4 | 5 | 6 | def densify( 7 | top_indices: torch.LongTensor, 8 | top_values: torch.Tensor, 9 | vocab_size: int, 10 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 11 | renormalize: bool = False, 12 | fill_value: float = -float("inf"), 13 | ) -> torch.Tensor: 14 | """Expand a sparse set of logits to a dense tensor. 15 | 16 | Fills missing logits with -inf.""" 17 | 18 | if missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM: 19 | # compute total probability mass 20 | log_total_mass = torch.logsumexp(top_values, dim=-1, keepdim=True) 21 | missing = 1 - log_total_mass.exp() 22 | # equally spread missing mass over all missing tokens 23 | fill_value = torch.log( 24 | (missing / (vocab_size - top_values.shape[-1])).clamp(min=1e-8) 25 | ) 26 | elif missing == MissingProbabilityHandling.ZERO: 27 | # just fill with -inf (or whatever fill_value is) 28 | pass 29 | 30 | expanded_logits = ( 31 | torch.zeros( 32 | tuple(top_indices.shape[:-1]) + (vocab_size,), 33 | device=top_indices.device, 34 | dtype=top_values.dtype, 35 | ) 36 | + fill_value 37 | ) 38 | expanded_logits.scatter_(-1, top_indices, top_values) 39 | if renormalize: 40 | expanded_logits = torch.log_softmax(expanded_logits, dim=-1) 41 | return expanded_logits 42 | -------------------------------------------------------------------------------- /examples/llama_70b_base.yml: -------------------------------------------------------------------------------- 1 | project_name: llama3p3-70b-dsv3tok 2 | model: /workspace/models/Llama-3.3-70B-DSV3Tok 3 | model_auto_class: AutoModelForCausalLM 4 | output_path: /workspace/models/Llama-3.3-70B-DSV3Tok-BaseDistill-T1 5 | resize_embeddings_to_multiple_of: 64 6 | use_flash_attention: true 7 | sequence_length: 8192 8 | dataset: 9 | train_dataset: 10 | repo_id: arcee-ai/DeepSeek-DCLM-Logits-Packed-8192 11 | split: train 12 | seed: 9 13 | prepacked: true 14 | teacher: 15 | kind: dataset 16 | legacy_logit_compression: 17 | vocab_size: 129280 18 | k: 128 19 | exact_k: 32 20 | polynomial_degree: 8 21 | with_sqrt_term: false 22 | term_dtype: float32 23 | invert_polynomial: true 24 | loss_functions: 25 | - function: cross_entropy 26 | weight: 0.5 27 | - function: kl 28 | weight: 0.5 29 | temperature: 1.0 30 | missing_probability_handling: zero 31 | sparse_chunk_length: 1024 32 | 33 | training_args: 34 | num_train_epochs: 1 35 | per_device_train_batch_size: 1 36 | gradient_accumulation_steps: 8 37 | save_steps: 512 38 | save_total_limit: 4 39 | logging_steps: 1 40 | learning_rate: 2.e-7 41 | weight_decay: 0.01 42 | warmup_ratio: 0.025 43 | lr_scheduler_type: linear 44 | bf16: true 45 | remove_unused_columns: false 46 | optim: paged_adamw_8bit 47 | gradient_checkpointing: true 48 | gradient_checkpointing_kwargs: 49 | use_reentrant: false 50 | report_to: wandb 51 | push_to_hub: true 52 | hub_model_id: arcee-train/Llama-3.3-70B-DSV3Tok-BaseDistill-T1 53 | hub_strategy: every_save 54 | -------------------------------------------------------------------------------- /examples/mistral3.yaml: -------------------------------------------------------------------------------- 1 | project_name: mistral3-24b-dsv3tok 2 | model: arcee-train/Mistral3-24B-DSV3Tok-BaseDistill-Multilingual-Slerp 3 | model_auto_class: AutoModelForImageTextToText 4 | output_path: /workspace/models/Mistral3-24B-DSV3Tok-MixedReasoning-V1 5 | resize_embeddings_to_multiple_of: 64 6 | use_flash_attention: true 7 | frozen_modules: 8 | - vision_tower 9 | - multi_modal_projector 10 | sequence_length: 16384 11 | dataset: 12 | train_dataset: 13 | repo_id: arcee-ai/DeepSeek-MixedModeReasoning-Logits-Packed-16384 14 | split: train 15 | seed: 58347 16 | prepacked: true 17 | loss_functions: 18 | - function: cross_entropy 19 | weight: 0.5 20 | - function: kl 21 | weight: 0.5 22 | temperature: 1.0 23 | missing_probability_handling: zero 24 | sparse_chunk_length: 1024 25 | functionary_packing: true 26 | teacher: 27 | kind: dataset 28 | vocab_size: 129280 29 | legacy_logit_compression: 30 | vocab_size: 129280 31 | k: 32 32 | exact_k: 32 33 | polynomial_degree: 0 34 | with_sqrt_term: false 35 | term_dtype: float32 36 | invert_polynomial: true 37 | training_args: 38 | num_train_epochs: 1 39 | per_device_train_batch_size: 1 40 | gradient_accumulation_steps: 8 41 | save_steps: 512 42 | save_total_limit: 4 43 | logging_steps: 1 44 | learning_rate: 2.e-7 45 | weight_decay: 0.05 46 | warmup_ratio: 0.025 47 | lr_scheduler_type: linear 48 | bf16: true 49 | remove_unused_columns: false 50 | optim: paged_adamw_8bit 51 | gradient_checkpointing: true 52 | gradient_checkpointing_kwargs: 53 | use_reentrant: false 54 | report_to: wandb 55 | push_to_hub: true 56 | hub_model_id: arcee-train/Mistral3-24B-DSV3Tok-MixedReasoning-V1 57 | hub_strategy: every_save 58 | -------------------------------------------------------------------------------- /tests/test_bitpack.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from distillkit.compression.bitpack import pack_to_bytes, unpack_from_bytes 5 | 6 | 7 | def test_pack_unpack_elem_bits_8(): 8 | x = torch.tensor([255, 128, 64, 0], dtype=torch.long) 9 | packed = pack_to_bytes(x, 8) 10 | unpacked = unpack_from_bytes(packed, 8, x.size(-1)) 11 | assert (x == unpacked).all() 12 | 13 | 14 | def test_pack_unpack_elem_bits_3(): 15 | x = torch.tensor([5, 3], dtype=torch.long) # 0b101, 0b011 16 | packed = pack_to_bytes(x, 3) 17 | assert packed.size(-1) == 1 18 | unpacked = unpack_from_bytes(packed, 3, x.size(-1)) 19 | assert (x == unpacked).all() 20 | 21 | 22 | def test_pack_unpack_elem_bits_1(): 23 | x = torch.tensor([1, 0, 1, 1, 0, 1, 0, 0, 1], dtype=torch.long) # 9 bits → 2 bytes 24 | packed = pack_to_bytes(x, 1) 25 | assert packed.size(-1) == 2 26 | unpacked = unpack_from_bytes(packed, 1, x.size(-1)) 27 | assert (x == unpacked).all() 28 | 29 | 30 | def test_pack_unpack_elem_bits_16(): 31 | x = torch.tensor([65535, 32768], dtype=torch.long) 32 | packed = pack_to_bytes(x, 16) 33 | unpacked = unpack_from_bytes(packed, 16, x.size(-1)) 34 | assert (x == unpacked).all() 35 | 36 | 37 | @pytest.mark.parametrize("elem_bits", [5, 9, 17, 53]) 38 | def test_pack_unpack_random(elem_bits): 39 | for _ in range(10): 40 | original_num_elements = torch.randint(100, 128, (1,)).item() 41 | x = torch.randint(0, 2**elem_bits, (original_num_elements,), dtype=torch.long) 42 | packed = pack_to_bytes(x, elem_bits) 43 | unpacked = unpack_from_bytes(packed, elem_bits, original_num_elements) 44 | assert x.shape == unpacked.shape, ( 45 | f"Shape mismatch for elem_bits={elem_bits}, original_num_elements={original_num_elements}" 46 | ) 47 | assert (x == unpacked).all(), ( 48 | f"Random test failed for elem_bits={elem_bits}, original_num_elements={original_num_elements}" 49 | ) 50 | -------------------------------------------------------------------------------- /distillkit/hsd_mapping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import transformers 4 | 5 | 6 | class HiddenStateMapping: 7 | layer_mapping: list[tuple[int, int]] 8 | hidden_state_mapping: torch.nn.ModuleList | None 9 | 10 | def __init__( 11 | self, 12 | student: transformers.PreTrainedModel, 13 | teacher_hidden_size: int, 14 | layer_mapping: list[tuple[int, int]], 15 | init_strategy: str = "xavier", 16 | force_projection: bool = False, 17 | ): 18 | student_hidden_size = student.config.hidden_size 19 | need_projection = force_projection or ( 20 | teacher_hidden_size != student_hidden_size 21 | ) 22 | 23 | self.layer_mapping = layer_mapping 24 | if need_projection: 25 | self.projections = nn.ModuleList( 26 | [ 27 | nn.Linear(student_hidden_size, teacher_hidden_size, bias=False) 28 | for _ in layer_mapping 29 | ] 30 | ) 31 | 32 | # init projections 33 | for proj in self.projections: 34 | if init_strategy == "xavier": 35 | nn.init.xavier_uniform_(proj.weight) 36 | elif init_strategy == "kaiming": 37 | nn.init.kaiming_uniform_(proj.weight, nonlinearity="linear") 38 | elif init_strategy == "zero": 39 | nn.init.zeros_(proj.weight) 40 | elif init_strategy == "identity": 41 | # Initialize as truncated identity matrix 42 | nn.init.zeros_(proj.weight) 43 | min_dim = min(student_hidden_size, teacher_hidden_size) 44 | with torch.no_grad(): 45 | proj.weight[:min_dim, :min_dim] = torch.eye(min_dim) 46 | else: 47 | raise ValueError(f"Unknown projection_init: {init_strategy}") 48 | 49 | # slap 'em on the student so they're trained and saved 50 | student.add_module("distillation_projections", self.projections) 51 | else: 52 | self.projections = None 53 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/hingeloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.modeling_outputs import CausalLMOutput 3 | from typing_extensions import override 4 | 5 | from distillkit.hsd_mapping import HiddenStateMapping 6 | from distillkit.lossfuncs.common import LossFunctionBase 7 | from distillkit.signals import DenseSignal, TeacherSignal 8 | 9 | 10 | def sparse_hinge_loss( 11 | logits: torch.Tensor, 12 | target_ids: torch.LongTensor, 13 | target_values: torch.Tensor, 14 | mask: torch.Tensor | None = None, 15 | eps: float = 1e-8, 16 | log_target: bool = True, 17 | margin: float | None = None, 18 | ) -> torch.Tensor: 19 | # Validate input shapes 20 | assert logits.size()[:2] == target_ids.size()[:2] == target_values.size()[:2], ( 21 | "Batch and sequence length must match" 22 | ) 23 | B, S, K = target_ids.shape 24 | assert target_values.shape == (B, S, K), "Target values shape must match target_ids" 25 | 26 | # Gather the logits for the target_ids 27 | student_probs = torch.softmax(logits, dim=-1) # Shape: [B, S, V] 28 | student_target_probs = student_probs.gather(-1, target_ids) # Shape: [B, S, K] 29 | 30 | if log_target: 31 | teacher_probs = torch.exp(target_values) 32 | else: 33 | teacher_probs = target_values 34 | 35 | prob_diff = student_target_probs.unsqueeze(-1) - student_target_probs.unsqueeze( 36 | -2 37 | ) # Shape: [B, S, K, K] 38 | if margin is None: 39 | margin_values = teacher_probs.unsqueeze(-1) - teacher_probs.unsqueeze(-2) 40 | else: 41 | margin_values = margin * torch.ones_like(prob_diff) 42 | loss_terms = margin_values - prob_diff 43 | max_terms = torch.relu(loss_terms) # Shape: [B, S, K, K] 44 | 45 | actually_supported_mask = teacher_probs > eps 46 | supported_k = actually_supported_mask.unsqueeze(-1) # [B, S, K, 1] 47 | supported_l = actually_supported_mask.unsqueeze(-2) # [B, S, 1, K] 48 | pair_is_genuinely_supported_mask = supported_k & supported_l # [B, S, K, K] 49 | 50 | preference_mask = teacher_probs.unsqueeze(-1) > (teacher_probs.unsqueeze(-2) + eps) 51 | valid_mask = preference_mask & pair_is_genuinely_supported_mask 52 | active_terms = max_terms * valid_mask.float() 53 | 54 | num_contributing_pairs = valid_mask.float() 55 | if mask is not None: 56 | mask_expanded = mask.unsqueeze(-1).unsqueeze(-1).float() 57 | active_terms = active_terms * mask_expanded 58 | num_contributing_pairs = num_contributing_pairs * mask_expanded 59 | 60 | total_loss = active_terms.sum() / (num_contributing_pairs.sum() + eps) 61 | return total_loss 62 | 63 | 64 | class HingeLoss(LossFunctionBase): 65 | margin: float 66 | 67 | @override 68 | @classmethod 69 | def name(cls) -> str: 70 | return "hinge" 71 | 72 | @override 73 | def __init__(self, margin: float) -> None: 74 | self.margin = margin 75 | 76 | @override 77 | def __call__( 78 | self, 79 | student_outputs: CausalLMOutput, 80 | signal: TeacherSignal, 81 | mask: torch.Tensor | None = None, 82 | hidden_state_mapping: HiddenStateMapping | None = None, 83 | num_items_in_batch: int | None = None, 84 | ) -> torch.Tensor: 85 | if isinstance(signal, DenseSignal): 86 | raise RuntimeError("Hinge loss is not supported for dense predictions") 87 | 88 | return sparse_hinge_loss( 89 | logits=student_outputs.logits, 90 | target_ids=signal.sparse_ids, 91 | target_values=signal.sparse_values, 92 | mask=mask, 93 | log_target=signal.log_values, 94 | margin=self.margin, 95 | ) 96 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/hidden_state.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers.modeling_outputs import CausalLMOutput 4 | from typing_extensions import override 5 | 6 | from distillkit.hsd_mapping import HiddenStateMapping 7 | from distillkit.lossfuncs.common import ( 8 | LossFunctionBase, 9 | ) 10 | from distillkit.signals import TeacherSignal 11 | 12 | 13 | def compute_hs_loss( 14 | kind: str, 15 | student_outputs: CausalLMOutput, 16 | signal: TeacherSignal, 17 | mask: torch.Tensor | None = None, 18 | hidden_state_mapping: HiddenStateMapping | None = None, 19 | ): 20 | assert hidden_state_mapping is not None, ( 21 | "Hidden state losses require HiddenStateMapping" 22 | ) 23 | assert len(hidden_state_mapping.layer_mapping) > 0, ( 24 | "No layers specified in hidden state mapping" 25 | ) 26 | assert student_outputs.hidden_states is not None 27 | assert signal.hidden_states is not None 28 | 29 | if mask is None: 30 | mask = torch.ones( 31 | student_outputs.hidden_states[0].shape[:-1], 32 | dtype=torch.bool, 33 | device=student_outputs.hidden_states[0].device, 34 | ) 35 | 36 | if mask is not None and mask.dim() == 2: 37 | mask = mask.unsqueeze(-1) 38 | 39 | total_loss = torch.tensor(0.0, device=student_outputs.hidden_states[0].device) 40 | for i, (student_layer_idx, teacher_layer_idx) in enumerate( 41 | hidden_state_mapping.layer_mapping 42 | ): 43 | student_h = student_outputs.hidden_states[student_layer_idx] 44 | teacher_h = signal.hidden_states[teacher_layer_idx] 45 | 46 | if hidden_state_mapping.projections is not None: 47 | student_h = hidden_state_mapping.projections[i](student_h) 48 | 49 | if kind == "mse": 50 | squared_error = (student_h - teacher_h) ** 2 51 | masked_error = squared_error * mask 52 | layer_loss = masked_error.sum() / (mask.sum() * student_h.shape[-1]) 53 | elif kind == "cosine": 54 | cosine_sim = F.cosine_similarity(student_h, teacher_h, dim=-1) 55 | cosine_distance = (1 - cosine_sim) * mask.squeeze(-1) 56 | layer_loss = cosine_distance.sum() / mask.sum() 57 | else: 58 | raise RuntimeError(f"Unimplemented hidden state loss type {repr(kind)}") 59 | 60 | total_loss += layer_loss 61 | 62 | return total_loss / len(hidden_state_mapping.layer_mapping) 63 | 64 | 65 | class HiddenStateCosineLoss(LossFunctionBase): 66 | @override 67 | @classmethod 68 | def name(cls) -> str: 69 | return "hs_cosine" 70 | 71 | @override 72 | def requires_hidden_states(self) -> bool: 73 | return True 74 | 75 | @override 76 | def __init__(self): ... 77 | 78 | @override 79 | def __call__( 80 | self, 81 | student_outputs: CausalLMOutput, 82 | signal: TeacherSignal, 83 | mask: torch.Tensor | None = None, 84 | hidden_state_mapping: HiddenStateMapping | None = None, 85 | num_items_in_batch: int | None = None, 86 | ) -> torch.Tensor: 87 | return compute_hs_loss( 88 | "cosine", student_outputs, signal, mask, hidden_state_mapping 89 | ) 90 | 91 | 92 | class HiddenStateMSELoss(LossFunctionBase): 93 | @override 94 | @classmethod 95 | def name(cls) -> str: 96 | return "hs_mse" 97 | 98 | @override 99 | def requires_hidden_states(self) -> bool: 100 | return True 101 | 102 | @override 103 | def __init__(self): ... 104 | 105 | @override 106 | def __call__( 107 | self, 108 | student_outputs: CausalLMOutput, 109 | signal: TeacherSignal, 110 | mask: torch.Tensor | None = None, 111 | hidden_state_mapping: HiddenStateMapping | None = None, 112 | num_items_in_batch: int | None = None, 113 | ) -> torch.Tensor: 114 | return compute_hs_loss( 115 | "mse", student_outputs, signal, mask, hidden_state_mapping 116 | ) 117 | -------------------------------------------------------------------------------- /distillkit/compression/bitpack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pack_to_bytes( 5 | x: torch.LongTensor, 6 | elem_bits: int, 7 | ) -> torch.ByteTensor: 8 | """ 9 | Pack a tensor of integers into a byte tensor. 10 | 11 | Args: 12 | x (torch.LongTensor): The input tensor of integers, with shape (..., N). 13 | elem_bits (int): The number of bits per element. Must be between 1 and 64. 14 | 15 | Returns: 16 | torch.ByteTensor: The packed byte tensor, with shape (..., ceil(N * elem_bits / 8)). 17 | """ 18 | assert 1 <= elem_bits <= 64, "elem_bits must be between 1 and 64" 19 | 20 | # Mask the input to elem_bits 21 | mask = (1 << elem_bits) - 1 22 | x = x & mask 23 | 24 | # Generate positions of each bit, from MSB to LSB within each element 25 | bit_positions = torch.arange(elem_bits - 1, -1, -1, device=x.device) 26 | 27 | # Expand each element into its constituent bits as uint8 (..., N, elem_bits) 28 | bits = ((x.unsqueeze(-1) >> bit_positions) & 1).to(torch.uint8) 29 | 30 | # Flatten the bits into a single bit stream (..., N * elem_bits) 31 | original_shape = x.shape 32 | bits = bits.view(*original_shape[:-1], -1) 33 | 34 | # Calculate padding needed to make total bits a multiple of 8 35 | total_bits = bits.size(-1) 36 | pad_length = (8 - (total_bits % 8)) % 8 37 | if pad_length > 0: 38 | bits = torch.nn.functional.pad(bits, (0, pad_length)) 39 | bits = bits.contiguous() # Ensure contiguous for efficient reshaping 40 | 41 | # Reshape into groups of 8 bits and convert to bytes 42 | bits = bits.view(*original_shape[:-1], -1, 8) 43 | # Precompute power as uint8 for efficiency 44 | power = torch.tensor( 45 | [128, 64, 32, 16, 8, 4, 2, 1], dtype=torch.uint8, device=x.device 46 | ) 47 | bytes = (bits * power).sum(dim=-1).to(torch.uint8) 48 | 49 | return bytes 50 | 51 | 52 | def unpack_from_bytes( 53 | bytes_tensor: torch.ByteTensor, 54 | elem_bits: int, 55 | original_num_elements: int, 56 | ) -> torch.LongTensor: 57 | """ 58 | Unpack a byte tensor back into the original integers. 59 | 60 | Args: 61 | bytes_tensor (torch.ByteTensor): The packed byte tensor, with shape (..., ceil(N * elem_bits / 8)). 62 | elem_bits (int): The number of bits per element used during packing. Must be between 1 and 64. 63 | original_num_elements (int): The number of elements in the original tensor along the last dimension (N). 64 | 65 | Returns: 66 | torch.LongTensor: The unpacked tensor of integers, with shape (..., N). 67 | """ 68 | assert 1 <= elem_bits <= 64, "elem_bits must be between 1 and 64" 69 | assert original_num_elements >= 0, "original_num_elements must be non-negative" 70 | 71 | total_bits_needed = original_num_elements * elem_bits 72 | original_shape = bytes_tensor.shape 73 | M = original_shape[-1] 74 | total_bits_available = M * 8 75 | 76 | if total_bits_needed > total_bits_available: 77 | raise ValueError( 78 | f"original_num_elements {original_num_elements} with elem_bits {elem_bits} " 79 | f"requires {total_bits_needed} bits, but only {total_bits_available} available" 80 | ) 81 | 82 | # Convert each byte to 8 bits (MSB to LSB) as uint8 83 | bit_positions = torch.arange(7, -1, -1, device=bytes_tensor.device) # 7, 6, ..., 0 84 | bits = ((bytes_tensor.unsqueeze(-1) >> bit_positions) & 1).to( 85 | torch.uint8 86 | ) # (..., M, 8) 87 | bits_flat = bits.view(*original_shape[:-1], -1) # (..., M*8) 88 | 89 | # Slice to get the needed bits and discard padding 90 | bits_needed = bits_flat[..., :total_bits_needed] 91 | 92 | # Reshape into (..., original_num_elements, elem_bits) ensuring contiguous 93 | new_shape = list(original_shape[:-1]) + [original_num_elements, elem_bits] 94 | bits_needed = bits_needed.contiguous().view(*new_shape) 95 | 96 | # Convert bits to integers using appropriate power tensor 97 | powers = 2 ** torch.arange(elem_bits - 1, -1, -1, device=bits_needed.device) 98 | result = (bits_needed * powers).sum(dim=-1).long() 99 | 100 | return result 101 | -------------------------------------------------------------------------------- /distillkit/compression/compressor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import torch 5 | 6 | from distillkit.compression.bitpack import ( 7 | pack_to_bytes, 8 | unpack_from_bytes, 9 | ) 10 | from distillkit.compression.config import ( 11 | DistributionQuantizationConfig, 12 | LegacyLogitCompressionConfig, 13 | ) 14 | from distillkit.compression.legacy import LogitCompressor as LegacyLogitCompressor 15 | from distillkit.compression.monotonic_logprobs import ( 16 | compress_monotonic_logprobs, 17 | decompress_monotonic_logprobs, 18 | ) 19 | 20 | LOG = logging.getLogger(__name__) 21 | 22 | 23 | class LogprobCompressor: 24 | config: DistributionQuantizationConfig | None 25 | legacy_compressor: LegacyLogitCompressor | None 26 | 27 | def __init__( 28 | self, 29 | config: DistributionQuantizationConfig | None = None, 30 | legacy_config: LegacyLogitCompressionConfig | None = None, 31 | ): 32 | if config is not None and legacy_config is not None: 33 | raise ValueError( 34 | "At most one of `config` or `legacy_config` should be provided." 35 | ) 36 | self.config = config 37 | if legacy_config is not None: 38 | self.legacy_compressor = LegacyLogitCompressor(legacy_config) 39 | self.vocab_index_bits = int(self.legacy_compressor.vocab_index_bits) 40 | else: 41 | self.legacy_compressor = None 42 | self.vocab_index_bits = int( 43 | torch.log2(torch.tensor(self.config.d, dtype=torch.float32)) 44 | .ceil() 45 | .item() 46 | ) 47 | 48 | def compress_from_sparse( 49 | self, indices: torch.LongTensor, logprobs: torch.Tensor 50 | ) -> dict[str, torch.Tensor]: 51 | if self.legacy_compressor is not None: 52 | packed_indices, exact_values, coeffs = ( 53 | self.legacy_compressor.compress_from_sparse(indices, logprobs) 54 | ) 55 | return { 56 | "packed_indices": packed_indices, 57 | "exact_values": exact_values, 58 | "coeffs": coeffs, 59 | } 60 | elif self.config is None: 61 | raise ValueError("No config provided for compression.") 62 | 63 | # enforce monotonicity 64 | _, sorted_indices = torch.sort(logprobs, descending=True, dim=-1) 65 | sorted_values = logprobs.gather(-1, sorted_indices) 66 | sorted_indices = indices.gather(-1, sorted_indices) 67 | 68 | logprob_bytes = compress_monotonic_logprobs( 69 | sorted_values, 70 | self.config, 71 | ) 72 | indices_bytes = pack_to_bytes( 73 | sorted_indices, 74 | self.vocab_index_bits, 75 | ) 76 | return { 77 | "compressed_logprobs": logprob_bytes, 78 | "bytepacked_indices": indices_bytes, 79 | } 80 | 81 | def decompress_to_sparse( 82 | self, 83 | row: dict[str, Any], 84 | ) -> tuple[torch.LongTensor, torch.Tensor]: 85 | if "top_values" in row and "token_ids" in row: 86 | return row["token_ids"], row["top_values"] 87 | elif "packed_indices" in row and "exact_values" in row and "coeffs" in row: 88 | if self.legacy_compressor is None: 89 | raise ValueError("Row is in legacy format, but compressor is not.") 90 | return self.legacy_compressor.decompress_to_sparse( 91 | row["packed_indices"], row["exact_values"], row["coeffs"] 92 | ) 93 | elif "compressed_logprobs" in row and "bytepacked_indices" in row: 94 | if self.config is None: 95 | raise ValueError("Row is in new format, but compressor is not.") 96 | logprobs = decompress_monotonic_logprobs( 97 | row["compressed_logprobs"].to(torch.uint8), 98 | self.config, 99 | ) 100 | indices = unpack_from_bytes( 101 | row["bytepacked_indices"].to(torch.uint8), 102 | self.vocab_index_bits, 103 | original_num_elements=logprobs.shape[-1], 104 | ) 105 | return indices, logprobs 106 | else: 107 | raise ValueError( 108 | "Unknown row format. Expected either raw top-k, legacy compressed, or new compressed format." 109 | ) 110 | 111 | def compress( 112 | self, 113 | logprobs: torch.Tensor, 114 | ) -> dict[str, torch.Tensor]: 115 | if self.legacy_compressor is not None: 116 | k = self.legacy_compressor.config.k 117 | else: 118 | k = self.config.k 119 | 120 | sparse_logprobs, sparse_indices = torch.topk(logprobs, k, dim=-1) 121 | return self.compress_from_sparse(sparse_indices, sparse_logprobs) 122 | -------------------------------------------------------------------------------- /distillkit/signals.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from dataclasses import dataclass 3 | from typing import Any, TypeAlias 4 | 5 | import torch 6 | import transformers 7 | from typing_extensions import override 8 | 9 | from distillkit.compression import LogprobCompressor 10 | 11 | 12 | @dataclass 13 | class TeacherSignalBase: 14 | generation_temperature: float 15 | hidden_states: tuple[torch.Tensor, ...] | None 16 | vocab_size: int 17 | 18 | 19 | @dataclass 20 | class DenseSignal(TeacherSignalBase): 21 | logits: torch.Tensor 22 | 23 | 24 | @dataclass 25 | class SparseSignal(TeacherSignalBase): 26 | sparse_ids: torch.LongTensor 27 | sparse_values: torch.Tensor 28 | log_values: bool # if True, values are logprobs 29 | 30 | 31 | TeacherSignal: TypeAlias = SparseSignal | DenseSignal 32 | 33 | 34 | class SignalSource(ABC): 35 | @abstractmethod 36 | def supports_hidden_states(self) -> bool: ... 37 | 38 | @abstractmethod 39 | def get_signal( 40 | self, batch: dict[str, Any], return_hidden_states: bool = False 41 | ) -> TeacherSignal: ... 42 | 43 | 44 | class OfflineSignalSource(SignalSource): 45 | compressor: LogprobCompressor 46 | preapplied_temperature: float 47 | vocab_size: int 48 | log_values: bool 49 | 50 | def __init__( 51 | self, 52 | compressor: LogprobCompressor, 53 | vocab_size: int, 54 | preapplied_temperature: float = 1.0, 55 | log_values: bool = True, 56 | ): 57 | self.compressor = compressor 58 | self.vocab_size = vocab_size 59 | self.preapplied_temperature = preapplied_temperature 60 | self.log_values = log_values 61 | 62 | @override 63 | def supports_hidden_states(self) -> bool: 64 | return False 65 | 66 | @override 67 | def get_signal( 68 | self, batch: dict[str, Any], return_hidden_states: bool = False 69 | ) -> SparseSignal: 70 | if return_hidden_states: 71 | raise RuntimeError( 72 | "Hidden states requested but signal source is precomputed logits" 73 | ) 74 | with torch.no_grad(): 75 | sparse_ids, sparse_values = self.compressor.decompress_to_sparse(batch) 76 | return SparseSignal( 77 | sparse_ids=sparse_ids, 78 | sparse_values=sparse_values, 79 | log_values=self.log_values, 80 | generation_temperature=self.preapplied_temperature, 81 | hidden_states=None, 82 | vocab_size=self.vocab_size, 83 | ) 84 | 85 | 86 | class OnlineSignalSource(SignalSource): 87 | teacher_model: transformers.PreTrainedModel 88 | vocab_size: int 89 | sparsify_top_k: int | None 90 | teacher_kwargs: dict[str, Any] 91 | 92 | def __init__( 93 | self, 94 | teacher_model: transformers.PreTrainedModel, 95 | vocab_size: int, 96 | sparsify_top_k: int | None = None, 97 | teacher_kwargs: dict[str, Any] | None = None, 98 | ): 99 | self.teacher_model = teacher_model.eval() 100 | self.vocab_size = vocab_size 101 | self.sparsify_top_k = sparsify_top_k 102 | self.teacher_kwargs = teacher_kwargs or {} 103 | 104 | for param in self.teacher_model.parameters(): 105 | param.requires_grad_(False) 106 | 107 | @override 108 | def supports_hidden_states(self) -> bool: 109 | return True 110 | 111 | @override 112 | def get_signal( 113 | self, batch: dict[str, Any], return_hidden_states: bool = False 114 | ) -> TeacherSignal: 115 | with torch.no_grad(): 116 | teacher_outputs = self.teacher_model( 117 | input_ids=batch["input_ids"], 118 | attention_mask=batch.get("attention_mask", None), 119 | output_hidden_states=return_hidden_states, 120 | **self.teacher_kwargs, 121 | ) 122 | 123 | real_vocab_size = teacher_outputs.logits.shape[-1] 124 | vocab_size = min(real_vocab_size, self.vocab_size) 125 | 126 | if self.sparsify_top_k is not None: 127 | logprobs = torch.log_softmax( 128 | teacher_outputs.logits[..., :vocab_size], dim=-1 129 | ) 130 | values, indices = torch.topk(logprobs, self.sparsify_top_k, dim=-1) 131 | return SparseSignal( 132 | sparse_ids=indices, 133 | sparse_values=values, 134 | log_values=True, 135 | generation_temperature=1.0, 136 | hidden_states=teacher_outputs.hidden_states, 137 | vocab_size=vocab_size, 138 | ) 139 | 140 | return DenseSignal( 141 | logits=teacher_outputs.logits[..., :vocab_size], 142 | hidden_states=teacher_outputs.hidden_states, 143 | generation_temperature=1.0, 144 | vocab_size=vocab_size, 145 | ) 146 | -------------------------------------------------------------------------------- /distillkit/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Charles O. Goddard 2 | 3 | import torch 4 | from transformers import ( 5 | PreTrainedModel, 6 | ) 7 | from trl import SFTTrainer 8 | 9 | from distillkit.configuration import DistillationRunConfig, LossFunctionConfig 10 | from distillkit.hsd_mapping import HiddenStateMapping 11 | from distillkit.lossfuncs import ALL_LOSS_CLASSES, LossFunctionBase 12 | from distillkit.signals import OnlineSignalSource, SignalSource, TeacherSignal 13 | 14 | 15 | def create_loss_func(cfg: LossFunctionConfig) -> LossFunctionBase: 16 | for cls in ALL_LOSS_CLASSES: 17 | if cfg.function.value == cls.name(): 18 | return cls( 19 | **cfg.model_dump(exclude=["function", "weight"], exclude_none=True) 20 | ) 21 | raise RuntimeError(f"Unknown loss function '{cfg.function}'") 22 | 23 | 24 | class DistillationTrainer(SFTTrainer): 25 | def __init__( 26 | self, 27 | model: PreTrainedModel, 28 | config: DistillationRunConfig, 29 | signal_source: SignalSource, 30 | true_vocab_size: int, 31 | *args, 32 | hidden_state_mapping: HiddenStateMapping | None = None, 33 | **kwargs, 34 | ): 35 | super().__init__(model, *args, **kwargs) 36 | self.true_vocab_size = true_vocab_size 37 | self.config = config 38 | 39 | self.loss_functions = [create_loss_func(lfc) for lfc in config.loss_functions] 40 | self.need_hidden_states = any( 41 | lf.requires_hidden_states() for lf in self.loss_functions 42 | ) 43 | 44 | self.signal_source = signal_source 45 | self.hidden_state_mapping = hidden_state_mapping 46 | 47 | if self.need_hidden_states and not self.signal_source.supports_hidden_states(): 48 | raise ValueError( 49 | "Configuration requests hidden state loss, but the provided Teacher " 50 | "(Offline/Dataset) does not support hidden states." 51 | ) 52 | 53 | if (self.hidden_state_mapping is None) and self.need_hidden_states: 54 | raise ValueError( 55 | "Must define a hidden state mapping to use hidden state losses." 56 | ) 57 | 58 | if isinstance(self.signal_source, OnlineSignalSource): 59 | self.signal_source.teacher_model = self.signal_source.teacher_model.to( 60 | self.accelerator.device 61 | ) 62 | 63 | self.model_accepts_loss_kwargs = False 64 | 65 | def compute_loss( 66 | self, 67 | model: PreTrainedModel, 68 | inputs: dict[str, torch.Tensor], 69 | return_outputs: bool = False, 70 | **kwargs, 71 | ): 72 | if "labels" not in inputs: 73 | inputs["labels"] = inputs["input_ids"] 74 | if self.config.dataset.eos_label_token_ids: 75 | inputs["labels"] = inputs["labels"].clone() 76 | for tok_id in self.config.dataset.eos_label_token_ids: 77 | inputs["labels"][inputs["labels"] == tok_id] = ( 78 | self.model.config.eos_token_id 79 | ) 80 | 81 | student_model = model.module if hasattr(model, "module") else model 82 | student_outputs = student_model( 83 | **{ 84 | k: inputs[k] 85 | for k in ["input_ids", "attention_mask", "labels"] 86 | if k in inputs 87 | }, 88 | return_dict=True, 89 | output_hidden_states=self.need_hidden_states, 90 | **kwargs, 91 | ) 92 | if student_outputs.logits.shape[-1] != self.true_vocab_size: 93 | # truncate any extra logits from padding 94 | student_outputs.logits = student_outputs.logits[..., : self.true_vocab_size] 95 | 96 | total_loss = self.total_distillation_loss( 97 | student_outputs, 98 | inputs, 99 | num_items_in_batch=None, 100 | ) 101 | return (total_loss, student_outputs) if return_outputs else total_loss 102 | 103 | def total_distillation_loss( 104 | self, student_outputs, inputs, num_items_in_batch: int | None = None 105 | ): 106 | valid_mask = (inputs["labels"] >= 0).unsqueeze(-1) 107 | signal: TeacherSignal = self.signal_source.get_signal( 108 | inputs, 109 | return_hidden_states=self.need_hidden_states, 110 | ) 111 | 112 | losses = [] 113 | loss_fns = [] 114 | weights = [] 115 | for idx, loss_fn in enumerate(self.loss_functions): 116 | cfg = self.config.loss_functions[idx] 117 | loss = loss_fn( 118 | student_outputs, 119 | signal, 120 | mask=valid_mask, 121 | hidden_state_mapping=self.hidden_state_mapping, 122 | num_items_in_batch=num_items_in_batch, 123 | ) 124 | losses.append(loss) 125 | loss_fns.append(cfg.function.value) 126 | weights.append(cfg.weight) 127 | 128 | total_loss = 0.0 129 | for loss, weight in zip(losses, weights): 130 | total_loss += loss * weight 131 | total_loss = total_loss / sum(weights) 132 | self.log( 133 | { 134 | f"distillation_loss/{idx + 1}_{loss_fn}": loss.item() 135 | for idx, (loss, loss_fn) in enumerate(zip(losses, loss_fns)) 136 | } 137 | ) 138 | return total_loss 139 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/logistic_ranking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers.modeling_outputs import CausalLMOutput 4 | from typing_extensions import override 5 | 6 | from distillkit.hsd_mapping import HiddenStateMapping 7 | from distillkit.lossfuncs.common import LossFunctionBase 8 | from distillkit.signals import DenseSignal, TeacherSignal 9 | 10 | 11 | def sparse_logistic_ranking_loss( 12 | student_logits: torch.Tensor, 13 | teacher_target_ids: torch.LongTensor, 14 | teacher_target_values: torch.Tensor, 15 | log_target: bool = True, 16 | sequence_mask: torch.Tensor | None = None, 17 | eps: float = 1e-9, 18 | ) -> torch.Tensor: 19 | """ 20 | Computes a logistic ranking loss between student logits and sparse teacher probabilities. 21 | The loss encourages the student to rank the teacher's *actually supported* tokens 22 | (those with probability > eps) in the same relative order as the teacher. 23 | 24 | Args: 25 | student_logits: Logits predicted by the student model. 26 | Shape: [batch_size, seq_len, vocab_size] 27 | teacher_target_ids: Indices of the tokens in the teacher's sparse set (up to k_sparse). 28 | 0 <= teacher_target_ids < vocab_size. 29 | Shape: [batch_size, seq_len, k_sparse] 30 | teacher_target_values: Probabilities or log-probabilities for the tokens in 31 | teacher_target_ids. Shape: [batch_size, seq_len, k_sparse] 32 | It's assumed that if k_sparse is a max capacity and fewer items 33 | are truly supported, the corresponding values here reflect that 34 | (e.g., 0 for probability, or very small/negative logprob). 35 | log_target: Boolean, True if teacher_target_values are log-probabilities. 36 | sequence_mask: Optional boolean mask for sequence positions. True indicates active, 37 | False indicates padded/ignored. Shape: [batch_size, seq_len] 38 | eps: Small value to avoid numerical issues. Default is 1e-9. 39 | 40 | Returns: 41 | A scalar tensor representing the mean logistic ranking loss over active sequence positions. 42 | """ 43 | if log_target: 44 | teacher_probs = torch.exp(teacher_target_values) 45 | else: 46 | teacher_probs = teacher_target_values 47 | 48 | student_logits_at_targets = torch.gather( 49 | student_logits, 50 | dim=2, 51 | index=teacher_target_ids, 52 | ) # [B, S, k_sparse] 53 | 54 | # Mask for tokens that are *actually* supported by the teacher (prob > eps) 55 | actually_supported_mask = teacher_probs > eps # [B, S, k_sparse] 56 | 57 | # Teacher preferences: teacher_probs_k > teacher_probs_l 58 | # teacher_probs_k.shape [B, S, k_sparse, 1], teacher_probs_l.shape [B, S, 1, k_sparse] 59 | teacher_prefers_k_over_l_mask = teacher_probs.unsqueeze(-1) > ( 60 | teacher_probs.unsqueeze(-2) + eps 61 | ) # [B, S, k_sparse, k_sparse] 62 | student_logit_diff_k_minus_l = student_logits_at_targets.unsqueeze( 63 | -1 64 | ) - student_logits_at_targets.unsqueeze(-2) # [B, S, k_sparse, k_sparse] 65 | 66 | # Pair activity mask: both tokens in the pair must be actually supported 67 | supported_k = actually_supported_mask.unsqueeze(-1) # [B, S, k_sparse, 1] 68 | supported_l = actually_supported_mask.unsqueeze(-2) # [B, S, 1, k_sparse] 69 | pair_is_supported_mask = supported_k & supported_l # [B, S, k_sparse, k_sparse] 70 | 71 | # Valid preference pairs: teacher prefers k over l and both k and l are supported 72 | valid_preference_pair_mask = teacher_prefers_k_over_l_mask & pair_is_supported_mask 73 | 74 | pair_loss = F.softplus(-student_logit_diff_k_minus_l) # [B, S, k_sparse, k_sparse] 75 | masked_pair_loss = pair_loss * valid_preference_pair_mask.float() 76 | sum_pair_loss_per_pos = masked_pair_loss.sum(dim=(2, 3)) # [B, S] 77 | 78 | if sequence_mask is None: 79 | active_sequence_mask = torch.ones_like( 80 | sum_pair_loss_per_pos, dtype=torch.bool, device=student_logits.device 81 | ) 82 | else: 83 | active_sequence_mask = sequence_mask.bool() # Ensure boolean 84 | 85 | final_summed_loss = (sum_pair_loss_per_pos * active_sequence_mask.float()).sum() 86 | 87 | num_contributing_pairs = ( 88 | valid_preference_pair_mask.float() 89 | * active_sequence_mask.float().unsqueeze(-1).unsqueeze(-1) 90 | ).sum() 91 | loss = final_summed_loss / (num_contributing_pairs + eps) 92 | 93 | return loss 94 | 95 | 96 | class LogisticRankingLoss(LossFunctionBase): 97 | @override 98 | @classmethod 99 | def name(cls) -> str: 100 | return "logistic_ranking" 101 | 102 | @override 103 | def __init__(self) -> None: 104 | pass 105 | 106 | @override 107 | def __call__( 108 | self, 109 | student_outputs: CausalLMOutput, 110 | signal: TeacherSignal, 111 | mask: torch.Tensor | None = None, 112 | hidden_state_mapping: HiddenStateMapping | None = None, 113 | num_items_in_batch: int | None = None, 114 | ) -> torch.Tensor: 115 | if isinstance(signal, DenseSignal): 116 | raise RuntimeError( 117 | "Logistic ranking loss is not supported for dense predictions" 118 | ) 119 | 120 | return sparse_logistic_ranking_loss( 121 | student_logits=student_outputs.logits, 122 | teacher_target_ids=signal.sparse_ids, 123 | teacher_target_values=signal.sparse_values, 124 | sequence_mask=mask, 125 | log_target=signal.log_values, 126 | ) 127 | -------------------------------------------------------------------------------- /distillkit/compression/config.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | from pydantic import BaseModel, Field, model_validator 5 | 6 | 7 | class SpecialTerm(Enum): 8 | SQRT = "sqrt" 9 | EXP = "exp" 10 | 11 | 12 | class TermDtype(Enum): 13 | FLOAT16 = "float16" 14 | BFLOAT16 = "bfloat16" 15 | FLOAT32 = "float32" 16 | FLOAT64 = "float64" 17 | 18 | def bit_width(self) -> int: 19 | if self == TermDtype.FLOAT16: 20 | return 16 21 | elif self == TermDtype.BFLOAT16: 22 | return 16 23 | elif self == TermDtype.FLOAT32: 24 | return 32 25 | elif self == TermDtype.FLOAT64: 26 | return 64 27 | else: 28 | raise ValueError(f"Unsupported dtype: {self}") 29 | 30 | def dtype(self) -> torch.dtype: 31 | if self == TermDtype.FLOAT16: 32 | return torch.float16 33 | elif self == TermDtype.BFLOAT16: 34 | return torch.bfloat16 35 | elif self == TermDtype.FLOAT32: 36 | return torch.float32 37 | elif self == TermDtype.FLOAT64: 38 | return torch.float64 39 | else: 40 | raise ValueError(f"Unsupported dtype: {self}") 41 | 42 | 43 | class QuantizationBin(BaseModel): 44 | scale_dtype: TermDtype = Field( 45 | ..., 46 | description="Data type for the scale value.", 47 | ) 48 | element_bits: int = Field( 49 | ..., 50 | description="Number of bits per element.", 51 | gt=0, 52 | lte=64, 53 | ) 54 | num_elements: int = Field( 55 | ..., 56 | description="Number of elements in the bin.", 57 | gt=0, 58 | ) 59 | 60 | # note that element_bits can be anything from 1 to 64 61 | # *could* constrain it to 1-8, 16, 32, 64 62 | # but would be better not to 63 | 64 | 65 | class DistributionQuantizationConfig(BaseModel): 66 | d: int = Field( 67 | ..., 68 | description="Dimension of the unquantized distribution.", 69 | gt=0, 70 | ) 71 | k: int = Field( 72 | ..., 73 | description="Number of non-zero values in the quantized distribution.", 74 | gt=0, 75 | ) 76 | exact_k: int = Field( 77 | ..., 78 | description="Number of top values to store unquantized.", 79 | ge=0, 80 | ) 81 | exact_dtype: TermDtype = Field( 82 | TermDtype.FLOAT32, 83 | description="Data type for the top `exact_k` values.", 84 | ) 85 | polynomial_terms: list[SpecialTerm | int] | None = Field( 86 | None, 87 | description="Terms to use in the polynomial approximation. Integer values represent power terms, " 88 | "SpecialTerm values represent special non-polynomial terms.", 89 | ) 90 | term_dtype: TermDtype = Field( 91 | TermDtype.FLOAT32, 92 | description="Data type for the polynomial terms.", 93 | ) 94 | residual_bins: list[QuantizationBin] = Field( 95 | ..., 96 | description="List of bins for quantized residuals.", 97 | ) 98 | delta_encoding: bool = Field(True, description="Whether to use delta encoding.") 99 | error_diffusion: bool = Field( 100 | False, 101 | description="Whether to use error diffusion.", 102 | ) 103 | normalize_t: bool = Field( 104 | default=False, description="Map t to [0,1] instead of [0,k]." 105 | ) 106 | 107 | @model_validator(mode="after") 108 | def check_valid(self) -> "DistributionQuantizationConfig": 109 | if self.k < self.exact_k: 110 | raise ValueError("k must be >= exact_k") 111 | if ( 112 | self.exact_k < self.k 113 | and (not self.polynomial_terms) 114 | and (not self.residual_bins) 115 | ): 116 | raise ValueError( 117 | "If exact_k < k, at least one of polynomial_terms or residual_bins must be provided." 118 | ) 119 | approx_terms = self.k - self.exact_k 120 | bin_elems = sum([bin.num_elements for bin in self.residual_bins]) 121 | if bin_elems > approx_terms: 122 | raise ValueError( 123 | "Sum of num_elements in residual_bins must be <= k - exact_k" 124 | ) 125 | return self 126 | 127 | def logprob_bits(self) -> int: 128 | res = 0 129 | res += self.exact_k * self.exact_dtype.bit_width() 130 | res += len(self.polynomial_terms or []) * self.term_dtype.bit_width() 131 | for bin in self.residual_bins: 132 | bin_bits = bin.scale_dtype.bit_width() + bin.element_bits * bin.num_elements 133 | if bin_bits % 8 != 0: 134 | bin_bits += 8 - (bin_bits % 8) 135 | res += bin_bits 136 | if res % 8 != 0: 137 | res += 8 - (res % 8) 138 | return res 139 | 140 | def total_bits(self) -> int: 141 | lpb = self.logprob_bits() 142 | vocab_index_bits = int( 143 | torch.log2(torch.tensor(self.d, dtype=torch.float32)).ceil().item() 144 | ) 145 | index_bytes = (self.k * vocab_index_bits + 7) // 8 146 | total = lpb + index_bytes * 8 147 | return total 148 | 149 | 150 | class LegacyLogitCompressionConfig(BaseModel): 151 | """Configuration for legacy polynomial logit compression. 152 | 153 | Args: 154 | k (int): Total number of logits per token 155 | exact_k (int): Number of exact logits to keep 156 | polynomial_degree (int): Degree of the polynomial to approximate the remaining logits 157 | invert_polynomial (bool): Whether to invert the polynomial terms 158 | with_sqrt_term (bool): Whether to include a square root term in the polynomial 159 | term_dtype (str): Data type for the polynomial terms (float16, bfloat16, float32, float64) 160 | """ 161 | 162 | k: int 163 | exact_k: int 164 | polynomial_degree: int 165 | vocab_size: int 166 | invert_polynomial: bool = True 167 | with_sqrt_term: bool = False 168 | term_dtype: str = "float32" 169 | -------------------------------------------------------------------------------- /distillkit/pack_logits.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import click 4 | import datasets 5 | 6 | 7 | def pack_pass(data: dict[str, list], max_len: int) -> dict[str, list]: 8 | # pack a set of rows into a (hopefully smaller) set of batches 9 | # truncate rows that are too long 10 | row_lens = [ 11 | (len(input_ids), idx) for (idx, input_ids) in enumerate(data["input_ids"]) 12 | ] 13 | row_lens.sort() 14 | data_out = {key: [] for key in data} 15 | data_out["attention_mask"] = [] 16 | 17 | while row_lens and row_lens[-1][0] >= max_len: 18 | row_l, idx = row_lens.pop() 19 | for key in data: 20 | if key == "attention_mask": 21 | continue 22 | data_out[key].append(data[key][idx][:max_len]) 23 | assert len(data_out[key][-1]) == max_len, ( 24 | f"{key} {len(data_out[key][-1])} != {max_len}" 25 | ) 26 | data_out["attention_mask"].append([1] * max_len) 27 | 28 | if not row_lens: 29 | return data_out 30 | 31 | # greedily pack remaining rows 32 | # take longest first then fill with shortest 33 | row_lens = [row_lens[-1]] + row_lens[:-1] 34 | current_batch = {key: [] for key in data} 35 | current_batch_indices = [] 36 | current_batch["attention_mask"] = [] 37 | current_token_count = 0 38 | current_num_examples = 0 39 | while row_lens: 40 | _, idx = row_lens.pop(0) 41 | space = max_len - current_token_count 42 | tokens_to_take = min(space, len(data["input_ids"][idx])) 43 | for key in data: 44 | if key == "attention_mask": 45 | continue 46 | current_batch[key] += data[key][idx][:space] 47 | current_batch["attention_mask"] += [current_num_examples + 1] * tokens_to_take 48 | current_batch_indices.append(idx) 49 | current_token_count += tokens_to_take 50 | current_num_examples += 1 51 | if current_token_count >= max_len: 52 | token_ct = len(current_batch["input_ids"]) 53 | for key in current_batch: 54 | assert len(current_batch[key]) == token_ct, ( 55 | f"{key} {len(current_batch[key])} != {token_ct}" 56 | ) 57 | data_out[key].append(current_batch[key]) 58 | current_batch = {key: [] for key in data} 59 | current_batch["attention_mask"] = [] 60 | current_batch_indices = [] 61 | current_token_count = 0 62 | current_num_examples = 0 63 | if row_lens: 64 | row_lens = [row_lens[-1]] + row_lens[:-1] 65 | if current_num_examples > 0: 66 | for idx in current_batch_indices: 67 | for key in data: 68 | if key == "attention_mask": 69 | continue 70 | data_out[key].append(data[key][idx][:max_len]) 71 | data_out["attention_mask"].append([1] * len(data_out["input_ids"][-1])) 72 | return data_out 73 | 74 | 75 | def _truncate_row(row): 76 | res = {} 77 | if "token_ids" in row: 78 | expected_len = len(row["token_ids"]) 79 | elif "packed_indices" in row: 80 | expected_len = len(row["packed_indices"]) 81 | elif "bytepacked_indices" in row: 82 | expected_len = len(row["bytepacked_indices"]) 83 | else: 84 | raise ValueError("No token_ids or packed_indices found in row") 85 | for key in ["input_ids", "labels", "attention_mask"]: 86 | if key in row: 87 | res[key] = row[key][:expected_len] 88 | return res 89 | 90 | 91 | def iterative_packing( 92 | ds: datasets.Dataset, 93 | max_len: int, 94 | num_proc: int | None = None, 95 | max_iters: int = 4, 96 | batch_size: int = 32, 97 | ) -> datasets.Dataset: 98 | # truncate input_ids to match packed_indices 99 | # because if the logits came from vllm, we may not have logits for the final token 100 | ds = ds.map( 101 | _truncate_row, 102 | num_proc=num_proc, 103 | desc="Truncating input_ids, labels, and attention_mask", 104 | ) 105 | 106 | ds_out = [] 107 | ds_current = ds 108 | print(f"Batch size: {batch_size}") 109 | for iter_idx in range(max_iters): 110 | print(f"len(ds_current): {len(ds_current)}") 111 | print(f"len(ds_out): {len(ds_out)}") 112 | if num_proc: 113 | batched_num_proc = min(num_proc, len(ds_current) // batch_size) 114 | else: 115 | batched_num_proc = None 116 | if batched_num_proc is not None and batched_num_proc < 1: 117 | batched_num_proc = None 118 | print(f"Starting iteration {iter_idx}") 119 | print(f"num_proc: {num_proc}, batched_num_proc: {batched_num_proc}") 120 | ds_p = ds_current.map( 121 | pack_pass, 122 | batched=True, 123 | batch_size=batch_size, 124 | fn_kwargs={"max_len": max_len}, 125 | num_proc=batched_num_proc, 126 | desc=f"Packing iteration {iter_idx}", 127 | ) 128 | ds_done = ds_p.filter( 129 | lambda x: len(x) == max_len, 130 | num_proc=num_proc, 131 | input_columns=["input_ids"], 132 | desc="Finding full batches", 133 | ) 134 | ds_current = ds_p.filter( 135 | lambda x: len(x) != max_len, 136 | num_proc=num_proc, 137 | input_columns=["input_ids"], 138 | desc="Finding partial batches", 139 | ) 140 | print( 141 | f"Finished iteration {iter_idx}, {len(ds_done)} batches packed, {len(ds_current)} examples remaining" 142 | ) 143 | ds_out.append(ds_done) 144 | if len(ds_current) <= 1: 145 | break 146 | 147 | if len(ds_current) > 0: 148 | ds_out.append(ds_current) 149 | return datasets.concatenate_datasets(ds_out) 150 | 151 | 152 | @click.command("distillkit-pack-logits") 153 | @click.option("--dataset", type=str, required=True) 154 | @click.option("--split", type=str, required=False) 155 | @click.option("--max-len", type=int, required=True) 156 | @click.option("--output", type=str, required=True) 157 | @click.option("--num-proc", type=int, default=None) 158 | @click.option("--shuffle-seed", type=int, default=None) 159 | @click.option("--max-iters", type=int, default=4) 160 | @click.option("--batch-size", type=int, default=32) 161 | @click.option("--remove-columns", type=str, multiple=True) 162 | def pack_logits_cli( 163 | dataset: str, 164 | split: str | None, 165 | max_len: int, 166 | output: str, 167 | num_proc: int | None, 168 | shuffle_seed: int | None, 169 | max_iters: int, 170 | batch_size: int, 171 | remove_columns: list[str] | None, 172 | ): 173 | if os.path.exists(dataset): 174 | ds = datasets.load_from_disk(dataset) 175 | if split is not None: 176 | assert isinstance(ds, datasets.DatasetDict), ( 177 | "Expected a DatasetDict, got a Dataset" 178 | ) 179 | ds = ds[split] 180 | else: 181 | ds = datasets.load_dataset(dataset, split=split) 182 | if shuffle_seed is not None: 183 | ds = ds.shuffle(seed=shuffle_seed) 184 | if remove_columns: 185 | ds = ds.remove_columns(remove_columns) 186 | ds_out = iterative_packing( 187 | ds, max_len, num_proc=num_proc, max_iters=max_iters, batch_size=batch_size 188 | ) 189 | ds_out.save_to_disk(output) 190 | 191 | 192 | if __name__ == "__main__": 193 | pack_logits_cli() 194 | -------------------------------------------------------------------------------- /distillkit/configuration.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Arcee AI 2 | from enum import Enum 3 | from typing import Any, Literal 4 | 5 | from pydantic import BaseModel, Field 6 | from typing_extensions import TypeAlias 7 | 8 | from distillkit.compression.config import ( 9 | DistributionQuantizationConfig, 10 | LegacyLogitCompressionConfig, 11 | ) 12 | 13 | 14 | class LossFunction(str, Enum): 15 | CROSS_ENTROPY = "cross_entropy" 16 | KL = "kl" 17 | JSD = "jsd" 18 | TVD = "tvd" 19 | HINGE = "hinge" 20 | LOGISTIC_RANKING = "logistic_ranking" 21 | HIDDEN_STATE_COSINE = "hs_cosine" 22 | HIDDEN_STATE_MSE = "hs_mse" 23 | 24 | 25 | class MissingProbabilityHandling(Enum): 26 | ZERO = "zero" 27 | SYMMETRIC_UNIFORM = "symmetric_uniform" 28 | 29 | 30 | class LossFunctionConfig(BaseModel): 31 | function: LossFunction = Field( 32 | ..., 33 | description="Type of loss function to use.", 34 | ) 35 | weight: float = Field( 36 | ..., 37 | description="Weight for the loss function.", 38 | ) 39 | temperature: float | None = Field( 40 | default=None, 41 | description="Temperature for loss, if applicable.", 42 | ) 43 | missing_probability_handling: MissingProbabilityHandling | None = Field( 44 | default=None, 45 | description="Missing probability handling mode for sparse divergence functions.", 46 | ) 47 | sparse_chunk_length: int | None = Field( 48 | default=None, 49 | description="Chunk length for sparse divergence functions. None to disable chunking.", 50 | ) 51 | margin: float | None = Field( 52 | default=None, 53 | description="Margin for hinge loss, if applicable.", 54 | ) 55 | 56 | 57 | class HfRepoDataset(BaseModel): 58 | repo_id: str = Field( 59 | description="Hugging Face repository ID of the dataset.", 60 | ) 61 | revision: str | None = Field( 62 | default=None, 63 | description="Revision of the dataset to use.", 64 | ) 65 | config_name: str | None = Field( 66 | default=None, 67 | description="Configuration name of the dataset.", 68 | ) 69 | split: str | None = Field( 70 | default=None, 71 | description="Split of the dataset to use.", 72 | ) 73 | 74 | 75 | class LocalDataset(BaseModel): 76 | disk_path: str = Field( 77 | description="Path to the local dataset or dataset dict directory.", 78 | ) 79 | split: str | None = Field( 80 | default=None, 81 | description="Split of the dataset to use.", 82 | ) 83 | 84 | 85 | DatasetPath: TypeAlias = HfRepoDataset | LocalDataset 86 | 87 | 88 | class DatasetConfiguration(BaseModel): 89 | train_dataset: DatasetPath = Field( 90 | description="Dataset to use for training.", 91 | ) 92 | eval_dataset: DatasetPath | None = Field( 93 | default=None, 94 | description="Dataset to use for evaluation.", 95 | ) 96 | seed: int | None = Field( 97 | default=42, 98 | description="Random seed for shuffling datasets.", 99 | ) 100 | num_samples: int | None = Field( 101 | default=None, 102 | description="Number of samples to use from the dataset.", 103 | ) 104 | num_eval_samples: int | None = Field( 105 | default=None, 106 | description="Number of samples to use from the evaluation dataset.", 107 | ) 108 | eos_label_token_ids: list[int] | None = Field( 109 | default=None, 110 | description="List of token IDs to replace with EOS token IDs in the labels.", 111 | ) 112 | prepared_dataset_path: str | None = Field( 113 | default=None, 114 | description="Path to store prepared dataset.", 115 | ) 116 | prepacked: bool = Field( 117 | default=False, 118 | description="Assume dataset is pretokenized and packed, skip TRL packing.", 119 | ) 120 | 121 | 122 | class TeacherModelConfig(BaseModel): 123 | kind: Literal["hf"] = "hf" 124 | 125 | path: str 126 | kwargs: dict[str, Any] | None = None 127 | 128 | top_k: int | None = None 129 | 130 | 131 | class TeacherDatasetConfig(BaseModel): 132 | kind: Literal["dataset"] = "dataset" 133 | legacy_logit_compression: LegacyLogitCompressionConfig | None = Field( 134 | default=None, 135 | description="Legacy logit compression configuration. Must match configuration used to capture logits.", 136 | ) 137 | logprob_compressor: DistributionQuantizationConfig | None = Field( 138 | default=None, 139 | description="Logit compression configuration. Must match configuration used to capture logits.", 140 | ) 141 | 142 | 143 | class DistillationRunConfig(BaseModel): 144 | project_name: str = Field( 145 | default="distillkit", 146 | description="Project name for logging.", 147 | ) 148 | train_model: str = Field( 149 | description="Model to train.", 150 | alias="model", 151 | ) 152 | dataset: DatasetConfiguration 153 | teacher: TeacherModelConfig | TeacherDatasetConfig = Field( 154 | ..., discriminator="kind" 155 | ) 156 | sequence_length: int = Field( 157 | description="Sequence length for training.", 158 | ) 159 | output_path: str = Field( 160 | description="Path to save the model.", 161 | ) 162 | resize_embeddings_to_multiple_of: int | None = Field( 163 | default=None, 164 | description="Resize embeddings to a multiple of this value.", 165 | ) 166 | use_flash_attention: bool = Field( 167 | default=True, 168 | description="Use flash attention for training.", 169 | ) 170 | 171 | loss_functions: list[LossFunctionConfig] = Field( 172 | description="List of loss functions to use for distillation.", 173 | default_factory=lambda: [ 174 | LossFunctionConfig( 175 | function=LossFunction.CROSS_ENTROPY, 176 | weight=0.5, 177 | ), 178 | LossFunctionConfig( 179 | function=LossFunction.KL, 180 | weight=0.5, 181 | temperature=1.0, 182 | missing_probability_handling=MissingProbabilityHandling.ZERO, 183 | ), 184 | ], 185 | ) 186 | layer_mapping: list[tuple[int, int]] | Literal["all"] | None = Field( 187 | default=None, 188 | description='List of (student_layer_idx, teacher_layer_idx) pairs (or "all" for a complete one-to-one mapping.)', 189 | ) 190 | force_hidden_state_projection: bool = Field( 191 | default=False, 192 | description="Use linear layers to project between teacher and student hidden states even if sizes are equal.", 193 | ) 194 | functionary_packing: bool = Field( 195 | default=False, 196 | description="Use functionary's packing code. Requires flash attention and may not be compatible with all models.", 197 | ) 198 | training_args: dict[str, Any] = Field( 199 | default_factory=dict, 200 | description="Additional arguments for the trainer.", 201 | ) 202 | model_kwargs: dict[str, Any] = Field( 203 | default_factory=dict, 204 | description="Additional arguments for the model.", 205 | ) 206 | model_auto_class: str | None = Field( 207 | default="AutoModelForCausalLM", 208 | description="Auto class for the model.", 209 | ) 210 | trust_remote_code: bool = Field( 211 | default=False, 212 | description="Trust remote code when loading the model.", 213 | ) 214 | frozen_modules: list[str] | None = Field( 215 | default=None, 216 | description="List of modules to freeze during training.", 217 | ) 218 | frozen_res: list[str] | None = Field( 219 | default=None, 220 | description="List of regular expressions matching names of parameters to freeze during training.", 221 | ) 222 | -------------------------------------------------------------------------------- /distillkit/compression/legacy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Charles O. Goddard 2 | 3 | import torch 4 | 5 | from distillkit.compression.config import ( 6 | LegacyLogitCompressionConfig as LogitCompressionConfig, 7 | ) 8 | from distillkit.compression.densify import densify 9 | 10 | 11 | class LogitCompressor: 12 | """Compresses and decompresses logits using polynomial approximation.""" 13 | 14 | def __init__(self, config: LogitCompressionConfig): 15 | self.config = config 16 | self.vocab_index_bits = ( 17 | torch.log2(torch.tensor(self.config.vocab_size, dtype=torch.float32)) 18 | .ceil() 19 | .item() 20 | ) 21 | self._validate_config() 22 | self._setup_polynomial_terms() 23 | 24 | def _validate_config(self): 25 | assert self.config.exact_k <= self.config.k, ( 26 | "exact_k must be less than or equal to k" 27 | ) 28 | assert self.config.k > 0, "k must be greater than 0" 29 | 30 | def _setup_polynomial_terms(self): 31 | a = torch.arange(self.config.k - self.config.exact_k, dtype=torch.float32) + 1 32 | exponents = [ 33 | -i if self.config.invert_polynomial else i 34 | for i in range(self.config.polynomial_degree + 1) 35 | ] 36 | terms = [a**exp for exp in exponents] 37 | if self.config.with_sqrt_term: 38 | terms.append(a.sqrt()) 39 | self.X = torch.stack(terms, dim=-1).unsqueeze(0).unsqueeze(0) 40 | 41 | def compress_from_sparse( 42 | self, top_indices: torch.LongTensor, top_values: torch.Tensor 43 | ): 44 | exact_values = top_values[..., : self.config.exact_k] 45 | approx_values = top_values[..., self.config.exact_k : self.config.k] 46 | 47 | if self.config.exact_k < self.config.k: 48 | X = self.X.to(top_values.device, top_values.dtype) 49 | y = approx_values.unsqueeze(-1) 50 | coeffs = self._solve_least_squares(X, y).squeeze(-1) 51 | 52 | if self.config.term_dtype != "float32": 53 | coeffs = coeffs.to(self._str_to_dtype(self.config.term_dtype)) 54 | else: 55 | coeffs = torch.zeros( 56 | top_values.shape[:-1] + (0,), 57 | device=top_values.device, 58 | dtype=self._str_to_dtype(self.config.term_dtype), 59 | ) 60 | 61 | packed_indices = pack_tensor(top_indices, int(self.vocab_index_bits)) 62 | return packed_indices, exact_values.to(dtype=torch.float16), coeffs 63 | 64 | def compress( 65 | self, logits: torch.Tensor 66 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 67 | mean = logits.mean(dim=-1, keepdim=True) 68 | centered_logits = logits - mean 69 | top_values, top_indices = torch.topk(centered_logits, self.config.k, dim=-1) 70 | return self.compress_from_sparse(top_indices, top_values) 71 | 72 | def decompress_to_sparse( 73 | self, 74 | packed_indices: torch.Tensor, 75 | exact_values: torch.Tensor, 76 | coeffs: torch.Tensor, 77 | ) -> tuple[torch.LongTensor, torch.Tensor]: 78 | batch_size, seq_len = packed_indices.shape[:2] 79 | top_indices = unpack_tensor( 80 | packed_indices, 81 | torch.Size([batch_size, seq_len, self.config.k]), 82 | int(self.vocab_index_bits), 83 | ) 84 | approx_logits = torch.sum( 85 | self.X.to(coeffs.device, coeffs.dtype) 86 | * coeffs.to(dtype=self.X.dtype).unsqueeze(-2), 87 | dim=-1, 88 | ) 89 | top_values = torch.cat([exact_values, approx_logits], dim=-1) 90 | return top_indices, top_values 91 | 92 | def decompress( 93 | self, 94 | packed_indices: torch.Tensor, 95 | exact_values: torch.Tensor, 96 | coeffs: torch.Tensor, 97 | ) -> torch.Tensor: 98 | top_indices, top_values = self.decompress_to_sparse( 99 | packed_indices, exact_values, coeffs 100 | ) 101 | return densify(top_indices, top_values, self.config.vocab_size) 102 | 103 | @staticmethod 104 | def _solve_least_squares(A, B): 105 | # because for some reason torch.linalg.lstsq only works for full-rank matrices on GPU 106 | U, S, Vh = torch.linalg.svd(A, full_matrices=False) 107 | tol = 1e-5 108 | Spinv = torch.zeros_like(S) 109 | Spinv[S > tol] = 1 / S[S > tol] 110 | UhB = U.transpose(-1, -2) @ B 111 | SpinvUhB = Spinv.unsqueeze(-1) * UhB 112 | return Vh.transpose(-1, -2) @ SpinvUhB 113 | 114 | @staticmethod 115 | def _str_to_dtype(dtype_str: str): 116 | return { 117 | "float16": torch.float16, 118 | "bfloat16": torch.bfloat16, 119 | "float32": torch.float32, 120 | "float64": torch.float64, 121 | }.get(dtype_str, torch.float32) 122 | 123 | def bytes_per_token(self) -> int: 124 | index_bits = self.vocab_index_bits * self.config.k 125 | index_longs = (index_bits + 63) // 64 126 | index_bytes = index_longs * 8 127 | return ( 128 | self.config.exact_k * 2 129 | + self.config.polynomial_degree 130 | * (self._str_to_dtype(self.config.term_dtype).itemsize) 131 | + index_bytes 132 | ) 133 | 134 | 135 | def pack_tensor(x: torch.Tensor, bits: int) -> torch.Tensor: 136 | """Bit-packs a tensor of integers into a tensor of longs. 137 | 138 | Args: 139 | x (torch.Tensor): Input tensor of integers 140 | bits (int): Number of bits to use for each element 141 | """ 142 | # written by Claude Sonnet 3.5, thanx bro 143 | assert x.dtype == torch.long, "Input tensor must be of type torch.long" 144 | assert 1 <= bits <= 63, "Number of bits must be between 1 and 63" 145 | 146 | device = x.device 147 | max_value = 2**bits - 1 148 | assert torch.all(x >= 0) and torch.all(x <= max_value), ( 149 | f"All values must be between 0 and {max_value}" 150 | ) 151 | 152 | # Calculate the number of elements that can fit in 64 bits 153 | elements_per_64bits = 64 // bits 154 | 155 | # Pad the last dimension to be a multiple of elements_per_64bits 156 | pad_size = ( 157 | elements_per_64bits - x.shape[-1] % elements_per_64bits 158 | ) % elements_per_64bits 159 | x_padded = torch.nn.functional.pad(x, (0, pad_size)) 160 | 161 | # Reshape the tensor to group elements that will be packed together 162 | x_reshaped = x_padded.reshape(*x_padded.shape[:-1], -1, elements_per_64bits) 163 | 164 | packed = torch.zeros(*x_reshaped.shape[:-1], dtype=torch.int64, device=device) 165 | 166 | for i in range(elements_per_64bits): 167 | packed |= x_reshaped[..., i] << (bits * i) 168 | 169 | return packed 170 | 171 | 172 | def unpack_tensor( 173 | packed: torch.Tensor, original_size: torch.Size, bits: int 174 | ) -> torch.Tensor: 175 | """Unpacks a bit-packed tensor of longs into a tensor of integers. 176 | 177 | Inverse operation of pack_tensor. 178 | 179 | Args: 180 | packed (torch.Tensor): Input tensor of longs 181 | original_size (torch.Size): Original size of the unpacked tensor 182 | bits (int): Number of bits used for each element 183 | """ 184 | assert packed.dtype == torch.long, "Packed tensor must be of type torch.long" 185 | assert 1 <= bits <= 63, "Number of bits must be between 1 and 63" 186 | 187 | device = packed.device 188 | elements_per_64bits = 64 // bits 189 | mask = (1 << bits) - 1 190 | 191 | unpacked = torch.zeros( 192 | *packed.shape[:-1], 193 | packed.shape[-1] * elements_per_64bits, 194 | dtype=torch.long, 195 | device=device, 196 | ) 197 | 198 | for i in range(elements_per_64bits): 199 | unpacked[..., i::elements_per_64bits] = (packed >> (bits * i)) & mask 200 | 201 | # Trim the unpacked tensor to the original size 202 | return unpacked[..., : original_size[-1]] 203 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/common.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Charles O. Goddard 2 | import math 3 | from abc import ABC, abstractmethod 4 | from typing import Callable 5 | 6 | import torch 7 | from transformers.modeling_outputs import CausalLMOutput 8 | 9 | from distillkit.configuration import MissingProbabilityHandling 10 | from distillkit.hsd_mapping import HiddenStateMapping 11 | from distillkit.signals import TeacherSignal 12 | 13 | 14 | def get_target_logprobs( 15 | values_in: torch.Tensor, 16 | log_target: bool, 17 | distillation_temperature: float, 18 | target_generation_temperature: float, 19 | missing: MissingProbabilityHandling, 20 | ) -> torch.Tensor: 21 | temperature_change = not math.isclose( 22 | distillation_temperature, target_generation_temperature 23 | ) 24 | if log_target: 25 | if ( 26 | not temperature_change 27 | and missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM 28 | ): 29 | # if we are not changing the temperature and don't need to 30 | # renormalize to sum to 1 as in the ZERO case, we can just 31 | # return the input values 32 | return values_in 33 | # we want to divide the target logits by temperature 34 | # but unfortunately, we only have topk logprobs 35 | # general approach: convert to probs, apply temperature, renormalize 36 | alpha = target_generation_temperature / distillation_temperature 37 | 38 | # compute total mass of unscaled target probs 39 | max_in, _ = values_in.max(dim=-1, keepdim=True) 40 | lse_in = torch.logsumexp(values_in - max_in, dim=-1, keepdim=True) + max_in 41 | target_sum = lse_in.exp() 42 | leftover = (1.0 - target_sum).clamp(0, 1) 43 | 44 | alpha_values_in = alpha * values_in 45 | # We need sum( p_i^alpha ). In log space: sum( exp(alpha * log p_i) ). 46 | max_alpha, _ = alpha_values_in.max(dim=-1, keepdim=True) 47 | alpha_lse = ( 48 | torch.logsumexp(alpha_values_in - max_alpha, dim=-1, keepdim=True) 49 | + max_alpha 50 | ) 51 | alpha_sum = alpha_lse.exp() 52 | 53 | if missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM: 54 | # assume the remaining probability mass is distributed uniformly 55 | # over the missing token indices in both the teacher and student 56 | # distributions (symmetrically) - this lets us act as though 57 | # instead of a N-dimensional distribution we have a k+1 dimensional 58 | # distribution, where the k+1th dimension is the probability of all 59 | # other tokens. 60 | leftover_alpha = leftover.pow(alpha) 61 | divisor = alpha_sum + leftover_alpha 62 | final_lse = divisor.log() 63 | else: 64 | # assume the missing tokens have zero probability 65 | final_lse = alpha_lse 66 | # final log-probabilities are alpha * log p_i - log(Z) 67 | return alpha_values_in - final_lse 68 | else: 69 | # we have logits, praise be 70 | if temperature_change: 71 | logits = values_in * ( 72 | target_generation_temperature / distillation_temperature 73 | ) 74 | else: 75 | logits = values_in 76 | sparse_max = torch.max(logits, dim=-1, keepdim=True).values 77 | sparse_lse = ( 78 | torch.logsumexp( 79 | (logits - sparse_max).to(torch.float32), dim=-1, keepdim=True 80 | ) 81 | + sparse_max 82 | ).to(values_in.dtype) 83 | return logits - sparse_lse 84 | 85 | 86 | def get_logprobs( 87 | logits: torch.Tensor, 88 | target_ids: torch.LongTensor, 89 | target_values: torch.Tensor, 90 | eps: float = 1e-8, 91 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 92 | log_target: bool = True, 93 | distillation_temperature: float = 1.0, 94 | target_generation_temperature: float = 1.0, 95 | student_generation_temperature: float = 1.0, 96 | ) -> tuple[torch.Tensor, torch.Tensor]: 97 | batch_size, seq_len, vocab_size = logits.shape 98 | assert target_ids.shape[:-1] == (batch_size, seq_len), ( 99 | f"Target ids shape {target_ids.shape[:-1]} does not match logits shape " 100 | f"{logits.shape}" 101 | ) 102 | assert target_values.shape == target_ids.shape, ( 103 | f"Target values shape {target_values.shape} does not match target ids shape " 104 | f"{target_ids.shape}" 105 | ) 106 | assert distillation_temperature > eps, ( 107 | f"Temperature must be positive and non-zero, got {distillation_temperature}" 108 | ) 109 | out_dtype = logits.dtype 110 | 111 | if (not log_target) and (missing != MissingProbabilityHandling.ZERO): 112 | raise ValueError( 113 | "For log_target=False (teacher inputs are logits), " 114 | "MissingProbabilityHandling.SYMMETRIC_UNIFORM is ill-defined. " 115 | "The teacher distribution is only over the provided sparse logits. " 116 | "Use MissingProbabilityHandling.ZERO." 117 | ) 118 | 119 | if not math.isclose(distillation_temperature, student_generation_temperature): 120 | logits = logits * (student_generation_temperature / distillation_temperature) 121 | student_lse = torch.logsumexp(logits.to(torch.float32), dim=-1, keepdim=True).to( 122 | out_dtype 123 | ) 124 | sparse_student_logprobs = logits.gather(-1, target_ids) - student_lse 125 | del student_lse, logits 126 | 127 | with torch.no_grad(): 128 | sparse_target_logprobs = get_target_logprobs( 129 | target_values.to(torch.float32), 130 | log_target=log_target, 131 | distillation_temperature=distillation_temperature, 132 | target_generation_temperature=target_generation_temperature, 133 | missing=missing, 134 | ).to(out_dtype) 135 | del target_values 136 | 137 | return sparse_student_logprobs, sparse_target_logprobs 138 | 139 | 140 | def accumulate_over_chunks( 141 | logits: torch.Tensor, 142 | target_ids: torch.LongTensor, 143 | target_values: torch.Tensor, 144 | mask: torch.Tensor | None, 145 | chunk_length: int | None, 146 | fn: Callable, 147 | *args, 148 | **kwargs, 149 | ) -> torch.Tensor: 150 | """Accumulate the result of a function over chunks of the input tensors. 151 | Args: 152 | logits (torch.Tensor): The logits tensor. 153 | target_ids (torch.LongTensor): The target IDs tensor. 154 | target_values (torch.Tensor): The target values tensor. 155 | chunk_size (int | None): The size of each chunk. If None, the entire sequence is used. 156 | fn (Callable): The function to apply to each chunk. 157 | *args: Additional arguments to pass to the function. 158 | **kwargs: Additional keyword arguments to pass to the function. 159 | Returns: 160 | torch.Tensor: The accumulated result. 161 | """ 162 | seq_len = logits.shape[1] 163 | if chunk_length is None: 164 | chunk_length = seq_len 165 | 166 | total = 0.0 167 | 168 | for start_idx in range(0, seq_len, chunk_length): 169 | if mask is not None: 170 | cur_mask = mask[:, start_idx : start_idx + chunk_length] 171 | else: 172 | cur_mask = None 173 | end_idx = min(start_idx + chunk_length, seq_len) 174 | total += fn( 175 | logits[:, start_idx:end_idx], 176 | target_ids[:, start_idx:end_idx], 177 | target_values[:, start_idx:end_idx], 178 | cur_mask, 179 | *args, 180 | **kwargs, 181 | ) 182 | return total 183 | 184 | 185 | class LossFunctionBase(ABC): 186 | @classmethod 187 | @abstractmethod 188 | def name(cls) -> str: ... 189 | 190 | def requires_hidden_states(self) -> bool: 191 | return False 192 | 193 | @abstractmethod 194 | def __init__(self, **kwargs) -> None: ... 195 | 196 | @abstractmethod 197 | def __call__( 198 | self, 199 | student_outputs: CausalLMOutput, 200 | signal: TeacherSignal, 201 | mask: torch.Tensor | None = None, 202 | hidden_state_mapping: HiddenStateMapping | None = None, 203 | num_items_in_batch: int | None = None, 204 | ) -> torch.Tensor: ... 205 | -------------------------------------------------------------------------------- /distillkit/monkey_patch_packing.py: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/MeetKai/functionary 2 | import sys 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import transformers 7 | 8 | 9 | def get_max_seqlen_in_batch(attention_mask): 10 | max_num = torch.max(attention_mask) 11 | # attention_mask: B x N 12 | counts = [] 13 | for i in range(1, max_num + 1): 14 | counts.append( 15 | torch.sum(attention_mask == i, axis=-1) 16 | ) # shape: B, count length of data point maksed with i 17 | result = torch.stack(counts, axis=1) 18 | result = result.flatten() 19 | return result[result.nonzero()].squeeze(-1).to(dtype=torch.int32) 20 | 21 | 22 | def get_unpad_data(attention_mask): 23 | seqlens_in_batch = get_max_seqlen_in_batch( 24 | attention_mask 25 | ) # attention_mask.sum(dim=-1, dtype=torch.int32) 26 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 27 | max_seqlen_in_batch = seqlens_in_batch.max().item() 28 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 29 | return ( 30 | indices, 31 | cu_seqlens, 32 | max_seqlen_in_batch, 33 | ) 34 | 35 | 36 | # Copy from original implementation of modeling_mixtral.py from transformers, Just change a little bit with new_attention_mask 37 | def load_balancing_loss_func( 38 | gate_logits: torch.Tensor, 39 | num_experts: torch.Tensor = None, 40 | top_k=2, 41 | attention_mask: torch.Tensor | None = None, 42 | ) -> float: 43 | r""" 44 | Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. 45 | 46 | See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss 47 | function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between 48 | experts is too unbalanced. 49 | 50 | Args: 51 | gate_logits (`torch.Tensor` | tuple[torch.Tensor]): 52 | Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of 53 | shape [batch_size X sequence_length, num_experts]. 54 | attention_mask (`torch.Tensor`, None): 55 | The attention_mask used in forward function 56 | shape [batch_size X sequence_length] if not None. 57 | num_experts (`int`, *optional*): 58 | Number of experts 59 | 60 | Returns: 61 | The auxiliary loss. 62 | """ 63 | if gate_logits is None or not isinstance(gate_logits, tuple): 64 | return 0 65 | 66 | if isinstance(gate_logits, tuple): 67 | compute_device = gate_logits[0].device 68 | concatenated_gate_logits = torch.cat( 69 | [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 70 | ) 71 | 72 | routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) 73 | 74 | _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) 75 | 76 | expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) 77 | 78 | if attention_mask is None: 79 | # Compute the percentage of tokens routed to each experts 80 | tokens_per_expert = torch.mean(expert_mask.float(), dim=0) 81 | 82 | # Compute the average probability of routing to these experts 83 | router_prob_per_expert = torch.mean(routing_weights, dim=0) 84 | else: 85 | # ONLY ADD THIS LINE OF CODE, AND REPLACE attention_mask WITH new_attention_mask 86 | new_attention_mask = (attention_mask != 0).int().to(attention_mask.device) 87 | batch_size, sequence_length = new_attention_mask.shape 88 | num_hidden_layers = concatenated_gate_logits.shape[0] // ( 89 | batch_size * sequence_length 90 | ) 91 | 92 | # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask 93 | expert_attention_mask = ( 94 | new_attention_mask[None, :, :, None, None] 95 | .expand( 96 | (num_hidden_layers, batch_size, sequence_length, top_k, num_experts) 97 | ) 98 | .reshape(-1, top_k, num_experts) 99 | .to(compute_device) 100 | ) 101 | 102 | # Compute the percentage of tokens routed to each experts 103 | tokens_per_expert = torch.sum( 104 | expert_mask.float() * expert_attention_mask, dim=0 105 | ) / torch.sum(expert_attention_mask, dim=0) 106 | 107 | # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert 108 | router_per_expert_attention_mask = ( 109 | new_attention_mask[None, :, :, None] 110 | .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) 111 | .reshape(-1, num_experts) 112 | .to(compute_device) 113 | ) 114 | 115 | # Compute the average probability of routing to these experts 116 | router_prob_per_expert = torch.sum( 117 | routing_weights * router_per_expert_attention_mask, dim=0 118 | ) / torch.sum(router_per_expert_attention_mask, dim=0) 119 | 120 | overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) 121 | return overall_loss * num_experts 122 | 123 | 124 | def monkey_patch_for_model_with_name(model_type: str, modelling_type: str): 125 | """For example for llama: model_package = llama, modelling_module=modeling_llama 126 | 127 | Args: 128 | model_package (_type_): _description_ 129 | modelling_module (_type_): _description_ 130 | """ 131 | module = getattr(getattr(transformers, model_type), modelling_type) 132 | if hasattr(module, "_get_unpad_data"): 133 | module._get_unpad_data = get_unpad_data 134 | else: 135 | print( 136 | f"cannot packing llama because _get_unpad_data was not found in transformers.{model_type}.{modelling_type}.py or transformers.modeling_flash_attention_utils._get_unpad_data" 137 | ) 138 | sys.exit(1) 139 | 140 | 141 | def monkey_patch_packing_for_model(pretrained_model): 142 | # Monkey-patch flash attention if this transformers already merged: https://github.com/huggingface/transformers/commit/e314395277d784a34ee99526f48155d4d62cff3d 143 | # this will work for all models using flash attention: Llama, Mistral, Qwen2, Phi3, ... 144 | model_config = transformers.AutoConfig.from_pretrained(pretrained_model) 145 | config_type = type(model_config).__name__.lower() 146 | if hasattr(transformers, "modeling_flash_attention_utils"): 147 | transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data 148 | else: # if this is the old version of transformer 149 | model_type, modelling_type = "", "" 150 | if config_type == "mistralconfig": 151 | print("monkey_patch_packing for Mistral ") 152 | transformers.models.mistral.modeling_mistral._get_unpad_data = ( 153 | get_unpad_data 154 | ) 155 | 156 | elif config_type == "llamaconfig": 157 | print("monkey_patch_packing for Llama ") 158 | transformers.models.llama.modeling_llama._get_unpad_data = get_unpad_data 159 | 160 | elif config_type == "mixtralconfig": 161 | print("monkey_patch_packing for Mixtral") 162 | transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( 163 | get_unpad_data 164 | ) 165 | 166 | elif config_type == "qwen2config": 167 | print("monkey_patch_packing for Qwen2") 168 | # transformers.models.qwen2.modeling_qwen2 169 | model_type, modelling_type = "qwen2", "modeling_qwen2" 170 | transformers.models.qwen2.modeling_qwen2._get_unpad_data = get_unpad_data 171 | 172 | elif config_type == "phi3config": 173 | # transformers.models.phi3.modeling_phi3 174 | print("monkey_patch_packing for Qwen2") 175 | transformers.models.phi3.modeling_phi3._get_unpad_data = get_unpad_data 176 | else: 177 | raise Exception( 178 | f"{config_type} is not supported, currently we only support: Mistral, Mixtral, Llama, Qwen2 for monkey-patch-packing" 179 | ) 180 | 181 | monkey_patch_for_model_with_name(model_type, modelling_type) 182 | 183 | if config_type == "mixtralconfig": 184 | # if it is mixtral, we need to monkey-patch the load_balancing_loss_func 185 | transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func = ( 186 | load_balancing_loss_func 187 | ) 188 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/kl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Charles O. Goddard 2 | 3 | import torch 4 | from transformers.modeling_outputs import CausalLMOutput 5 | from typing_extensions import override 6 | 7 | from distillkit.hsd_mapping import HiddenStateMapping 8 | from distillkit.lossfuncs.common import ( 9 | LossFunctionBase, 10 | MissingProbabilityHandling, 11 | accumulate_over_chunks, 12 | get_logprobs, 13 | ) 14 | from distillkit.signals import DenseSignal, TeacherSignal 15 | 16 | 17 | def sparse_kl_div_inner( 18 | logits: torch.Tensor, 19 | target_ids: torch.LongTensor, 20 | target_values: torch.Tensor, 21 | mask: torch.Tensor | None = None, 22 | eps: float = 1e-8, 23 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 24 | log_target: bool = True, 25 | temperature: float = 1.0, 26 | target_generation_temperature: float = 1.0, 27 | student_generation_temperature: float = 1.0, 28 | ) -> torch.Tensor: 29 | """Compute the KL divergence between a dense set of predictions and a sparse set of target logits. 30 | 31 | See `sparse_kl_div` for details. 32 | """ 33 | batch_size, seq_len, vocab_size = logits.shape 34 | out_dtype = logits.dtype 35 | sparse_student_logprobs, sparse_target_logprobs = get_logprobs( 36 | logits, 37 | target_ids, 38 | target_values, 39 | eps=eps, 40 | missing=missing, 41 | log_target=log_target, 42 | distillation_temperature=temperature, 43 | target_generation_temperature=target_generation_temperature, 44 | student_generation_temperature=student_generation_temperature, 45 | ) 46 | 47 | # Terms for non-zero target probabilities 48 | teacher_sparse_probs = torch.exp(sparse_target_logprobs) 49 | teacher_prob_sum = teacher_sparse_probs.to(torch.float32).sum(dim=-1) 50 | inner_sum = torch.sum( 51 | teacher_sparse_probs * (sparse_target_logprobs - sparse_student_logprobs), 52 | dim=-1, 53 | ) 54 | del sparse_target_logprobs, teacher_sparse_probs 55 | 56 | # Compute the contribution of missing logits to KL divergence 57 | if missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM: 58 | # if the teacher's logprobs don't sum to 1, we assume the remaining 59 | # probability mass in *both* the teacher and student is distributed 60 | # uniformly over the token indices missing from the teacher's distribution 61 | log_teacher_missing = torch.log1p(-teacher_prob_sum.clamp(min=eps, max=1 - eps)) 62 | student_probs = sparse_student_logprobs.to(torch.float32).exp_() 63 | student_prob_sum = student_probs.sum(dim=-1) 64 | del student_probs 65 | log_student_missing = torch.log1p(-student_prob_sum.clamp(min=eps, max=1 - eps)) 66 | del student_prob_sum 67 | missing_kl = torch.exp(log_teacher_missing) * ( 68 | log_teacher_missing - log_student_missing 69 | ) 70 | else: 71 | # in this case we assume zero probability mass for missing tokens 72 | # in the teacher distribution, and thus zero contribution to KL divergence 73 | missing_kl = None 74 | del sparse_student_logprobs 75 | 76 | if mask is not None: 77 | if mask.dim() == 3: 78 | mask = mask.squeeze(-1) 79 | inner_sum *= mask 80 | if missing_kl is not None: 81 | missing_kl *= mask 82 | del mask 83 | 84 | if missing_kl is not None: 85 | inner_sum += missing_kl 86 | 87 | return torch.sum(inner_sum).to(out_dtype) 88 | 89 | 90 | def sparse_kl_div( 91 | logits: torch.Tensor, 92 | target_ids: torch.LongTensor, 93 | target_values: torch.Tensor, 94 | mask: torch.Tensor | None = None, 95 | eps: float = 1e-8, 96 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 97 | log_target: bool = True, 98 | temperature: float = 1.0, 99 | target_generation_temperature: float = 1.0, 100 | student_generation_temperature: float = 1.0, 101 | chunk_length: int | None = None, 102 | ) -> torch.Tensor: 103 | """Compute the KL divergence between a dense set of predictions and a sparse set of target logits. 104 | 105 | Uses a chunked approach to avoid memory issues with large sequences. 106 | 107 | Args: 108 | logits: Dense tensor of predictions. 109 | target_ids: Tensor of indices for target logits. 110 | target_values: Tensor of values for target logits or log probabilities. 111 | mask: Optional boolean mask tensor. True indicates tokens to include, False to exclude. 112 | eps: Small value to prevent numerical instability. 113 | missing: How to handle missing probabilities in the target distribution. If ZERO, missing 114 | probabilities are assumed to be zero. If SYMMETRIC_UNIFORM, missing probabilities are 115 | assumed to be distributed uniformly over the missing tokens in both the teacher and 116 | student distributions. 117 | log_target: Whether the target values are already log probabilities. 118 | temperature: Temperature to apply to the distributions. 119 | target_generation_temperature: Temperature already applied to the target logits/logprobs. 120 | student_generation_temperature: Temperature already applied to the student logits. 121 | chunk_length: Number of tokens per chunk. If None, the entire sequence is processed at once. 122 | """ 123 | return accumulate_over_chunks( 124 | logits, 125 | target_ids, 126 | target_values, 127 | mask, 128 | chunk_length, 129 | sparse_kl_div_inner, 130 | eps=eps, 131 | missing=missing, 132 | log_target=log_target, 133 | temperature=temperature, 134 | target_generation_temperature=target_generation_temperature, 135 | student_generation_temperature=student_generation_temperature, 136 | ) 137 | 138 | 139 | def dense_kl_div( 140 | logits: torch.Tensor, 141 | target_logits: torch.Tensor, 142 | mask: torch.Tensor | None = None, 143 | temperature: float = 1.0, 144 | ) -> torch.Tensor: 145 | """Compute the KL divergence between a dense set of predictions and a dense set of target logits. 146 | 147 | Args: 148 | logits: Dense tensor of predictions (Student). 149 | target_logits: Dense tensor of target logits (Teacher). 150 | mask: Optional boolean mask tensor. True indicates tokens to include, False to exclude. 151 | temperature: Temperature to apply to the distributions. 152 | """ 153 | out_dtype = logits.dtype 154 | 155 | student_logprobs = torch.log_softmax(logits.float() / temperature, dim=-1) 156 | teacher_logprobs = torch.log_softmax(target_logits.float() / temperature, dim=-1) 157 | 158 | kl_per_element = torch.nn.functional.kl_div( 159 | input=student_logprobs, 160 | target=teacher_logprobs, 161 | reduction="none", 162 | log_target=True, 163 | ) 164 | # Sum over the vocabulary dimension (dim=-1) to get KL per token 165 | kl_per_token = torch.sum(kl_per_element, dim=-1) 166 | 167 | if mask is not None: 168 | if mask.dim() == 3: 169 | mask = mask.squeeze(-1) 170 | kl_per_token = kl_per_token * mask.float() 171 | 172 | return torch.sum(kl_per_token).to(out_dtype) 173 | 174 | 175 | class KLDLoss(LossFunctionBase): 176 | temperature: float 177 | missing: MissingProbabilityHandling 178 | chunk_length: int | None 179 | 180 | @override 181 | @classmethod 182 | def name(cls) -> str: 183 | return "kl" 184 | 185 | @override 186 | def __init__( 187 | self, 188 | temperature: float, 189 | missing_probability_handling: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 190 | sparse_chunk_length: int | None = None, 191 | ) -> None: 192 | self.temperature = temperature 193 | self.missing = missing_probability_handling 194 | self.chunk_length = sparse_chunk_length 195 | 196 | @override 197 | def __call__( 198 | self, 199 | student_outputs: CausalLMOutput, 200 | signal: TeacherSignal, 201 | mask: torch.Tensor | None = None, 202 | hidden_state_mapping: HiddenStateMapping | None = None, 203 | num_items_in_batch: int | None = None, 204 | ) -> torch.Tensor: 205 | if num_items_in_batch is None: 206 | if mask is not None: 207 | num_items_in_batch = mask.float().sum() 208 | else: 209 | num_items_in_batch = ( 210 | student_outputs.logits.shape[0] * student_outputs.logits.shape[1] 211 | ) 212 | if isinstance(signal, DenseSignal): 213 | res = dense_kl_div( 214 | student_outputs.logits, 215 | signal.logits, 216 | mask=mask, 217 | temperature=self.temperature, 218 | ) 219 | else: 220 | res = sparse_kl_div( 221 | logits=student_outputs.logits, 222 | target_ids=signal.sparse_ids, 223 | target_values=signal.sparse_values, 224 | mask=mask, 225 | missing=self.missing, 226 | log_target=signal.log_values, 227 | temperature=self.temperature, 228 | target_generation_temperature=signal.generation_temperature, 229 | chunk_length=self.chunk_length, 230 | ) 231 | return res * (self.temperature**2) / num_items_in_batch 232 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/tvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers.modeling_outputs import CausalLMOutput 3 | from typing_extensions import override 4 | 5 | from distillkit.hsd_mapping import HiddenStateMapping 6 | from distillkit.lossfuncs.common import ( 7 | LossFunctionBase, 8 | MissingProbabilityHandling, 9 | accumulate_over_chunks, 10 | get_logprobs, 11 | ) 12 | from distillkit.signals import DenseSignal, TeacherSignal 13 | 14 | 15 | def sparse_tvd_inner( 16 | logits: torch.Tensor, 17 | target_ids: torch.LongTensor, 18 | target_values: torch.Tensor, 19 | mask: torch.Tensor | None = None, 20 | eps: float = 1e-8, 21 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 22 | log_target: bool = True, 23 | temperature: float = 1.0, 24 | target_generation_temperature: float = 1.0, 25 | student_generation_temperature: float = 1.0, 26 | ) -> torch.Tensor: 27 | """Compute the Total Variation Distance (TVD) between dense student predictions and sparse teacher targets. 28 | 29 | See `sparse_tvd_div` for details. 30 | """ 31 | batch_size, seq_len, vocab_size = logits.shape 32 | sparse_student_logprobs, sparse_target_logprobs = get_logprobs( 33 | logits, 34 | target_ids, 35 | target_values, 36 | eps=eps, 37 | missing=missing, 38 | log_target=log_target, 39 | distillation_temperature=temperature, 40 | target_generation_temperature=target_generation_temperature, 41 | student_generation_temperature=student_generation_temperature, 42 | ) 43 | sparse_student_probs = torch.exp(sparse_student_logprobs) 44 | del sparse_student_logprobs 45 | sparse_teacher_probs = torch.exp(sparse_target_logprobs) 46 | del sparse_target_logprobs 47 | 48 | # --- 3. Compute TVD for sparse indices --- 49 | # sum_i |P_i - Q_i| for i in target_ids 50 | tvd_sparse_terms_sum = torch.sum( 51 | torch.abs(sparse_teacher_probs - sparse_student_probs), dim=-1 52 | ) # Shape: (B, S) 53 | 54 | # --- 4. Compute TVD contribution from missing indices --- 55 | if missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM: 56 | # In this case, get_target_logprobs has normalized P such that P_sparse + P_miss = 1 57 | # (where P_miss is the (k+1)th synthetic category). 58 | # So, P_miss = 1 - sum(P_sparse) 59 | teacher_prob_sum_sparse = sparse_teacher_probs.sum(dim=-1) # B, S 60 | # Clamp to avoid small numerical errors making this < 0 or > 1 due to float precision 61 | teacher_missing_prob_mass = (1.0 - teacher_prob_sum_sparse).clamp( 62 | min=0.0, max=1.0 63 | ) 64 | 65 | # Similarly for student Q, Q_miss = 1 - sum(Q_sparse) 66 | student_prob_sum_sparse = sparse_student_probs.sum(dim=-1) # B, S 67 | student_missing_prob_mass = (1.0 - student_prob_sum_sparse).clamp( 68 | min=0.0, max=1.0 69 | ) 70 | 71 | tvd_missing_contrib = torch.abs( 72 | teacher_missing_prob_mass - student_missing_prob_mass 73 | ) # B, S 74 | 75 | else: 76 | # Teacher P_j = 0 for j not in target_ids. 77 | # So, sum_{j not in target_ids} |P_j - Q_j| = sum_{j not in target_ids} |0 - Q_j| 78 | # = sum_{j not in target_ids} Q_j 79 | # This is the total probability mass of the student for tokens NOT in target_ids. 80 | student_prob_sum_sparse = sparse_student_probs.sum(dim=-1) # B, S 81 | student_total_missing_prob_mass = (1.0 - student_prob_sum_sparse).clamp( 82 | min=0.0 83 | ) # B, S 84 | tvd_missing_contrib = student_total_missing_prob_mass 85 | 86 | del sparse_teacher_probs, sparse_student_probs 87 | tvd_token_level = 0.5 * (tvd_sparse_terms_sum + tvd_missing_contrib) 88 | 89 | # --- 6. Masking and Aggregation --- 90 | if mask is not None: 91 | if mask.dim() == 2: # B, S 92 | mask_squozed = mask 93 | else: # B, S, 1 94 | mask_squozed = mask.squeeze(-1) 95 | tvd_token_level *= mask_squozed 96 | 97 | return torch.sum(tvd_token_level) 98 | 99 | 100 | def sparse_tvd( 101 | logits: torch.Tensor, 102 | target_ids: torch.LongTensor, 103 | target_values: torch.Tensor, 104 | mask: torch.Tensor | None = None, 105 | eps: float = 1e-8, 106 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 107 | log_target: bool = True, 108 | temperature: float = 1.0, 109 | target_generation_temperature: float = 1.0, 110 | student_generation_temperature: float = 1.0, 111 | chunk_length: int | None = None, 112 | ) -> torch.Tensor: 113 | """Compute the Total Variation Distance (TVD) between a dense set of student predictions 114 | and a sparse set of teacher targets. 115 | 116 | Uses a chunked approach to avoid memory issues with large sequences. 117 | 118 | TVD = 0.5 * sum_i |P_i - Q_i| 119 | 120 | Args: 121 | logits: Dense tensor of student predictions (batch_size, seq_len, vocab_size). 122 | target_ids: Tensor of indices for teacher target probabilities/logits 123 | (batch_size, seq_len, num_sparse_targets). 124 | target_values: Tensor of values for teacher target probabilities/logits 125 | (batch_size, seq_len, num_sparse_targets). 126 | mask: Optional boolean mask tensor (batch_size, seq_len). True indicates tokens to include. 127 | eps: Small value to avoid numerical issues. Default is 1e-8. 128 | missing: How to handle missing probabilities in the target distribution. 129 | ZERO: Missing teacher probabilities are zero. Student's missing mass contributes to TVD. 130 | SYMMETRIC_UNIFORM: Missing mass in P and Q is treated as a (k+1)th category. 131 | log_target: Whether the target_values are log probabilities (True) or logits (False). 132 | temperature: Distillation temperature to apply to both student and teacher. 133 | target_generation_temperature: Temperature originally used to generate teacher targets. 134 | student_generation_temperature: Temperature originally used to generate student logits. 135 | chunk_length: Number of tokens per chunk. If None, entire sequence is processed at once. 136 | """ 137 | return accumulate_over_chunks( 138 | logits, 139 | target_ids, 140 | target_values, 141 | mask, 142 | chunk_length, 143 | sparse_tvd_inner, 144 | eps=eps, 145 | missing=missing, 146 | log_target=log_target, 147 | temperature=temperature, 148 | target_generation_temperature=target_generation_temperature, 149 | student_generation_temperature=student_generation_temperature, 150 | ) 151 | 152 | 153 | def dense_tvd( 154 | logits: torch.Tensor, 155 | target_logits: torch.Tensor, 156 | mask: torch.Tensor | None = None, 157 | temperature: float = 1.0, 158 | ) -> torch.Tensor: 159 | """Compute the Total Variation Distance (TVD) between dense student predictions 160 | and dense teacher targets. 161 | 162 | TVD = 0.5 * sum_i |P_i - Q_i| 163 | 164 | Args: 165 | logits: Student logits (batch_size, seq_len, vocab_size). 166 | target_logits: Teacher logits (batch_size, seq_len, vocab_size). 167 | mask: Optional mask (batch_size, seq_len) or (batch_size, seq_len, 1). 168 | temperature: Distillation temperature. 169 | """ 170 | # 1. Apply temperature and compute probabilities 171 | student_probs = torch.softmax(logits.float() / temperature, dim=-1) 172 | teacher_probs = torch.softmax(target_logits.float() / temperature, dim=-1) 173 | 174 | # 2. Compute TVD per token 175 | # Sum over the vocabulary dimension (dim=-1) 176 | # Shape becomes: (Batch, Seq_Len) 177 | tvd_token_level = 0.5 * torch.sum(torch.abs(teacher_probs - student_probs), dim=-1) 178 | 179 | # 3. Apply Mask 180 | if mask is not None: 181 | if mask.dim() == 3: 182 | mask = mask.squeeze(-1) 183 | tvd_token_level = tvd_token_level * mask 184 | 185 | return torch.sum(tvd_token_level) 186 | 187 | 188 | class TVDLoss(LossFunctionBase): 189 | temperature: float 190 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO 191 | chunk_length: int | None = None 192 | 193 | @override 194 | @classmethod 195 | def name(cls) -> str: 196 | return "tvd" 197 | 198 | @override 199 | def __init__( 200 | self, 201 | temperature: float, 202 | missing_probability_handling: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 203 | sparse_chunk_length: int | None = None, 204 | ) -> None: 205 | self.temperature = temperature 206 | self.missing = missing_probability_handling 207 | self.chunk_length = sparse_chunk_length 208 | 209 | @override 210 | def __call__( 211 | self, 212 | student_outputs: CausalLMOutput, 213 | signal: TeacherSignal, 214 | mask: torch.Tensor | None = None, 215 | hidden_state_mapping: HiddenStateMapping | None = None, 216 | num_items_in_batch: int | None = None, 217 | ) -> torch.Tensor: 218 | if num_items_in_batch is None: 219 | if mask is not None: 220 | num_items_in_batch = mask.float().sum() 221 | else: 222 | num_items_in_batch = ( 223 | student_outputs.logits.shape[0] * student_outputs.logits.shape[1] 224 | ) 225 | if isinstance(signal, DenseSignal): 226 | res = dense_tvd( 227 | student_outputs.logits, 228 | signal.logits, 229 | mask=mask, 230 | temperature=self.temperature, 231 | ) 232 | else: 233 | res = sparse_tvd( 234 | logits=student_outputs.logits, 235 | target_ids=signal.sparse_ids, 236 | target_values=signal.sparse_values, 237 | mask=mask, 238 | missing=self.missing, 239 | log_target=signal.log_values, 240 | temperature=self.temperature, 241 | target_generation_temperature=signal.generation_temperature, 242 | chunk_length=self.chunk_length, 243 | ) 244 | return res * (self.temperature**2) / num_items_in_batch 245 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /distillkit/sample_common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import queue 4 | import threading 5 | from typing import Any 6 | 7 | import datasets 8 | import pyarrow 9 | import pyarrow.parquet as pq 10 | import torch 11 | import transformers 12 | 13 | ROLE_MAP = { 14 | "gpt": "assistant", 15 | "human": "user", 16 | "system": "system", 17 | "user": "user", 18 | "assistant": "assistant", 19 | "tool": "tool", 20 | } 21 | 22 | 23 | def maybe_trim_bos(text: str, tokenizer: transformers.PreTrainedTokenizerBase): 24 | if ( 25 | tokenizer.bos_token_id is not None 26 | and getattr(tokenizer, "add_bos_token", False) 27 | and text.startswith(tokenizer.bos_token) 28 | ): 29 | return text[len(tokenizer.bos_token) :] 30 | return text 31 | 32 | 33 | def do_chat_template(row: dict, tokenizer: transformers.PreTrainedTokenizerBase): 34 | if "text" in row: 35 | return row["text"] 36 | elif "instruction" in row and "output" in row: 37 | res = tokenizer.apply_chat_template( 38 | [ 39 | {"role": "user", "content": row["instruction"]}, 40 | {"role": "assistant", "content": row["output"]}, 41 | ], 42 | tokenize=False, 43 | ) 44 | return maybe_trim_bos(res, tokenizer) 45 | elif "inputs" in row and "targets" in row: 46 | res = tokenizer.apply_chat_template( 47 | [ 48 | {"role": "user", "content": row["inputs"]}, 49 | {"role": "assistant", "content": row["targets"]}, 50 | ], 51 | tokenize=False, 52 | ) 53 | return maybe_trim_bos(res, tokenizer) 54 | elif "tools" in row and "turns" in row: 55 | tool_defs = [json.loads(tool) for tool in row["tools"]] 56 | turns = [] 57 | for turn in row["turns"]: 58 | turn_out = { 59 | key: turn[key] 60 | for key in ["role", "content", "tool_calls"] 61 | if (key in turn and turn[key] is not None) 62 | } 63 | if "content" not in turn_out: 64 | turn_out["content"] = "" 65 | if "tool_calls" in turn_out: 66 | old_tcs = list(turn_out["tool_calls"]) 67 | new_tcs = [] 68 | for tc in old_tcs: 69 | tool_name = tc["name"] 70 | arguments = json.loads(tc["arguments"]) 71 | tc_out = { 72 | "type": "function", 73 | "function": {"name": tool_name, "arguments": arguments}, 74 | } 75 | new_tcs.append(tc_out) 76 | turn_out["tool_calls"] = new_tcs 77 | turns.append(turn_out) 78 | res = tokenizer.apply_chat_template(turns, tools=tool_defs, tokenize=False) 79 | return maybe_trim_bos(res, tokenizer) 80 | elif "conversations" in row: 81 | msgs = [ 82 | {"role": ROLE_MAP[msg["from"]], "content": msg["value"]} 83 | for msg in row["conversations"] 84 | ] 85 | res = tokenizer.apply_chat_template(msgs, tokenize=False) 86 | return maybe_trim_bos(res, tokenizer) 87 | elif "messages" in row: 88 | res = tokenizer.apply_chat_template(row["messages"], tokenize=False) 89 | return maybe_trim_bos(res, tokenizer) 90 | else: 91 | raise ValueError("row must contain 'text' or 'conversations' or 'messages' key") 92 | 93 | 94 | def load_preprocess_data( 95 | *, 96 | dataset: str, 97 | configuration: str | None, 98 | split: str, 99 | samples: int | None, 100 | seed: int, 101 | max_seq_len: int, 102 | tokenizer: transformers.PreTrainedTokenizerBase, 103 | add_extra_pad_token: bool = False, 104 | apply_chat_template: bool = False, 105 | ): 106 | ds = datasets.load_dataset(dataset, name=configuration, split=split) 107 | ds = ds.shuffle(seed=seed) 108 | if samples is not None: 109 | ds = ds.select(range(samples)) 110 | if tokenizer.pad_token_id is None: 111 | tokenizer.pad_token_id = tokenizer.eos_token_id 112 | if apply_chat_template: 113 | ds = ds.map( 114 | lambda x: { 115 | "text": do_chat_template(x, tokenizer), 116 | }, 117 | num_proc=64, 118 | ) 119 | ds = ds.filter(lambda row: row["text"] and row["text"].strip(), num_proc=64) 120 | ds = ds.map( 121 | lambda x: { 122 | "input_ids": truncate_tokens( 123 | x["text"], tokenizer, max_seq_len, add_extra_pad_token 124 | ) 125 | }, 126 | num_proc=64, 127 | ).filter(lambda x: len(x["input_ids"]) > 0, num_proc=64) 128 | return ds 129 | 130 | 131 | def truncate_tokens( 132 | text: str, tokenizer, max_seq_len: int, add_extra_pad_token: bool = False 133 | ): 134 | tokens: torch.Tensor = tokenizer(text, return_tensors="pt")["input_ids"][0] 135 | if ( 136 | add_extra_pad_token 137 | and tokens.shape[0] < max_seq_len 138 | and tokens.shape[0] > 0 139 | and tokens[-1] != tokenizer.pad_token_id 140 | ): 141 | # add single padding token 142 | # so that we don't have to look at sampled_logprobs and can 143 | # just stick with prompt_logprobs 144 | tokens = torch.cat( 145 | [ 146 | tokens, 147 | torch.tensor( 148 | [tokenizer.pad_token_id], 149 | dtype=torch.long, 150 | device=tokens.device, 151 | ), 152 | ], 153 | dim=0, 154 | ) 155 | return tokens[:max_seq_len] 156 | 157 | 158 | class StreamingParquetWriter: 159 | def __init__( 160 | self, 161 | output_path: str, 162 | schema: pyarrow.Schema, 163 | file_max_rows: int, 164 | write_batch_size: int = 1000, 165 | queue_maxsize: int | None = None, 166 | ): 167 | """ 168 | Initializes the StreamingParquetWriter. 169 | 170 | Args: 171 | output_path (str): The directory where Parquet files will be saved. 172 | schema (pyarrow.Schema): The schema of the Parquet files. 173 | file_max_rows (int): The maximum number of rows per Parquet file. 174 | write_batch_size (int): The number of rows to buffer in memory before writing to disk. 175 | A larger batch size can improve performance but uses more memory. 176 | """ 177 | self.output_path = output_path 178 | os.makedirs(self.output_path, exist_ok=True) # Ensure output directory exists 179 | self.schema = schema 180 | 181 | if not (file_max_rows > 0): 182 | raise ValueError("file_max_rows must be a positive integer.") 183 | if not (write_batch_size > 0): 184 | raise ValueError("write_batch_size must be a positive integer.") 185 | 186 | self.file_max_rows = file_max_rows 187 | self.write_batch_size = write_batch_size 188 | 189 | self.pq_writer = None # The actual pyarrow.parquet.ParquetWriter instance 190 | self._current_rows_in_physical_file = ( 191 | 0 # Tracks rows written to the currently open .parquet file 192 | ) 193 | self.file_index = ( 194 | 0 # Used for naming output files (data_0.parquet, data_1.parquet, ...) 195 | ) 196 | 197 | self._write_queue = queue.Queue(maxsize=queue_maxsize) 198 | self._writer_thread = threading.Thread(target=self._writer_loop, daemon=True) 199 | self._shutdown_event = threading.Event() 200 | 201 | def start(self): 202 | self._writer_thread.start() 203 | 204 | def _ensure_writer_open(self): 205 | """Opens a new Parquet file writer if one is not already open.""" 206 | if self.pq_writer is None: 207 | file_path = os.path.join( 208 | self.output_path, f"data_{self.file_index}.parquet" 209 | ) 210 | self.pq_writer = pq.ParquetWriter(file_path, schema=self.schema) 211 | self._current_rows_in_physical_file = 0 # Reset row count for the new file 212 | 213 | def _write_batch_to_parquet(self, batch_data: list[dict[str, Any]]): 214 | if not batch_data: 215 | return 216 | 217 | self._ensure_writer_open() 218 | 219 | # Convert list of dicts to columnar for pyarrow.Table 220 | # This assumes batch_data is a list of row_dicts 221 | columnar_data = {name: [] for name in self.schema.names} 222 | for row_dict in batch_data: 223 | for name in self.schema.names: 224 | columnar_data[name].append(row_dict[name]) 225 | 226 | arrays = [] 227 | for name in self.schema.names: 228 | field_type = self.schema.field(name).type 229 | # Data from queue should be CPU numpy arrays or Python lists/values 230 | # If tensors were put on queue, .cpu().numpy() here 231 | # For example, if 'input_ids' was a tensor: 232 | # if name == 'input_ids' and isinstance(columnar_data[name][0], torch.Tensor): 233 | # data_to_convert = [t.cpu().numpy() for t in columnar_data[name]] 234 | # else: 235 | # data_to_convert = columnar_data[name] 236 | # arrays.append(pyarrow.array(data_to_convert, type=field_type)) 237 | arrays.append(pyarrow.array(columnar_data[name], type=field_type)) 238 | 239 | table = pyarrow.Table.from_arrays(arrays, schema=self.schema) 240 | self.pq_writer.write_table(table) 241 | self._current_rows_in_physical_file += len(batch_data) 242 | 243 | if self._current_rows_in_physical_file >= self.file_max_rows: 244 | if self.pq_writer is not None: 245 | self.pq_writer.close() 246 | self.pq_writer = None 247 | self.file_index += 1 248 | 249 | def _writer_loop(self): 250 | batch_to_write = [] 251 | while not self._shutdown_event.is_set() or not self._write_queue.empty(): 252 | try: 253 | # Wait for a short timeout to check shutdown_event periodically 254 | row_data = self._write_queue.get(timeout=0.1) 255 | if row_data is None: # Sentinel for shutdown 256 | self._write_queue.task_done() 257 | break 258 | batch_to_write.append(row_data) 259 | self._write_queue.task_done() 260 | 261 | if len(batch_to_write) >= self.write_batch_size: 262 | self._write_batch_to_parquet(batch_to_write) 263 | batch_to_write = [] 264 | except queue.Empty: 265 | continue # Loop again to check shutdown_event or new items 266 | 267 | # Flush any remaining items after loop ends 268 | if batch_to_write: 269 | self._write_batch_to_parquet(batch_to_write) 270 | 271 | if self.pq_writer is not None: 272 | self.pq_writer.close() 273 | self.pq_writer = None 274 | 275 | def write(self, row_data: dict[str, Any]): 276 | """ 277 | Adds a single row of data (as a dictionary) to the write queue. 278 | The dictionary keys should match schema names. 279 | Values should be CPU data (Python lists/values, or NumPy arrays). 280 | Tensors should be .cpu().numpy() or .tolist() BEFORE putting on queue. 281 | """ 282 | # Expects row_data to be a dictionary now for clarity 283 | # e.g., {"input_ids": [...], "compressed_logprobs": [...], ...} 284 | # Conversion of tensors to lists/numpy happens *before* this call. 285 | self._write_queue.put(row_data) 286 | 287 | def close(self): 288 | # Signal shutdown and wait for writer thread 289 | if self._writer_thread.is_alive(): 290 | self._write_queue.put(None) # Sentinel 291 | self._shutdown_event.set() 292 | self._writer_thread.join() 293 | 294 | def __enter__(self): 295 | self.start() 296 | return self 297 | 298 | def __exit__(self, exc_type, exc_val, traceback): 299 | self.close() 300 | return False 301 | 302 | 303 | def legacy_compressed_logit_schema() -> pyarrow.Schema: 304 | return pyarrow.schema( 305 | [ 306 | pyarrow.field("input_ids", pyarrow.list_(pyarrow.uint64())), 307 | pyarrow.field( 308 | "packed_indices", pyarrow.list_(pyarrow.list_(pyarrow.uint64())) 309 | ), 310 | pyarrow.field( 311 | "exact_values", pyarrow.list_(pyarrow.list_(pyarrow.float32())) 312 | ), 313 | pyarrow.field("coeffs", pyarrow.list_(pyarrow.list_(pyarrow.float32()))), 314 | ] 315 | ) 316 | 317 | 318 | def compressed_logit_schema() -> pyarrow.Schema: 319 | return pyarrow.schema( 320 | [ 321 | pyarrow.field("input_ids", pyarrow.list_(pyarrow.uint64())), 322 | pyarrow.field( 323 | "compressed_logprobs", pyarrow.list_(pyarrow.list_(pyarrow.uint8())) 324 | ), 325 | pyarrow.field( 326 | "bytepacked_indices", pyarrow.list_(pyarrow.list_(pyarrow.uint8())) 327 | ), 328 | ] 329 | ) 330 | -------------------------------------------------------------------------------- /distillkit/sample_logits_vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import click 5 | import transformers 6 | 7 | try: 8 | import vllm 9 | except ImportError: 10 | raise ImportError("VLLM must be installed to use this script.") 11 | import logging 12 | from concurrent.futures import ThreadPoolExecutor 13 | 14 | import torch 15 | import tqdm 16 | import yaml 17 | from vllm.logprobs import FlatLogprobs, PromptLogprobs 18 | 19 | from distillkit.compression import DistributionQuantizationConfig, LogprobCompressor 20 | from distillkit.sample_common import ( 21 | StreamingParquetWriter, 22 | compressed_logit_schema, 23 | load_preprocess_data, 24 | ) 25 | 26 | 27 | @click.command("sample-logits") 28 | @click.option("--model", type=str, required=True) 29 | @click.option("--dataset", type=str, required=True) 30 | @click.option("--dataset-configuration", type=str, default=None) 31 | @click.option("--output", type=str, required=True) 32 | @click.option("--split", type=str, default="train") 33 | @click.option("--samples", type=int, default=None) 34 | @click.option("--seed", type=int, default=42) 35 | @click.option("--tokenizer", type=str, default=None) 36 | @click.option("--apply-chat-template/--no-apply-chat-template", default=False) 37 | @click.option("--max-seq-len", type=int, default=1024) 38 | @click.option("--max-model-len", type=int, default=None) 39 | @click.option("--tensor-parallel-size", type=int, default=1) 40 | @click.option("--pipeline-parallel-size", type=int, default=1) 41 | @click.option( 42 | "--enable-expert-parallel/--no-enable-expert-parallel", type=bool, default=False 43 | ) 44 | @click.option("--dtype", type=str, default=None) 45 | @click.option("--quantization", type=str, default=None) 46 | @click.option("--trust-remote-code/--no-trust-remote-code", default=False) 47 | @click.option("--gpu-memory-utilization", type=float, default=0.9) 48 | @click.option("--compression-config", type=str, required=True) 49 | @click.option("--macrobatch-size", type=int, default=256) 50 | @click.option("--max-workers", type=int, default=None) 51 | @click.option("--auto-vocab-size/--no-auto-vocab-size", type=bool, default=True) 52 | def sample_logits( 53 | model: str, 54 | dataset: str, 55 | dataset_configuration: str | None, 56 | split: str, 57 | output: str, 58 | samples: int | None, 59 | tokenizer: str | None, 60 | apply_chat_template: bool, 61 | seed: int, 62 | max_seq_len: int, 63 | max_model_len: int | None, 64 | tensor_parallel_size: int, 65 | pipeline_parallel_size: int, 66 | enable_expert_parallel: bool, 67 | dtype: str | None, 68 | quantization: str | None, 69 | trust_remote_code: bool, 70 | gpu_memory_utilization: float, 71 | compression_config: str, 72 | macrobatch_size: int, 73 | max_workers: int | None, 74 | auto_vocab_size: bool, 75 | ): 76 | logging.basicConfig(level=logging.INFO) 77 | 78 | tok = transformers.AutoTokenizer.from_pretrained( 79 | tokenizer or model, trust_remote_code=trust_remote_code 80 | ) 81 | 82 | # load compression config 83 | with open(compression_config, "r") as f: 84 | cfg = DistributionQuantizationConfig.model_validate(yaml.safe_load(f)) 85 | k = cfg.k 86 | 87 | tok_vocab = tok.get_vocab() 88 | tok_vocab_size = max(len(tok_vocab) + 1, max(tok_vocab.values())) 89 | if cfg.d != tok_vocab_size: 90 | if auto_vocab_size: 91 | cfg.d = tok_vocab_size 92 | logging.warning( 93 | f"Automatically set compressor vocab size to {tok_vocab_size}" 94 | ) 95 | elif cfg.d < tok_vocab_size: 96 | logging.error("Compression config has too small vocabulary size!") 97 | logging.error( 98 | f"cfg.d: {cfg.d}, effective tokenizer vocab size: {tok_vocab_size}" 99 | ) 100 | sys.exit(-1) 101 | elif ( 102 | abs(cfg.d - tok_vocab_size) > 32 103 | ): # allow a little wiggle room for common padding 104 | logging.warning( 105 | f"Vocabulary size in compression config ({cfg.d}) is larger than needed ({tok_vocab_size}). " 106 | "This will work but may consume more space than needed - double check that this is what you want." 107 | ) 108 | 109 | os.makedirs(output, exist_ok=True) 110 | with open( 111 | os.path.join(output, "compression_config.yaml"), "w", encoding="utf-8" 112 | ) as f: 113 | yaml.safe_dump(cfg.model_dump(mode="json"), f) 114 | 115 | logging.info(f"Loading and preprocessing data from {dataset} ({split})") 116 | ds = load_preprocess_data( 117 | dataset=dataset, 118 | configuration=dataset_configuration, 119 | split=split, 120 | samples=samples, 121 | seed=seed, 122 | max_seq_len=max_seq_len, 123 | tokenizer=tok, 124 | add_extra_pad_token=True, 125 | apply_chat_template=apply_chat_template, 126 | ) 127 | 128 | llm = vllm.LLM( 129 | model=model, 130 | tokenizer=tokenizer, 131 | dtype=dtype, 132 | quantization=quantization, 133 | trust_remote_code=trust_remote_code, 134 | tensor_parallel_size=tensor_parallel_size, 135 | pipeline_parallel_size=pipeline_parallel_size, 136 | enable_expert_parallel=enable_expert_parallel, 137 | gpu_memory_utilization=gpu_memory_utilization, 138 | max_logprobs=k, 139 | logprobs_mode="raw_logprobs", 140 | max_model_len=max_model_len, 141 | ) 142 | 143 | compressor = LogprobCompressor( 144 | config=cfg, 145 | ) 146 | 147 | sampling_params = vllm.SamplingParams( 148 | temperature=1.0, 149 | top_p=1, 150 | min_p=0, 151 | top_k=-1, 152 | frequency_penalty=0, 153 | presence_penalty=0, 154 | repetition_penalty=1, 155 | prompt_logprobs=k, 156 | logprobs=k, 157 | flat_logprobs=True, 158 | max_tokens=1, # vLLM wants at least 1 generated token 159 | detokenize=False, 160 | skip_special_tokens=False, 161 | ) 162 | 163 | logging.info(f"Generating logits for {len(ds)} samples") 164 | 165 | def process_and_write_sample( 166 | req_out: vllm.RequestOutput, 167 | input_ids_sample: list[int], 168 | k: int, 169 | compressor: LogprobCompressor, 170 | writer: StreamingParquetWriter, 171 | ) -> None: 172 | """Process a single sample: extract logprobs, compress, and write to disk.""" 173 | top_indices, top_values = process_prompt_logprobs(req_out.prompt_logprobs, k=k) 174 | top_indices.unsqueeze_(0) 175 | top_values.unsqueeze_(0) 176 | 177 | row_out = compressor.compress_from_sparse( 178 | top_indices, 179 | top_values, 180 | ) 181 | 182 | compressed_logprobs_list = ( 183 | row_out["compressed_logprobs"].cpu().squeeze(0).tolist() 184 | ) 185 | bytepacked_indices_list = ( 186 | row_out["bytepacked_indices"].cpu().squeeze(0).tolist() 187 | ) 188 | 189 | writer.write( 190 | { 191 | "input_ids": input_ids_sample, 192 | "compressed_logprobs": compressed_logprobs_list, 193 | "bytepacked_indices": bytepacked_indices_list, 194 | } 195 | ) 196 | 197 | try: 198 | with StreamingParquetWriter( 199 | output, 200 | schema=compressed_logit_schema(), 201 | file_max_rows=macrobatch_size, 202 | queue_maxsize=macrobatch_size * 2, 203 | ) as writer: 204 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 205 | futures = [] 206 | 207 | for i0 in tqdm.tqdm( 208 | range(0, len(ds), macrobatch_size), desc="Logit Batches" 209 | ): 210 | batch_input_ids = ds[i0 : i0 + macrobatch_size]["input_ids"] 211 | # Submit CPU processing tasks to background thread 212 | for idx, req_out in enumerate( 213 | llm.generate( 214 | [{"prompt_token_ids": x} for x in batch_input_ids], 215 | sampling_params=sampling_params, 216 | ) 217 | ): 218 | future = executor.submit( 219 | process_and_write_sample, 220 | req_out, 221 | batch_input_ids[idx], 222 | k, 223 | compressor, 224 | writer, 225 | ) 226 | futures.append(future) 227 | 228 | # Limit writes in flight to avoid unbounded memory growth 229 | while len(futures) > macrobatch_size * 2: 230 | futures.pop(0).result() 231 | 232 | for future in futures: 233 | future.result() 234 | 235 | logging.info(f"Logits saved to {output}") 236 | finally: 237 | del llm 238 | 239 | 240 | def process_prompt_logprobs( 241 | prompt_logprobs: PromptLogprobs, k: int 242 | ) -> tuple[torch.LongTensor, torch.Tensor]: 243 | # Fast path: directly access FlatLogprobs data without materializing dicts 244 | if isinstance(prompt_logprobs, FlatLogprobs): 245 | # Skip first position if it's empty (first token has no logprobs) 246 | start_pos = 0 247 | if len(prompt_logprobs) > 0: 248 | first_start = prompt_logprobs.start_indices[0] 249 | first_end = prompt_logprobs.end_indices[0] 250 | if first_end - first_start == 0: 251 | start_pos = 1 252 | 253 | num_prompt_tokens = len(prompt_logprobs) - start_pos 254 | if num_prompt_tokens <= 0: 255 | return torch.empty((0, 0), dtype=torch.long), torch.empty( 256 | (0, 0), dtype=torch.float32 257 | ) 258 | 259 | top_indices = torch.empty( 260 | (num_prompt_tokens, k), dtype=torch.long, device="cpu" 261 | ) 262 | top_values = torch.full( 263 | (num_prompt_tokens, k), 264 | fill_value=float("-inf"), 265 | dtype=torch.float32, 266 | device="cpu", 267 | ) 268 | 269 | # Build index arrays for vectorized assignment 270 | seq_ids = [] 271 | rank_ids = [] 272 | token_ids_to_copy = [] 273 | logprobs_to_copy = [] 274 | 275 | for pos_id in range(start_pos, len(prompt_logprobs)): 276 | seq_id = pos_id - start_pos 277 | start_idx = prompt_logprobs.start_indices[pos_id] 278 | end_idx = prompt_logprobs.end_indices[pos_id] 279 | 280 | for i in range(start_idx, end_idx): 281 | rank = prompt_logprobs.ranks[i] 282 | if rank is None or rank > k: 283 | # None: vLLM returns the actual prompt token even when not in top-k 284 | # rank > k: Truncate to only the k values we requested 285 | continue 286 | seq_ids.append(seq_id) 287 | rank_ids.append(rank - 1) 288 | token_ids_to_copy.append(prompt_logprobs.token_ids[i]) 289 | logprobs_to_copy.append(prompt_logprobs.logprobs[i]) 290 | 291 | # Vectorized assignment using advanced indexing 292 | if seq_ids: 293 | seq_idx_tensor = torch.tensor(seq_ids, dtype=torch.long) 294 | rank_idx_tensor = torch.tensor(rank_ids, dtype=torch.long) 295 | top_indices[seq_idx_tensor, rank_idx_tensor] = torch.tensor( 296 | token_ids_to_copy, dtype=top_indices.dtype, device=top_indices.device 297 | ) 298 | top_values[seq_idx_tensor, rank_idx_tensor] = torch.tensor( 299 | logprobs_to_copy, dtype=top_values.dtype, device=top_values.device 300 | ) 301 | 302 | return top_indices, top_values 303 | 304 | # Slow path: handle legacy list format 305 | else: 306 | valid_logprobs = [lp for lp in prompt_logprobs] 307 | if valid_logprobs[0] is None or len(valid_logprobs[0]) < 1: 308 | valid_logprobs.pop(0) 309 | 310 | if not valid_logprobs: 311 | return torch.empty((0, 0), dtype=torch.long), torch.empty( 312 | (0, 0), dtype=torch.float32 313 | ) 314 | 315 | num_prompt_tokens = len(valid_logprobs) 316 | 317 | top_indices = torch.empty((num_prompt_tokens, k), dtype=torch.long) 318 | top_values = torch.full( 319 | (num_prompt_tokens, k), fill_value=float("-inf"), dtype=torch.float32 320 | ) 321 | for seq_id, logprobs in enumerate(valid_logprobs): 322 | assert logprobs is not None, ( 323 | f"Missing logprobs for token at position {seq_id + 1} (expected logprobs for all non-first tokens)" 324 | ) 325 | for tok_id, logprob in logprobs.items(): 326 | if logprob.rank is None or logprob.rank > k: 327 | continue 328 | top_indices[seq_id, logprob.rank - 1] = tok_id 329 | top_values[seq_id, logprob.rank - 1] = logprob.logprob 330 | 331 | return top_indices, top_values 332 | 333 | 334 | if __name__ == "__main__": 335 | sample_logits() 336 | -------------------------------------------------------------------------------- /distillkit/lossfuncs/jsd.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from transformers.modeling_outputs import CausalLMOutput 5 | from typing_extensions import override 6 | 7 | from distillkit.hsd_mapping import HiddenStateMapping 8 | from distillkit.lossfuncs.common import ( 9 | LossFunctionBase, 10 | MissingProbabilityHandling, 11 | accumulate_over_chunks, 12 | get_logprobs, 13 | ) 14 | from distillkit.signals import DenseSignal, TeacherSignal 15 | 16 | 17 | def sparse_jsd_inner( 18 | logits: torch.Tensor, 19 | target_ids: torch.LongTensor, 20 | target_values: torch.Tensor, 21 | mask: torch.Tensor | None = None, 22 | eps: float = 1e-8, 23 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 24 | log_target: bool = True, 25 | temperature: float = 1.0, 26 | target_generation_temperature: float = 1.0, 27 | student_generation_temperature: float = 1.0, 28 | ) -> torch.Tensor: 29 | batch_size, seq_len, vocab_size = logits.shape 30 | out_dtype = logits.dtype 31 | sparse_student_logprobs, sparse_target_logprobs = get_logprobs( 32 | logits, 33 | target_ids, 34 | target_values, 35 | eps=eps, 36 | missing=missing, 37 | log_target=log_target, 38 | distillation_temperature=temperature, 39 | target_generation_temperature=target_generation_temperature, 40 | student_generation_temperature=student_generation_temperature, 41 | ) 42 | sparse_student_probs = torch.exp(sparse_student_logprobs) 43 | sparse_teacher_probs = torch.exp(sparse_target_logprobs) 44 | 45 | # --- 3. Compute Mixture M (log M) --- 46 | # Common sparse part for M 47 | # Ensure P_i + Q_i is not zero before log. Add eps to P+Q or use M.clamp(min=eps) 48 | # log(0.5 * (P_i + Q_i)) = log(0.5) + log(P_i + Q_i) 49 | # Using .clamp for safety if P_i and Q_i can both be zero for some index. 50 | M_sparse_probs = 0.5 * (sparse_teacher_probs + sparse_student_probs) 51 | log_M_sparse = torch.log(M_sparse_probs.clamp(min=eps)) 52 | 53 | if missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM: 54 | teacher_prob_sum_sparse = sparse_teacher_probs.to(torch.float32).sum( 55 | dim=-1, keepdim=True 56 | ) 57 | # log1p for log(1-x) is good 58 | log_teacher_missing_prob = torch.log1p( 59 | -teacher_prob_sum_sparse.clamp(min=eps, max=1.0 - eps) 60 | ) 61 | teacher_missing_prob = torch.exp(log_teacher_missing_prob) 62 | 63 | student_prob_sum_sparse = sparse_student_probs.to(torch.float32).sum( 64 | dim=-1, keepdim=True 65 | ) 66 | log_student_missing_prob = torch.log1p( 67 | -student_prob_sum_sparse.clamp(min=eps, max=1.0 - eps) 68 | ) 69 | student_missing_prob = torch.exp(log_student_missing_prob) 70 | 71 | M_missing_prob = 0.5 * (teacher_missing_prob + student_missing_prob) 72 | log_M_missing = torch.log(M_missing_prob.clamp(min=eps)) 73 | 74 | # --- 4. Compute KL(P || M) --- 75 | # P_i * (log P_i - log M_i) 76 | # Handle P_i = 0 case: 0 * log(0/M_i) = 0. This is implicitly handled if P_i is small. 77 | # If sparse_teacher_probs can be exactly zero, ensure 0 * -inf is 0. 78 | # PyTorch's 0 * -inf is nan. Use torch.where(sparse_teacher_probs > eps, ..., 0.0) 79 | kl_P_M_sparse_terms = sparse_teacher_probs * (sparse_target_logprobs - log_M_sparse) 80 | kl_P_M_sparse_sum = torch.sum( 81 | torch.where( 82 | sparse_teacher_probs > eps, 83 | kl_P_M_sparse_terms, 84 | torch.zeros_like(kl_P_M_sparse_terms), 85 | ), 86 | dim=-1, 87 | ) 88 | 89 | if missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM: 90 | # teacher_missing_prob * (log_teacher_missing_prob - log_M_missing) 91 | kl_P_M_missing_term = teacher_missing_prob * ( 92 | log_teacher_missing_prob - log_M_missing 93 | ) 94 | kl_P_M = kl_P_M_sparse_sum + torch.where( 95 | teacher_missing_prob.squeeze(-1) > eps, 96 | kl_P_M_missing_term.squeeze(-1), 97 | torch.zeros_like(kl_P_M_missing_term.squeeze(-1)), 98 | ) 99 | else: # ZERO 100 | kl_P_M = kl_P_M_sparse_sum 101 | del sparse_target_logprobs # No longer needed for P 102 | 103 | # --- 5. Compute KL(Q || M) --- 104 | # Q_i * (log Q_i - log M_i) 105 | kl_Q_M_sparse_terms = sparse_student_probs * ( 106 | sparse_student_logprobs - log_M_sparse 107 | ) 108 | kl_Q_M_sparse_sum = torch.sum( 109 | torch.where( 110 | sparse_student_probs > eps, 111 | kl_Q_M_sparse_terms, 112 | torch.zeros_like(kl_Q_M_sparse_terms), 113 | ), 114 | dim=-1, 115 | ) 116 | 117 | if missing == MissingProbabilityHandling.SYMMETRIC_UNIFORM: 118 | kl_Q_M_missing_term = student_missing_prob * ( 119 | log_student_missing_prob - log_M_missing 120 | ) 121 | kl_Q_M = kl_Q_M_sparse_sum + torch.where( 122 | student_missing_prob.squeeze(-1) > eps, 123 | kl_Q_M_missing_term.squeeze(-1), 124 | torch.zeros_like(kl_Q_M_missing_term.squeeze(-1)), 125 | ) 126 | del ( 127 | student_missing_prob, 128 | log_student_missing_prob, 129 | M_missing_prob, 130 | log_M_missing, 131 | ) # Free memory 132 | else: # ZERO 133 | # Contribution from Q_missing_i * log(2) 134 | # We need total prob mass of student for tokens NOT in target_ids 135 | student_prob_sum_sparse = sparse_student_probs.sum(dim=-1) # B, S 136 | # total probability is 1, so mass for missing is 1 - sum_sparse 137 | student_total_missing_prob_mass = (1.0 - student_prob_sum_sparse).clamp( 138 | min=0.0 139 | ) # B, S 140 | kl_Q_M_missing_contrib = student_total_missing_prob_mass * math.log(2.0) 141 | kl_Q_M = kl_Q_M_sparse_sum + kl_Q_M_missing_contrib 142 | 143 | del ( 144 | sparse_student_logprobs, 145 | sparse_student_probs, 146 | M_sparse_probs, 147 | log_M_sparse, 148 | ) 149 | 150 | # --- 6. Combine for JSD --- 151 | jsd_terms = 0.5 * (kl_P_M + kl_Q_M) 152 | 153 | # --- 7. Masking and Aggregation --- 154 | if mask is not None: 155 | if mask.dim() == 2: # B, S 156 | mask_squozed = mask 157 | else: # B, S, 1 158 | mask_squozed = mask.squeeze(-1) 159 | jsd_terms *= mask_squozed 160 | 161 | return torch.sum(jsd_terms).to(out_dtype) 162 | 163 | 164 | def sparse_js_div( 165 | logits: torch.Tensor, 166 | target_ids: torch.LongTensor, 167 | target_values: torch.Tensor, 168 | mask: torch.Tensor | None = None, 169 | eps: float = 1e-8, 170 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 171 | log_target: bool = True, 172 | temperature: float = 1.0, 173 | target_generation_temperature: float = 1.0, 174 | student_generation_temperature: float = 1.0, 175 | chunk_length: int | None = None, 176 | ) -> torch.Tensor: 177 | """Compute the Jensen-Shannon Divergence (JSD) between a dense set of predictions and a sparse set of target logits. 178 | 179 | Uses a chunked approach to avoid memory issues with large sequences. 180 | 181 | Args: 182 | logits: Dense tensor of predictions. 183 | target_ids: Tensor of indices for target logits. 184 | target_values: Tensor of values for target logits or log probabilities. 185 | mask: Optional boolean mask tensor. True indicates tokens to include, False to exclude. 186 | eps: Small value to prevent numerical instability. 187 | missing: How to handle missing probabilities in the target distribution. If ZERO, missing 188 | probabilities are assumed to be zero. If SYMMETRIC_UNIFORM, missing probabilities are 189 | assumed to be distributed uniformly over the missing tokens in both the teacher and 190 | student distributions. 191 | log_target: Whether the target values are already log probabilities. 192 | temperature: Temperature to apply to the distributions. 193 | target_generation_temperature: Temperature already applied to the target logits/logprobs. 194 | student_generation_temperature: Temperature already applied to the student logits. 195 | chunk_length: Number of tokens per chunk. If None, the entire sequence is processed at once. 196 | """ 197 | return accumulate_over_chunks( 198 | logits, 199 | target_ids, 200 | target_values, 201 | mask, 202 | chunk_length, 203 | sparse_jsd_inner, 204 | eps=eps, 205 | missing=missing, 206 | log_target=log_target, 207 | temperature=temperature, 208 | target_generation_temperature=target_generation_temperature, 209 | student_generation_temperature=student_generation_temperature, 210 | ) 211 | 212 | 213 | def dense_js_div( 214 | logits: torch.Tensor, 215 | target_logits: torch.Tensor, 216 | mask: torch.Tensor | None = None, 217 | temperature: float = 1.0, 218 | ) -> torch.Tensor: 219 | """ 220 | Compute the Jensen-Shannon Divergence (JSD) between dense student predictions and dense target logits. 221 | 222 | JSD(P || Q) = 0.5 * KL(P || M) + 0.5 * KL(Q || M) 223 | where M = 0.5 * (P + Q) 224 | 225 | Args: 226 | logits: Student logits (Batch, Seq, Vocab). 227 | target_logits: Teacher/Target logits (Batch, Seq, Vocab). 228 | mask: Optional boolean mask tensor. True indicates tokens to include. 229 | temperature: Temperature to apply to both distributions. 230 | 231 | Returns: 232 | torch.Tensor: Scalar JSD loss averaged over the batch. 233 | """ 234 | out_dtype = logits.dtype 235 | 236 | # 1. Apply temperature scaling 237 | student_logits = logits / temperature 238 | teacher_logits = target_logits / temperature 239 | 240 | # 2. Compute log probabilities 241 | # Using log_softmax ensures numerical stability compared to softmax -> log 242 | student_log_probs = torch.log_softmax(student_logits.float(), dim=-1) 243 | teacher_log_probs = torch.log_softmax(teacher_logits.float(), dim=-1) 244 | 245 | # 3. Compute probabilities 246 | student_probs = torch.exp(student_log_probs) 247 | teacher_probs = torch.exp(teacher_log_probs) 248 | 249 | # 4. Compute Mixture Distribution M (in log space) 250 | # M = 0.5 * (P + Q) 251 | # log(M) = log(0.5) + log(exp(log_P) + exp(log_Q)) 252 | # We use logaddexp for numerical stability to avoid underflow/overflow 253 | mixture_log_probs = torch.logaddexp( 254 | student_log_probs, teacher_log_probs 255 | ) - math.log(2.0) 256 | 257 | # 5. Compute KL Divergences 258 | # KL(P || M) = sum( P * (log_P - log_M) ) 259 | # Note: We compute manually rather than using F.kl_div to utilize precomputed log_probs 260 | kl_student_mixture = torch.sum( 261 | student_probs * (student_log_probs - mixture_log_probs), dim=-1 262 | ) 263 | kl_teacher_mixture = torch.sum( 264 | teacher_probs * (teacher_log_probs - mixture_log_probs), dim=-1 265 | ) 266 | 267 | # 6. Combine for JSD 268 | jsd_terms = 0.5 * (kl_student_mixture + kl_teacher_mixture) 269 | 270 | # 7. Apply Masking 271 | if mask is not None: 272 | if mask.dim() == 2: # B, S 273 | mask_squeezed = mask 274 | else: # B, S, 1 275 | mask_squeezed = mask.squeeze(-1) 276 | jsd_terms = jsd_terms * mask_squeezed 277 | 278 | return torch.sum(jsd_terms).to(out_dtype) 279 | 280 | 281 | class JSDLoss(LossFunctionBase): 282 | temperature: float 283 | missing: MissingProbabilityHandling = MissingProbabilityHandling.ZERO 284 | chunk_length: int | None = None 285 | 286 | @override 287 | @classmethod 288 | def name(cls) -> str: 289 | return "jsd" 290 | 291 | @override 292 | def __init__( 293 | self, 294 | temperature: float, 295 | missing_probability_handling: MissingProbabilityHandling = MissingProbabilityHandling.ZERO, 296 | sparse_chunk_length: int | None = None, 297 | ) -> None: 298 | self.temperature = temperature 299 | self.missing = missing_probability_handling 300 | self.chunk_length = sparse_chunk_length 301 | 302 | @override 303 | def __call__( 304 | self, 305 | student_outputs: CausalLMOutput, 306 | signal: TeacherSignal, 307 | mask: torch.Tensor | None = None, 308 | hidden_state_mapping: HiddenStateMapping | None = None, 309 | num_items_in_batch: int | None = None, 310 | ) -> torch.Tensor: 311 | if num_items_in_batch is None: 312 | if mask is not None: 313 | num_items_in_batch = mask.float().sum() 314 | else: 315 | num_items_in_batch = ( 316 | student_outputs.logits.shape[0] * student_outputs.logits.shape[1] 317 | ) 318 | if isinstance(signal, DenseSignal): 319 | res = dense_js_div( 320 | student_outputs.logits, 321 | signal.logits, 322 | mask=mask, 323 | temperature=self.temperature, 324 | ) 325 | else: 326 | res = sparse_js_div( 327 | logits=student_outputs.logits, 328 | target_ids=signal.sparse_ids, 329 | target_values=signal.sparse_values, 330 | mask=mask, 331 | missing=self.missing, 332 | log_target=signal.log_values, 333 | temperature=self.temperature, 334 | target_generation_temperature=signal.generation_temperature, 335 | chunk_length=self.chunk_length, 336 | ) 337 | return res * (self.temperature**2) / num_items_in_batch 338 | -------------------------------------------------------------------------------- /distillkit/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 Arcee AI 2 | import hashlib 3 | import json 4 | import logging 5 | import os 6 | import re 7 | from typing import Any 8 | 9 | import click 10 | import datasets 11 | import torch 12 | import transformers 13 | import trl 14 | import yaml 15 | from accelerate import Accelerator 16 | 17 | from distillkit.compression import LogprobCompressor 18 | from distillkit.configuration import ( 19 | DatasetConfiguration, 20 | DatasetPath, 21 | DistillationRunConfig, 22 | HfRepoDataset, 23 | LocalDataset, 24 | TeacherDatasetConfig, 25 | TeacherModelConfig, 26 | ) 27 | from distillkit.hsd_mapping import HiddenStateMapping 28 | from distillkit.monkey_patch_packing import monkey_patch_packing_for_model 29 | from distillkit.signals import OfflineSignalSource, OnlineSignalSource, SignalSource 30 | from distillkit.trainer import DistillationTrainer 31 | 32 | LOG = logging.getLogger(__name__) 33 | 34 | 35 | def _format_row( 36 | example: dict[str, Any], tokenizer: transformers.PreTrainedTokenizer 37 | ) -> dict[str, Any]: 38 | if ("input_ids" in example) or ("text" in example): 39 | # either pretokenized or raw completion - no formatting needed 40 | return {} 41 | elif "conversations" in example: 42 | conversations = example["conversations"] 43 | 44 | messages = [] 45 | for conversation in conversations: 46 | role_map = { 47 | "human": "user", 48 | "user": "user", 49 | "gpt": "assistant", 50 | "assistant": "assistant", 51 | "system": "system", 52 | } 53 | role = role_map.get(conversation.get("from", ""), None) 54 | if role: 55 | messages.append( 56 | {"role": role, "content": conversation.get("value", "")} 57 | ) 58 | 59 | # Apply chat template to create a single string. SFTTrainer will handle tokenization. 60 | text = tokenizer.apply_chat_template( 61 | messages, tokenize=False, add_generation_prompt=False 62 | ) 63 | return {"text": text} 64 | elif "messages" in example: 65 | text = tokenizer.apply_chat_template( 66 | example["messages"], tokenize=False, add_generation_prompt=False 67 | ) 68 | return {"text": text} 69 | else: 70 | raise RuntimeError("Expected `text`, `messages`, or `conversations` column") 71 | 72 | 73 | def _load_dataset( 74 | path: DatasetPath, 75 | seed: int | None, 76 | num_samples: int | None, 77 | tokenizer: transformers.PreTrainedTokenizer, 78 | prepared_dataset_path: str | None = None, 79 | keep_in_memory: bool | None = None, 80 | prepacked: bool = False, 81 | ) -> datasets.Dataset: 82 | if prepared_dataset_path: 83 | honk = json.dumps( 84 | { 85 | "path": path.model_dump(), 86 | "seed": seed, 87 | "num_samples": num_samples, 88 | } 89 | ) 90 | logging.info(f"Dataset spec: {honk}") 91 | ds_hash = hashlib.sha256(honk.encode()).hexdigest() 92 | full_prepared_path = os.path.join(prepared_dataset_path, f"dataset-{ds_hash}") 93 | if os.path.exists(full_prepared_path): 94 | return datasets.load_from_disk(full_prepared_path) 95 | else: 96 | full_prepared_path = None 97 | if isinstance(path, HfRepoDataset): 98 | res = datasets.load_dataset( 99 | path.repo_id, 100 | name=path.config_name, 101 | revision=path.revision, 102 | split=path.split, 103 | keep_in_memory=keep_in_memory, 104 | ) 105 | elif isinstance(path, LocalDataset): 106 | res = datasets.load_from_disk(path.disk_path, keep_in_memory=keep_in_memory) 107 | if path.split: 108 | res = res[path.split] 109 | elif isinstance(res, datasets.DatasetDict): 110 | raise ValueError( 111 | "Dataset dict found but no split specified. Please specify a split." 112 | ) 113 | else: 114 | raise ValueError( 115 | "Unsupported dataset type. Please provide a valid Hugging Face repo ID or local dataset path." 116 | ) 117 | 118 | if prepacked: 119 | last_idx = len(res) - 1 120 | while len(res) >= 2 and len(res[last_idx]["input_ids"]) != len( 121 | res[0]["input_ids"] 122 | ): 123 | last_idx -= 1 124 | if last_idx <= 0: 125 | raise RuntimeError("Dataset config is probs wrong") 126 | res = res.select(range(last_idx + 1)) 127 | 128 | if seed: 129 | res = res.shuffle(seed=seed) 130 | if num_samples: 131 | res = res.select(range(num_samples)) 132 | if ( 133 | (not prepacked) 134 | and ("text" not in res.column_names) 135 | and ("input_ids" not in res.column_names) 136 | ): 137 | res = res.map( 138 | _format_row, 139 | remove_columns=res.column_names, 140 | fn_kwargs={"tokenizer": tokenizer}, 141 | ) 142 | if full_prepared_path: 143 | os.makedirs(full_prepared_path, exist_ok=True) 144 | logging.info( 145 | f"Saving prepared dataset to {full_prepared_path} (hash: {ds_hash}, path: {path}, seed: {seed}, num_samples: {num_samples})" 146 | ) 147 | res.save_to_disk(full_prepared_path) 148 | del res 149 | return datasets.load_from_disk( 150 | full_prepared_path, keep_in_memory=keep_in_memory 151 | ) 152 | return res 153 | 154 | 155 | def load_data( 156 | config: DatasetConfiguration, 157 | tokenizer: transformers.PreTrainedTokenizer, 158 | keep_in_memory: bool | None = None, 159 | ) -> tuple[datasets.Dataset, datasets.Dataset | None]: 160 | """ 161 | Load the train (and optionally eval) datasets as specified in the configuration. 162 | """ 163 | 164 | LOG.info( 165 | f"Loading datasets: {config.train_dataset} (train), {config.eval_dataset} (eval)" 166 | ) 167 | ds_train = _load_dataset( 168 | config.train_dataset, 169 | config.seed, 170 | config.num_samples, 171 | tokenizer=tokenizer, 172 | prepared_dataset_path=config.prepared_dataset_path, 173 | keep_in_memory=keep_in_memory, 174 | prepacked=config.prepacked, 175 | ) 176 | ds_eval = None 177 | if config.eval_dataset: 178 | ds_eval = _load_dataset( 179 | config.eval_dataset, 180 | config.seed, 181 | config.num_eval_samples, 182 | tokenizer=tokenizer, 183 | prepared_dataset_path=config.prepared_dataset_path, 184 | keep_in_memory=keep_in_memory, 185 | prepacked=config.prepacked, 186 | ) 187 | return ds_train, ds_eval 188 | 189 | 190 | def load_student_model( 191 | config: DistillationRunConfig, 192 | tokenizer_vocab_size: int, 193 | ) -> transformers.PreTrainedModel: 194 | if config.functionary_packing: 195 | monkey_patch_packing_for_model(config.train_model) 196 | auto_cls = getattr(transformers, config.model_auto_class, None) 197 | if auto_cls is None: 198 | raise ValueError( 199 | f"Model class {config.model_auto_class} not found in transformers." 200 | ) 201 | LOG.info(f"Loading model {config.train_model} with class {auto_cls}") 202 | extra_kwargs = {"trust_remote_code": config.trust_remote_code} 203 | if config.use_flash_attention: 204 | extra_kwargs["attn_implementation"] = "flash_attention_2" 205 | extra_kwargs["torch_dtype"] = torch.bfloat16 206 | model = auto_cls.from_pretrained( 207 | config.train_model, 208 | **extra_kwargs, 209 | **config.model_kwargs, 210 | ) 211 | LOG.info("Loaded model.") 212 | 213 | model_vocab_size = model.get_input_embeddings().weight.shape[0] 214 | if ( 215 | model_vocab_size != tokenizer_vocab_size 216 | or config.resize_embeddings_to_multiple_of 217 | ): 218 | model.resize_token_embeddings( 219 | tokenizer_vocab_size, 220 | pad_to_multiple_of=config.resize_embeddings_to_multiple_of, 221 | ) 222 | new_model_vocab_size = model.get_input_embeddings().weight.shape[0] 223 | if new_model_vocab_size != model_vocab_size: 224 | LOG.info( 225 | f"Resized model vocab size from {model_vocab_size} to {new_model_vocab_size}" 226 | ) 227 | 228 | model: transformers.PreTrainedModel 229 | if config.frozen_modules: 230 | module_set = set(config.frozen_modules) 231 | seen = set() 232 | for name, module in model.named_modules(): 233 | if name in module_set: 234 | module.requires_grad_(False) 235 | seen.add(name) 236 | unseen = module_set - seen 237 | LOG.info(f"Froze {len(seen)} modules") 238 | if unseen: 239 | raise ValueError(f"Frozen modules not found in model: {', '.join(unseen)}") 240 | if config.frozen_res: 241 | num_frozen = 0 242 | frozen_res = [re.compile(s) for s in config.frozen_res] 243 | for name, param in model.named_parameters(): 244 | if any(fre.search(name) for fre in frozen_res): 245 | param.requires_grad = False 246 | num_frozen += 1 247 | if num_frozen: 248 | print(f"Froze {num_frozen} tensors by regular expression") 249 | return model 250 | 251 | 252 | def create_signal_source( 253 | config: DistillationRunConfig, vocab_size: int 254 | ) -> SignalSource: 255 | if isinstance(config.teacher, TeacherDatasetConfig): 256 | compressor = LogprobCompressor( 257 | config=config.teacher.logprob_compressor, 258 | legacy_config=config.teacher.legacy_logit_compression, 259 | ) 260 | return OfflineSignalSource(compressor, vocab_size=vocab_size) 261 | elif isinstance(config.teacher, TeacherModelConfig): 262 | teacher_model = transformers.AutoModelForCausalLM.from_pretrained( 263 | config.teacher.path, **(config.teacher.kwargs or {}) 264 | ) 265 | return OnlineSignalSource( 266 | teacher_model, vocab_size=vocab_size, sparsify_top_k=config.teacher.top_k 267 | ) 268 | else: 269 | raise RuntimeError("Teacher configuration invalid") 270 | 271 | 272 | def collate_packed_batch(examples): 273 | # all sequences in the batch already have the same length 274 | # so we can directly stack them 275 | return { 276 | key: torch.tensor([example[key] for example in examples]) 277 | for key in examples[0].keys() 278 | } 279 | 280 | 281 | def load_tokenizer(config: DistillationRunConfig) -> transformers.PreTrainedTokenizer: 282 | if isinstance(config.teacher, TeacherModelConfig): 283 | src_path = config.teacher.path 284 | logging.info("Using teacher's tokenizer") 285 | else: 286 | src_path = config.train_model 287 | logging.info("Using student's tokenizer") 288 | return transformers.AutoTokenizer.from_pretrained( 289 | src_path, 290 | trust_remote_code=config.trust_remote_code, 291 | ) 292 | 293 | 294 | def do_distill(config: DistillationRunConfig, config_source: str | None = None): 295 | os.makedirs(config.output_path, exist_ok=True) 296 | if config_source is None: 297 | config_source = yaml.safe_dump(config.model_dump(mode="json", by_alias=True)) 298 | with open(os.path.join(config.output_path, "distillkit_config.yaml"), "w") as f: 299 | f.write(config_source) 300 | 301 | if config.project_name: 302 | os.environ["WANDB_PROJECT"] = config.project_name 303 | 304 | accelerator = Accelerator() 305 | with accelerator.main_process_first(): 306 | tokenizer = load_tokenizer(config) 307 | ds_train, ds_eval = load_data(config.dataset, tokenizer) 308 | 309 | tokenizer_vocab_size = max( 310 | len(tokenizer.get_vocab()), 311 | max(tokenizer.get_vocab().values()) + 1, 312 | ) 313 | 314 | model = load_student_model(config, tokenizer_vocab_size) 315 | 316 | config_kwargs = dict(config.training_args) 317 | dataset_kwargs = config_kwargs.pop("dataset_kwargs", {}) 318 | if config.dataset.prepacked: 319 | dataset_kwargs["skip_prepare_dataset"] = True 320 | max_length = config_kwargs.pop("max_length", config.sequence_length) 321 | training_arguments = trl.SFTConfig( 322 | **config_kwargs, 323 | max_length=max_length, 324 | output_dir=config.output_path, 325 | dataset_kwargs=dataset_kwargs, 326 | ) 327 | 328 | signal_source = create_signal_source(config, tokenizer_vocab_size) 329 | if config.layer_mapping is not None: 330 | if not isinstance(signal_source, OnlineSignalSource): 331 | raise RuntimeError( 332 | "Hidden state distillation not supported for offline teachers" 333 | ) 334 | teacher_hidden_size = signal_source.teacher_model.config.hidden_size 335 | if config.layer_mapping == "all": 336 | mapping = [(i, i) for i in range(model.config.num_hidden_layers)] 337 | else: 338 | mapping = config.layer_mapping 339 | hsm = HiddenStateMapping( 340 | student=model, 341 | teacher_hidden_size=teacher_hidden_size, 342 | layer_mapping=mapping, 343 | force_projection=config.force_hidden_state_projection, 344 | ) 345 | else: 346 | hsm = None 347 | trainer = DistillationTrainer( 348 | model=model, 349 | config=config, 350 | signal_source=signal_source, 351 | hidden_state_mapping=hsm, 352 | true_vocab_size=tokenizer_vocab_size, 353 | train_dataset=ds_train, 354 | eval_dataset=ds_eval, 355 | args=training_arguments, 356 | data_collator=collate_packed_batch if config.dataset.prepacked else None, 357 | processing_class=None if config.dataset.prepacked else tokenizer, 358 | ) 359 | 360 | resume_from_checkpoint = config.training_args.get("resume_from_checkpoint", None) 361 | 362 | LOG.info("Starting training.") 363 | trainer.train( 364 | resume_from_checkpoint=resume_from_checkpoint, 365 | ) 366 | LOG.info(f"Finished training. Saving model to {config.output_path}.") 367 | trainer.save_model(config.output_path) 368 | LOG.info("Done.") 369 | 370 | 371 | @click.command("distillkit-offline") 372 | @click.argument( 373 | "config_path", 374 | type=click.Path(exists=True, dir_okay=False, readable=True), 375 | ) 376 | @click.option( 377 | "--verbose", 378 | "-v", 379 | "verbosity", 380 | count=True, 381 | help="Increase verbosity of logging. Use -vv for debug level.", 382 | ) 383 | def main(config_path: str, verbosity: int): 384 | log_level = logging.WARNING 385 | if verbosity >= 2: 386 | log_level = logging.DEBUG 387 | elif verbosity == 1: 388 | log_level = logging.INFO 389 | logging.basicConfig(level=log_level) 390 | with open(config_path, "r") as f: 391 | config_dict = yaml.safe_load(f) 392 | config = DistillationRunConfig.model_validate(config_dict) 393 | do_distill(config) 394 | 395 | 396 | if __name__ == "__main__": 397 | # torch.autograd.set_detect_anomaly(True) 398 | main() 399 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DistillKit 2 | 3 | A flexible and production-ready toolkit for knowledge distillation of large language models, supporting both online and offline distillation workflows with advanced logit compression. 4 | 5 | DistillKit powers the training of many of Arcee's popular open-source models, including [Virtuoso](https://huggingface.co/arcee-ai/Virtuoso-Large), [SuperNova Medius](https://huggingface.co/arcee-ai/SuperNova-Medius), and [Blitz](https://huggingface.co/arcee-ai/Arcee-Blitz). 6 | 7 | ## Features 8 | 9 | - **Online Distillation**: Real-time teacher inference during student training 10 | - **Offline Distillation**: Train from pre-captured teacher outputs with advanced compression 11 | - **Advanced Logit Compression**: Novel polynomial approximation + quantization + bit-packing achieving vigorous compression ratios while preserving distillation quality 12 | - **Flexible Loss Functions**: Composable losses including KL divergence, JSD, TVD, ranking losses, and hidden state alignment 13 | - **Sparse & Dense Support**: Efficient sparse distributions (top-k) or exact dense distributions 14 | - **Battle-tested**: The infrastructure powering Arcee's distilled model releases 15 | - **HuggingFace Integration**: Built on Transformers, TRL, and Accelerate 16 | 17 | ## Why DistillKit? 18 | 19 | While online distillation is straightforward, **offline distillation at scale** requires careful engineering. Simply storing top-k token-logit pairs becomes prohibitively expensive when distilling on billions of tokens. 20 | 21 | DistillKit's compression system is the result of months of experimentation to strike the delicate balance between storage costs, memory throughput, and distillation quality. Our approach: 22 | 23 | 1. **Polynomial approximation** of the logit distribution curve 24 | 2. **Error-diffusion quantization** of residuals to preserve quality 25 | 3. **Bit-level packing** with arbitrary bit widths (1-64 bits) 26 | 27 | This enables practical offline distillation workflows that would otherwise be infeasible. 28 | 29 | ## Installation 30 | 31 | ```bash 32 | git clone https://github.com/arcee-ai/distillkit.git 33 | cd distillkit 34 | pip install -e . 35 | ``` 36 | 37 | ### Optional: Logit Capture 38 | 39 | To capture your own teacher outputs, install the capture dependencies: 40 | 41 | ```bash 42 | pip install -e ".[capture]" 43 | ``` 44 | 45 | For most users, we recommend starting with the pre-captured teacher datasets we provide (see [Datasets](#datasets) below). 46 | 47 | ## Quick Start 48 | 49 | ### Offline Distillation 50 | 51 | Train a student model using pre-captured teacher outputs: 52 | 53 | ```yaml 54 | # config.yaml 55 | project_name: my-distillation 56 | model: Qwen/Qwen3-8B 57 | output_path: ./output 58 | sequence_length: 8192 59 | 60 | dataset: 61 | train_dataset: 62 | repo_id: arcee-ai/Qwen3-235B-Logits-Packed-8192 # Pre-captured teacher outputs 63 | split: train 64 | prepacked: true 65 | 66 | teacher: 67 | kind: dataset 68 | logprob_compressor: 69 | d: 151936 # Vocabulary size 70 | delta_encoding: true 71 | error_diffusion: false 72 | exact_dtype: float32 73 | exact_k: 32 74 | k: 128 75 | polynomial_terms: [0, 1, 2] 76 | residual_bins: [] 77 | term_dtype: float32 78 | 79 | loss_functions: 80 | - function: cross_entropy 81 | weight: 0.5 82 | - function: kl 83 | weight: 0.5 84 | temperature: 1.0 85 | missing_probability_handling: zero 86 | sparse_chunk_length: 1024 87 | 88 | training_args: 89 | num_train_epochs: 1 90 | per_device_train_batch_size: 1 91 | gradient_accumulation_steps: 8 92 | learning_rate: 2.0e-6 93 | bf16: true 94 | optim: adamw_torch 95 | gradient_checkpointing: true 96 | ``` 97 | 98 | Run training: 99 | 100 | ```bash 101 | distillkit config.yaml 102 | ``` 103 | 104 | ### Online Distillation 105 | 106 | For online distillation where the teacher runs alongside student training, see [`examples/afm_test.yml`](examples/afm_test.yml) for a complete configuration example. 107 | 108 | ## Core Concepts 109 | 110 | ### Knowledge Distillation for LLMs 111 | 112 | Knowledge distillation transfers knowledge from a (potentially larger) "teacher" model to a "student" model. Instead of training only on hard labels (the correct token), the student learns from the teacher's probability distribution over tokens, which is a much richer learning signal. 113 | 114 | **Key benefits:** 115 | - Smaller, faster models with competitive performance 116 | - Lower inference costs 117 | - Easier deployment in resource-constrained environments 118 | 119 | ### Online vs Offline Distillation 120 | 121 | **Online Distillation:** 122 | - Teacher runs in real-time during student training 123 | - No storage overhead 124 | - Best when: You have sufficient VRAM for both models and dense distributions 125 | 126 | **Offline Distillation:** 127 | - Teacher outputs pre-captured and compressed 128 | - Enables training multiple students from the same teacher 129 | - Best when: VRAM-limited, reusing teacher signals, or training at large scale 130 | 131 | **Rule of thumb:** If you can fit both teacher and student with dense distributions into VRAM, use online distillation. Otherwise, offline distillation with our compression system is the way to go. 132 | 133 | ### Sparse vs Dense Distributions 134 | 135 | **Dense distributions** include probabilities for the full vocabulary. This is more accurate but memory-intensive. 136 | 137 | **Sparse distributions** store only the top-k tokens and serve as a lossy, but useful and efficient, approximation of the full dense distribution. With sufficient training data, sparse distillation can achieve equivalent performance to dense. 138 | 139 | DistillKit supports both, with automatic chunking for memory-efficient processing of long sequences. 140 | 141 | ### Logit Compression 142 | 143 | Our compression system balances storage efficiency with distillation quality: 144 | 145 | 1. Select top-k logits from teacher output 146 | 2. Sort by log-probability, optionally apply delta encoding 147 | 3. Fit polynomial to the distribution curve 148 | 4. Quantize residuals, with optional error diffusion 149 | 5. Bitpack everything into byte vectors 150 | 151 | There are lots of knobs you can twiddle here to reach a storage/fidelity tradeoff that works for your particular needs. 152 | 153 | **Recommended configuration** (used at Arcee for new captures): 154 | ```yaml 155 | logprob_compressor: 156 | d: 157 | k: 128 158 | exact_k: 16 159 | exact_dtype: bfloat16 160 | polynomial_terms: [0, 1, 2, 3, 4, "sqrt"] 161 | term_dtype: float32 162 | residual_bins: [] 163 | delta_encoding: false 164 | error_diffusion: false 165 | ``` 166 | 167 | This takes ~300 bytes/token (0.15% of uncompressed distribution size) with minimal quality loss. 168 | 169 | If you're a little tight on storage, try the **budget pick**: 170 | ```yaml 171 | logprob_compressor: 172 | d: 173 | k: 50 174 | exact_k: 1 175 | exact_dtype: bfloat16 176 | polynomial_terms: [0, 1, "sqrt"] 177 | term_dtype: float32 178 | residual_bins: [] 179 | delta_encoding: false 180 | error_diffusion: false 181 | ``` 182 | 183 | This weighs in at around 114 bytes per token, smaller and with better reconstruction quality than storing the top 32 logprobs in bf16. 184 | 185 | Note that the configuration that was used to capture the logits must be reflected in the distillation configuration. Mixing and matching isn't gonna work out so hot. 186 | 187 | ## Configuration Guide 188 | 189 | ### Loss Functions 190 | 191 | DistillKit supports composable loss functions with independent weights: 192 | 193 | #### Distribution-Based Losses 194 | - `kl`: Kullback-Leibler divergence (standard distillation loss) 195 | - `jsd`: Jensen-Shannon divergence (symmetric alternative to KL) 196 | - `tvd`: Total Variation Distance 197 | 198 | #### Ranking Losses 199 | - `hinge`: Hinge ranking loss 200 | - `logistic_ranking`: Logistic ranking loss 201 | 202 | #### Hidden State Alignment 203 | - `hs_mse`: Mean squared error between teacher and student hidden states 204 | - `hs_cosine`: Cosine similarity between hidden states 205 | 206 | #### Standard 207 | - `cross_entropy`: Standard language modeling loss 208 | 209 | All distribution losses support both sparse and dense modes. Combine multiple losses: 210 | 211 | ```yaml 212 | loss_functions: 213 | - function: cross_entropy 214 | weight: 0.25 215 | - function: kl 216 | weight: 0.5 217 | temperature: 2.0 218 | - function: hs_cosine 219 | weight: 0.25 220 | ``` 221 | 222 | ### Teacher Configuration 223 | 224 | **Offline (from dataset):** 225 | ```yaml 226 | teacher: 227 | kind: dataset 228 | logprob_compressor: 229 | d: 128256 230 | k: 128 231 | exact_k: 16 232 | delta_encoding: true 233 | ... 234 | # or: 235 | legacy_logit_compression: 236 | vocab_size: 128256 237 | k: 128 238 | exact_k: 32 239 | polynomial_degree: 8 240 | ... 241 | ``` 242 | 243 | **Online (HuggingFace model):** 244 | ```yaml 245 | teacher: 246 | kind: hf 247 | path: Qwen/Qwen3-8B 248 | kwargs: # keyword arguments passed when loading teacher model 249 | attn_implementation: flash_attention_2 250 | torch_dtype: bfloat16 251 | ``` 252 | 253 | ## Advanced Topics 254 | 255 | 256 | ### Compression Deep-Dive 257 | 258 | The compression system supports two modes: 259 | 260 | **Legacy compression** (fully polynomial-based): 261 | ```yaml 262 | legacy_logit_compression: 263 | vocab_size: 128256 # Size of teacher vocabulary 264 | k: 128 # Total number of logprobs per token, exact plus approximated 265 | exact_k: 32 # Number of logprobs stored as floating point values 266 | polynomial_degree: 8 # Degree of approximating polynomial 267 | with_sqrt_term: false # Include sqrt term in polynomial 268 | term_dtype: float32 # Precision for polynomial coefficients 269 | invert_polynomial: true # Invert for better numerical properties 270 | ``` 271 | 272 | **Distribution quantization** (newer, more flexible): 273 | ```yaml 274 | logprob_compressor: 275 | d: 128256 # Size of teacher vocabulary 276 | k: 128 # Total number of logprobs per token, exact plus approximated 277 | exact_k: 16 # Number of logprobs stored as floating point values 278 | exact_dtype: bfloat16 # dtype for "exact" logprobs 279 | delta_encoding: false # Store logprobs as deltas (not recommended) 280 | error_diffusion: false # Perform error diffusion to spread quantization error across values (not recommended) 281 | polynomial_terms: # List of polynomial terms used for approximating tail 282 | - 0 283 | - 1 284 | - 2 285 | - "sqrt" 286 | term_dtype: float32 # dtype for storage of polynomial coefficients 287 | residual_bins: # Optional list of bins storing quantized residuals vs. the approximated tail 288 | - scale_dtype: float16 # dtype for scale factor for this bin 289 | element_bits: 8 # Bits/element 290 | num_elements: 16 # Total number of elements in this bin 291 | - scale_dtype: float32 # bfloat16 also works 292 | element_bits: 2 # Can use any number of bits <= 64 293 | num_elements: 64 294 | ... 295 | ``` 296 | 297 | ### Hidden State Distillation 298 | 299 | Align student hidden states with teacher hidden states: 300 | 301 | ```yaml 302 | layer_mapping: all # Or specify layer pairs 303 | loss_functions: 304 | - function: hs_mse 305 | weight: 0.5 306 | ``` 307 | 308 | For cross-architecture distillation, hidden states are projected using learned linear mappings. You can also enable this for same-architecture distillations by setting `force_hidden_state_projection: true`. 309 | 310 | ### Capturing Teacher Outputs 311 | 312 | To create your own offline distillation dataset: 313 | 314 | ```bash 315 | python -m distillkit.sample_logits_vllm \ 316 | --model meta-llama/Llama-3.1-70B \ 317 | --dataset allenai/tulu-3-sft-mixture \ 318 | --output ./llama3_70b_tulu_logits/ \ 319 | --compression-config ./compression_config.yaml 320 | ``` 321 | 322 | Requires vLLM (see [Installation](#optional-logit-capture)). 323 | 324 | ### Memory Management Tips 325 | 326 | **For long sequences:** 327 | - Use `sparse_chunk_length` to process sequences in chunks (e.g., `1024`) 328 | - Use DeepSpeed ZeRO Stage 1 or 2 to cram more tokens in there 329 | 330 | **For general savings:** 331 | - Use `optim: paged_adamw_8bit` or `optim: adamw_bnb_8bit` 332 | - Enable Flash Attention 2: `use_flash_attention: true` 333 | - Use bfloat16 instead of float32 334 | - Enable `gradient_checkpointing` 335 | - Reduce batch size, increase gradient accumulation 336 | 337 | ## Examples 338 | 339 | - **Offline Distillation (70B → 8B)**: [`examples/llama_70b_base.yml`](examples/llama_70b_base.yml) 340 | - **Online Distillation with Hidden States**: [`examples/afm_test.yml`](examples/afm_test.yml) 341 | - **Multimodal Model Distillation**: [`examples/mistral3.yaml`](examples/mistral3.yaml) 342 | 343 | ## Datasets 344 | 345 | We're releasing several pre-captured teacher datasets: 346 | 347 | * [Qwen3-235B instruction-following](https://huggingface.co/datasets/arcee-ai/Qwen3-235B-Logits-Packed-8192): ~1.5 billion tokens of general instruct data at 8192 context length 348 | * [DeepSeek V3/R1 synthetic mixed-mode reasoning](https://huggingface.co/datasets/arcee-ai/DeepSeek-MixedModeReasoning-Logits-Packed-16384): ~5 billion tokens captured from DeepSeek V3 and R1, with prefixes to distinguish reasoning from non-reasoning traces - 16k context length 349 | * [DeepSeek V3 base](https://huggingface.co/datasets/arcee-ai/DeepSeek-DCLM-Logits-Packed-8192): ~1.2 billion tokens of raw completion data from DCLM captured from the DeepSeek V3 base model 350 | 351 | ## Cross-Architecture Distillation 352 | 353 | DistillKit can be used together with [mergekit-tokensurgeon](https://github.com/arcee-ai/mergekit/blob/main/docs/tokensurgeon.md) for cross-tokenizer, cross-architecture distillation. Many Arcee models combine both tools: 354 | 355 | 1. Use tokensurgeon to adapt student embeddings to teacher's tokenizer 356 | 2. Use DistillKit to distill teacher knowledge to student 357 | 3. Optionally convert back to student's original tokenizer, maybe do some other weird merges, follow your dreams 358 | 359 | ## Training Tips 360 | 361 | - **Start with ~0.5 cross-entropy weight**, then tune up or down depending on how high quality your dataset is 362 | - **Distillation temperature**: `temperature: 2.0` is a good first choice 363 | - **Missing probability handling**: Use `zero` to focus only on the teacher's most confident predictions; use `uniform` to match the teacher's uncertainty as well 364 | 365 | ## Citation 366 | 367 | If you use DistillKit in your research, please cite: 368 | 369 | ```bibtex 370 | @software{distillkit2024, 371 | title = {DistillKit: Flexible Knowledge Distillation for Large Language Models}, 372 | author = {Goddard, Charles and Atkins, Lucas}, 373 | year = {2024}, 374 | publisher = {Arcee AI}, 375 | url = {https://github.com/arcee-ai/distillkit} 376 | } 377 | ``` 378 | 379 | ## Community & Support 380 | 381 | - **Issues**: [GitHub Issues](https://github.com/arcee-ai/distillkit/issues) 382 | - **Discussions**: [Arcee Discord](https://discord.gg/arceeai) 383 | 384 | Note: DistillKit is an open-source research release. While it powers several of our production models and we'll happily address issues as bandwidth allows, community support is best-effort. 385 | 386 | ## License 387 | 388 | DistillKit is released under the Apache License 2.0. 389 | 390 | ### Acknowledgments 391 | 392 | - Flash Attention packing implementation adapted from [Functionary](https://github.com/MeetKai/functionary) (MIT License) 393 | - Built on [HuggingFace Transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl), and [Accelerate](https://github.com/huggingface/accelerate) 394 | 395 | --- 396 | 397 | **Built with ♥ by [Arcee AI](https://www.arcee.ai)** 398 | -------------------------------------------------------------------------------- /distillkit/compression/monotonic_logprobs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from distillkit.compression.bitpack import pack_to_bytes, unpack_from_bytes 4 | from distillkit.compression.config import ( 5 | DistributionQuantizationConfig, 6 | SpecialTerm, 7 | ) 8 | 9 | 10 | def _work_dtype(*inputs: torch.Tensor | None) -> torch.dtype: 11 | for x in inputs: 12 | if x is not None and x.dtype == torch.float64: 13 | return torch.float64 14 | return torch.float32 15 | 16 | 17 | def _solve_least_squares(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: 18 | work_dtype = _work_dtype(A, B) 19 | 20 | # because for some reason torch.linalg.lstsq only works for full-rank matrices on GPU 21 | U, S, Vh = torch.linalg.svd(A.to(work_dtype), full_matrices=False) 22 | tol = 1e-5 23 | Spinv = torch.zeros_like(S) 24 | Spinv[S > tol] = 1 / S[S > tol] 25 | UhB = U.transpose(-1, -2) @ B.to(work_dtype) 26 | SpinvUhB = Spinv.unsqueeze(-1) * UhB 27 | return Vh.transpose(-1, -2) @ SpinvUhB 28 | 29 | 30 | def polynomial_terms( 31 | terms: list[SpecialTerm | int], 32 | t: int, 33 | dtype: torch.dtype, 34 | device: torch.device, 35 | normalize_t: bool, 36 | ) -> torch.Tensor: 37 | assert all(isinstance(i, (int, SpecialTerm)) for i in terms), ( 38 | "terms must be a list of integers or SpecialTerm instances" 39 | ) 40 | if normalize_t: 41 | pts = torch.linspace(0, 1, steps=t, dtype=dtype, device=device) 42 | else: 43 | pts = torch.arange(t, dtype=dtype, device=device) 44 | X = torch.stack( 45 | [pts**i if isinstance(i, int) else getattr(torch, i.value)(pts) for i in terms], 46 | dim=-1, 47 | ) 48 | return X 49 | 50 | 51 | def fit_polynomial( 52 | values: torch.Tensor, 53 | terms: list[SpecialTerm | int], 54 | dtype: torch.dtype, 55 | normalize_t: bool, 56 | ) -> tuple[torch.Tensor, torch.Tensor]: 57 | work_dtype = _work_dtype(values) 58 | X = polynomial_terms( 59 | terms, 60 | values.shape[-1], 61 | dtype=work_dtype, 62 | device=values.device, 63 | normalize_t=normalize_t, 64 | ) 65 | while len(X.shape) < len(values.shape): 66 | X = X.unsqueeze(0) 67 | 68 | y = values.unsqueeze(-1) 69 | coeffs = _solve_least_squares(X, y).squeeze(-1) 70 | 71 | coeffs_final = coeffs.to(dtype) 72 | approx = torch.sum( 73 | X.to(dtype) * coeffs_final.unsqueeze(-2), 74 | dim=-1, 75 | ).to(work_dtype) 76 | residual = values - approx.squeeze(-1) 77 | return coeffs_final, residual 78 | 79 | 80 | def _get_quantize_range(element_bits: int): 81 | if element_bits == 1: 82 | quant_min = 0 83 | quant_max = 1 84 | else: 85 | quant_min = -(2 ** (element_bits - 1)) 86 | quant_max = (2 ** (element_bits - 1)) - 1.0 87 | 88 | return quant_min, quant_max 89 | 90 | 91 | def _get_quantize_scale_factors(values: torch.Tensor, element_bits: int): 92 | # Compute max absolute value for each group (..., 1) 93 | max_abs_val = torch.amax(torch.abs(values), dim=-1, keepdim=True) 94 | 95 | # Determine the maximum quantized absolute value based on element_bits 96 | if element_bits == 1: 97 | max_quant_abs = 1.0 98 | else: 99 | max_quant_abs = 2 ** (element_bits - 1) 100 | 101 | return max_abs_val, max_quant_abs 102 | 103 | 104 | def error_diffuse_and_quantize( 105 | values: torch.Tensor, 106 | element_bits: int, 107 | scale_dtype: torch.dtype, 108 | error_buffer: torch.Tensor | None = None, 109 | ) -> tuple[torch.Tensor, torch.LongTensor]: 110 | """ 111 | Quantize the input tensor to the specified number of bits per element using error diffusion. 112 | """ 113 | original_shape = values.shape 114 | n = original_shape[-1] 115 | 116 | (max_abs_val, max_quant_abs) = _get_quantize_scale_factors(values, element_bits) 117 | 118 | # Calculate scale factor, avoiding division by zero 119 | scale_factor = torch.where( 120 | max_abs_val == 0, torch.ones_like(max_abs_val), max_abs_val / max_quant_abs 121 | ) 122 | 123 | # Simulate output precision 124 | scale_factor = scale_factor.to(scale_dtype).to(values.dtype) 125 | 126 | # Scale the input values 127 | scaled_vals = values / scale_factor 128 | 129 | # Initialize error buffer and quantized tensor 130 | if error_buffer is None: 131 | error_buffer = torch.zeros_like(scaled_vals[..., 0]) 132 | quantized_values = torch.zeros_like(scaled_vals, dtype=torch.long) 133 | 134 | quant_min, quant_max = _get_quantize_range(element_bits) 135 | 136 | # Process each element along the last dimension with error diffusion 137 | for i in range(n): 138 | current = scaled_vals[..., i] + error_buffer 139 | quantized_i = torch.round(current) 140 | quantized_i = torch.clamp(quantized_i, quant_min, quant_max) 141 | error = current - quantized_i 142 | error_buffer = error 143 | quantized_values[..., i] = (quantized_i - quant_min).to(torch.long) 144 | 145 | return scale_factor.to(scale_dtype), quantized_values, error_buffer 146 | 147 | 148 | def error_diffuse_float( 149 | values: torch.Tensor, 150 | out_dtype: torch.dtype, 151 | error_buffer: torch.Tensor | None = None, 152 | ) -> tuple[torch.Tensor, torch.Tensor]: 153 | work_dtype = _work_dtype(values, error_buffer) 154 | values = values.to(work_dtype) 155 | if error_buffer is None: 156 | error_buffer = torch.zeros_like(values[..., 0]) 157 | 158 | out_values = torch.zeros_like(values, dtype=out_dtype) 159 | n = values.shape[-1] 160 | for i in range(n): 161 | current = values[..., i] + error_buffer 162 | q = current.to(out_dtype) 163 | error = current - q.to(work_dtype) 164 | error_buffer = error 165 | out_values[..., i] = q 166 | 167 | return out_values, error_buffer 168 | 169 | 170 | def quantize_naive( 171 | values: torch.Tensor, 172 | element_bits: int, 173 | scale_dtype: torch.dtype, 174 | error_buffer: torch.Tensor | None = None, 175 | ) -> tuple[torch.Tensor, torch.LongTensor]: 176 | """ 177 | Naive quantization of the input tensor to the specified number of bits per element. 178 | """ 179 | work_dtype = _work_dtype(values, error_buffer) 180 | values = values.to(work_dtype) 181 | 182 | (max_abs_val, max_quant_abs) = _get_quantize_scale_factors(values, element_bits) 183 | 184 | # Calculate scale factor, avoiding division by zero 185 | scale_factor = torch.where( 186 | max_abs_val == 0, torch.ones_like(max_abs_val), max_abs_val / max_quant_abs 187 | ) 188 | 189 | # Simulate output precision 190 | scale_factor = scale_factor.to(scale_dtype).to(values.dtype) 191 | 192 | # Scale the input values 193 | scaled_vals = values / scale_factor 194 | 195 | # Quantize the scaled values 196 | quant_min, quant_max = _get_quantize_range(element_bits) 197 | quantized_values = ( 198 | torch.round(scaled_vals).clamp(quant_min, quant_max) - quant_min 199 | ).to(torch.long) 200 | 201 | return ( 202 | scale_factor, 203 | quantized_values, 204 | torch.zeros_like(scaled_vals[..., 0]), 205 | ) 206 | 207 | 208 | def dequantize( 209 | quantized_values: torch.LongTensor, 210 | scale: torch.Tensor, 211 | element_bits: int, 212 | ) -> torch.Tensor: 213 | """ 214 | Dequantize the input tensor using the specified scale and number of bits per element. 215 | """ 216 | if element_bits == 1: 217 | return (quantized_values.to(torch.float32) * 2.0 - 1.0) * scale 218 | quant_min, quant_max = _get_quantize_range(element_bits) 219 | dequantized_values = (quantized_values + quant_min).to(torch.float32) * scale 220 | return dequantized_values 221 | 222 | 223 | def compress_monotonic_logprobs( 224 | logprobs: torch.Tensor, 225 | config: DistributionQuantizationConfig, 226 | ) -> torch.ByteTensor: 227 | """ 228 | Compresses logprobs using the specified configuration. 229 | Args: 230 | logprobs (torch.Tensor): Log probabilities to compress, shape (batch_size, seq_len, k). Must be 231 | monotonically decreasing along the last dimension. 232 | config (DistributionQuantizationConfig): Configuration for compression. 233 | Returns: 234 | torch.ByteTensor: Compressed logprobs. Shape (batch_size, seq_len, ceil(config.total_bits() / 8)). 235 | """ 236 | work_dtype = _work_dtype(logprobs) 237 | if config.delta_encoding: 238 | # Apply delta encoding 239 | # Logprobs are all <= 0 and monotonically decreasing 240 | # replace logprobs with their deltas with respect to previous value 241 | # the first value is unchanged 242 | if config.error_diffusion: 243 | work_dtype = torch.float64 244 | logprobs_work = logprobs.to(work_dtype) 245 | deltas = logprobs_work[..., 1:] - logprobs_work[..., :-1] 246 | deltas = torch.cat( 247 | [logprobs[..., :1], deltas], 248 | dim=-1, 249 | ) 250 | if config.error_diffusion: 251 | logprobs, _ = error_diffuse_float(deltas, logprobs.dtype, error_buffer=None) 252 | else: 253 | logprobs = deltas.to(logprobs.dtype) 254 | 255 | chunks = [] 256 | 257 | error_buffer = None 258 | if config.exact_k > 0: 259 | exact_values = logprobs[..., : config.exact_k].to(config.exact_dtype.dtype()) 260 | chunks.append( 261 | exact_values.view(torch.uint8).reshape( 262 | *logprobs.shape[:-1], 263 | -1, 264 | ), 265 | ) 266 | 267 | if config.polynomial_terms: 268 | approx_values = logprobs[..., config.exact_k : config.k] 269 | coeffs, residual = fit_polynomial( 270 | approx_values, 271 | config.polynomial_terms, 272 | dtype=config.term_dtype.dtype(), 273 | normalize_t=config.normalize_t, 274 | ) 275 | 276 | coeffs = coeffs.to(config.term_dtype.dtype()) 277 | coeff_bytes = coeffs.view(torch.uint8).reshape( 278 | *logprobs.shape[:-1], 279 | -1, 280 | ) 281 | chunks.append(coeff_bytes) 282 | else: 283 | residual = logprobs[..., config.exact_k : config.k] 284 | 285 | cur_index = 0 286 | for bin in config.residual_bins: 287 | values = residual[..., cur_index : cur_index + bin.num_elements] 288 | if config.error_diffusion: 289 | scale, scaled, error_buffer = error_diffuse_and_quantize( 290 | values, bin.element_bits, bin.scale_dtype.dtype(), error_buffer 291 | ) 292 | else: 293 | scale, scaled, error_buffer = quantize_naive( 294 | values, bin.element_bits, bin.scale_dtype.dtype(), error_buffer 295 | ) 296 | 297 | scale = scale.to(bin.scale_dtype.dtype()) 298 | scale_bytes = scale.view(torch.uint8).reshape( 299 | *logprobs.shape[:-1], 300 | -1, 301 | ) 302 | chunks.append(scale_bytes) 303 | 304 | packed = pack_to_bytes(scaled, bin.element_bits) 305 | packed = packed.reshape( 306 | *logprobs.shape[:-1], 307 | -1, 308 | ) 309 | chunks.append(packed) 310 | 311 | cur_index += bin.num_elements 312 | 313 | # return byte tensor 314 | return torch.cat(chunks, dim=-1) 315 | 316 | 317 | def decompress_monotonic_logprobs( 318 | bytes: torch.ByteTensor, 319 | config: DistributionQuantizationConfig, 320 | out_dtype: torch.dtype | None = None, 321 | use_residual: bool = True, 322 | ) -> torch.Tensor: 323 | """ 324 | Decompresses logprobs using the specified configuration. 325 | 326 | Args: 327 | bytes (torch.ByteTensor): Compressed logprobs, shape (batch_size, seq_len, num_bytes). 328 | config (DistributionQuantizationConfig): Configuration for decompression. 329 | out_dtype (torch.dtype | None): Data type for the output tensor. If None, uses the dtype for the 330 | exact values. 331 | Returns: 332 | torch.Tensor: Decompressed logprobs, shape (batch_size, seq_len, k). 333 | """ 334 | device = bytes.device 335 | if out_dtype is None: 336 | out_dtype = config.exact_dtype.dtype() 337 | 338 | # Extract exact values 339 | if config.exact_k > 0: 340 | exact_dtype_torch = config.exact_dtype.dtype() 341 | bytes_per_exact = config.exact_dtype.bit_width() // 8 342 | exact_bytes = config.exact_k * bytes_per_exact 343 | exact_part = bytes[..., :exact_bytes].contiguous() 344 | exact_values = exact_part.view(dtype=exact_dtype_torch).reshape( 345 | *bytes.shape[:-1], config.exact_k 346 | ) 347 | remaining_bytes = bytes[..., exact_bytes:] 348 | else: 349 | exact_values = torch.empty( 350 | (*bytes.shape[:-1], 0), dtype=out_dtype, device=device 351 | ) 352 | remaining_bytes = bytes 353 | 354 | # Extract polynomial coefficients if applicable 355 | if config.polynomial_terms and len(config.polynomial_terms) > 0: 356 | term_dtype_torch = config.term_dtype.dtype() 357 | terms_count = len(config.polynomial_terms) 358 | coeff_bytes = terms_count * (config.term_dtype.bit_width() // 8) 359 | coeff_part = remaining_bytes[..., :coeff_bytes].contiguous() 360 | coeffs = coeff_part.view(dtype=term_dtype_torch).reshape( 361 | *remaining_bytes.shape[:-1], terms_count 362 | ) 363 | remaining_bytes = remaining_bytes[..., coeff_bytes:] 364 | else: 365 | coeffs = None 366 | terms_count = 0 367 | 368 | # Process residual bins 369 | residuals = [] 370 | for bin in config.residual_bins: 371 | # Extract scale 372 | scale_dtype_torch = bin.scale_dtype.dtype() 373 | scale_bytes = bin.scale_dtype.bit_width() // 8 374 | scale_part = remaining_bytes[..., :scale_bytes].contiguous() 375 | scale = scale_part.view(dtype=scale_dtype_torch).reshape( 376 | *remaining_bytes.shape[:-1], 1 377 | ) 378 | remaining_bytes = remaining_bytes[..., scale_bytes:] 379 | 380 | # Extract and unpack elements 381 | num_elements = bin.num_elements 382 | element_bits = bin.element_bits 383 | packed_bits = num_elements * element_bits 384 | packed_bytes = (packed_bits + 7) // 8 385 | packed_part = remaining_bytes[..., :packed_bytes].contiguous() 386 | remaining_bytes = remaining_bytes[..., packed_bytes:] 387 | 388 | elements = unpack_from_bytes(packed_part, element_bits, num_elements) 389 | residual_bin = dequantize( 390 | elements, 391 | scale, 392 | element_bits, 393 | ) 394 | residuals.append(residual_bin) 395 | 396 | # Combine residuals and pad if necessary 397 | approx_terms = config.k - config.exact_k 398 | sum_bin_elems = sum(bin.num_elements for bin in config.residual_bins) 399 | if use_residual and residuals: 400 | residual = torch.cat(residuals, dim=-1) 401 | if sum_bin_elems < approx_terms: 402 | residual = torch.nn.functional.pad( 403 | residual, (0, approx_terms - sum_bin_elems) 404 | ) 405 | else: 406 | residual = torch.zeros( 407 | (*remaining_bytes.shape[:-1], approx_terms), dtype=out_dtype, device=device 408 | ) 409 | 410 | # Compute polynomial approximation 411 | if ( 412 | coeffs is not None 413 | and config.polynomial_terms 414 | and len(config.polynomial_terms) > 0 415 | ): 416 | X = polynomial_terms( 417 | terms=config.polynomial_terms, 418 | t=approx_terms, 419 | dtype=config.term_dtype.dtype(), 420 | device=device, 421 | normalize_t=config.normalize_t, 422 | ) 423 | fit = torch.sum( 424 | X.to(coeffs.device, coeffs.dtype) * coeffs.unsqueeze(-2), 425 | dim=-1, 426 | ) 427 | approx_values = fit + residual.to(out_dtype) 428 | else: 429 | approx_values = residual.to(out_dtype) 430 | 431 | logprobs = torch.cat([exact_values.to(out_dtype), approx_values], dim=-1) 432 | if config.delta_encoding: 433 | # Apply inverse delta encoding 434 | logprobs = torch.cumsum(logprobs.float(), dim=-1).to(out_dtype) 435 | return logprobs 436 | --------------------------------------------------------------------------------