├── .gitignore ├── LICENSE ├── README.md ├── configs └── ds_config_zero3.json ├── requirements.txt ├── scripts ├── convert │ ├── gemma2-9B-it.sh │ ├── gemma2-9B.sh │ ├── llama2-7B.sh │ ├── llama3-8B.sh │ ├── mistral-7B-v0.3.sh │ ├── mixtral-8x7B-v0.1.sh │ └── qwen2.5-7B-Instruct.sh └── train │ └── run_train_llama2_7b.sh ├── train.py └── transmla ├── converter.py ├── lora_qkv.py ├── modify_config.py ├── partial_rope.py ├── transformers ├── gemma2 │ ├── configuration_gemma2mla.py │ └── modeling_gemma2mla.py ├── llama │ ├── configuration_llamamla.py │ └── modeling_llamamla.py ├── mixtral │ ├── configuration_mixtralmla.py │ └── modeling_mixtralmla.py └── mla.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *outputs 6 | *meta-llama 7 | debug 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # UV 101 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | #uv.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 119 | .pdm.toml 120 | .pdm-python 121 | .pdm-build/ 122 | 123 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 124 | __pypackages__/ 125 | 126 | # Celery stuff 127 | celerybeat-schedule 128 | celerybeat.pid 129 | 130 | # SageMath parsed files 131 | *.sage.py 132 | 133 | # Environments 134 | .env 135 | .venv 136 | env/ 137 | venv/ 138 | ENV/ 139 | env.bak/ 140 | venv.bak/ 141 | 142 | # Spyder project settings 143 | .spyderproject 144 | .spyproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | 149 | # mkdocs documentation 150 | /site 151 | 152 | # mypy 153 | .mypy_cache/ 154 | .dmypy.json 155 | dmypy.json 156 | 157 | # Pyre type checker 158 | .pyre/ 159 | 160 | # pytype static type analyzer 161 | .pytype/ 162 | 163 | # Cython debug symbols 164 | cython_debug/ 165 | 166 | # PyCharm 167 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 168 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 169 | # and can be added to the global gitignore or merged into this file. For a more nuclear 170 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 171 | #.idea/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Fanxu Meng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🚀 TransMLA: Migrating GQA Models to MLA with Full DeepSeek Compatibility and Speedup 2 | 3 | Modern large-language models often face communication bottlenecks on current hardware rather than computational limitations. Multi-head latent attention (MLA) addresses this by compressing the key-value cache using low-rank matrices, while the Absorb operation prevents the KV cache from reverting to its original size, significantly boosting both training and inference speed. 4 | 5 | Despite the success of DeepSeek V2/V3/R1, most model vendors have heavily invested in optimizing GQA-based models and therefore lack strong incentives to retrain MLA-based models from scratch. In this paper, we introduce TransMLA, a framework that seamlessly converts any GQA-based pre-trained model (e.g., LLaMA, Qwen, Mixtral) into an MLA-based model. 6 | 7 | 8 | # 📰 News 9 | - [2025.05.29] A new version of technical report is released: [https://arxiv.org/abs/2502.07864](https://arxiv.org/abs/2502.07864). 10 | - [2025.04.28] Released TransMLA v3, successfully apply PCA across RoPE and reduce KV Cache. 11 | - [2025.02.16] Released the second version of the TransMLA model and usage code, compatible with RoPE and supporting Absorb operation. 12 | - [2025.02.13] The technical report of TransMLA is publicly available: [https://huggingface.co/papers/2502.07864](https://huggingface.co/papers/2502.07864) 13 | - [2025.01.02] Released the first version of the TransMLA model code, providing usage code for converting Qwen2.5 and LLaMA-3’s GQA to MLA equivalence. 14 | 15 | # 🛠 Installation 16 | ``` 17 | git clone https://github.com/fxmeng/TransMLA.git 18 | cd TransMLA 19 | conda create -n transmla python=3.12.8 20 | conda activate transmla 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | # ⚡ Quick Start 25 | 26 | 1. Convert MHA / GQA models (e.g. Qwen2.5-7B-Instruct) into DeepSeek-MLA: 27 | ```bash 28 | bash scripts/convert/qwen2.5-7B-Instruct.sh 29 | ``` 30 | 2. Have fun playing with the converted models! 31 | ```python 32 | # using `Transformers.AutoModelForCausalLM` 33 | import torch 34 | from transformers import AutoModelForCausalLM 35 | model = AutoModelForCausalLM.from_pretrained("outputs/qwen2_5-7B-Instruct-deepseek", trust_remote_code=True) 36 | ``` 37 | 38 | ## 🔧 Advanced Usage (`converter.py`) 39 | 40 | The converter.py script allows you to perform fine-grained control over RoPE removal and low-rank QKV projection towards DeepSeek-MLA. It supports: 41 | - Auto-search for optimal freqfold that minimizes PPL. 42 | - Automatic computation of collapse based on head_dim / qk_mqa_dim. 43 | - Evaluation of original, RoPE-removed, and final MLA models. 44 | 45 | 46 | ### ✅ Example Command: 47 | ```bash 48 | python transmla/converter.py \ 49 | --model-path meta-llama/Llama-2-7b-hf \ 50 | --save-path ./outputs/llama2-7b-deepseek \ 51 | --dtype bf16 \ 52 | --device auto \ 53 | --cal-dataset wikitext2 \ 54 | --cal-nsamples 128 \ 55 | --cal-max-seqlen 256 \ 56 | --cal-batch-size 8 \ 57 | --ppl-eval-batch-size 4 \ 58 | --freqfold auto \ 59 | --collapse auto \ 60 | --qk-mqa-dim 64 \ 61 | --q-lora-rank 512 \ 62 | --kv-lora-rank 512 63 | ``` 64 | 65 | ### 📘 Argument Details 66 | 67 | | Argument | Description | 68 | |----------|-------------| 69 | | --model-path | Path to the base model (e.g., from HuggingFace hub). | 70 | | --save-path | Output path for the converted model and tokenizer. | 71 | | --cal-dataset | Calibration dataset: wikitext2, ptb, c4, or alpaca. | 72 | | --cal-nsamples, --cal-max-seqlen, --cal-batch-size | Number, max sequence length, and batch size of samples used for calibration. | 73 | | --freqfold | RoPE frequency folding factor, or `auto` to search for the best value. | 74 | | --collapse | Collapse factor for RoPE. Use `auto` to compute as `head_dim // qk_mqa_dim`. Collapse factor reduces the dim of RoPEd KV cache from `head_dim` to `head_dim // collapse`. | 75 | | --qk-mqa-dim | Target dimension for decoupled RoPE. | 76 | | --q-lora-rank | The inner dimension for query low-rank decomposition, or `None` to disable low-rank decomposition for query. | 77 | | --kv-lora-rank | The inner dimension for key/value joint low-rank decomposition. | 78 | 79 | 80 | ### 🧠 Tips 81 | - Set `--freqfold auto` and `--collapse auto` to simplify configuration. The script will automatically search for the best freqfold factor based on ppl results. 82 | - We recommend setting `--qk-mqa-dim` to 64 and `--kv-lora-rank` to 512 to satisfy FlashMLA's requirements on H100. 83 | 84 | 85 | # 🐒 Model Zoo 86 | 87 | - [x] Llama2 88 | - [x] Llama3 89 | - [x] Qwen2 90 | - [x] Gemma2 91 | - [x] Mistral 92 | - [x] Mixtral 93 | - [ ] MiMo 94 | 95 | 96 | # 📋 To-Do 97 | - [x] Publish the technical report for the new version, detailing how TransMLA is compatible with RoPE, supports the Absorb operation. 98 | - [x] Compress the dimensions of the KV cache to improve inference speed. 99 | - [x] Add support for vLLM to improve inference speed. 100 | - [x] Support FlashMLA. 101 | - [x] Extend support to additional models (e.g., LLaMA, Mistral, Gemma2, etc.). 102 | - [ ] Release checkpoints. 103 | - [ ] Fine-tune on R1 distillation datasets. 104 | 105 | 106 | # 📚 Citation 107 | ``` 108 | @article{meng2025transmla, 109 | title={TransMLA: Multi-head Latent Attention Is All You Need}, 110 | author={Meng, Fanxu and Tang, Pingzhi and Yao, Zengwei and Zhang, Muhan}, 111 | journal={arXiv preprint arXiv:2502.07864}, 112 | year={2025} 113 | } 114 | ``` 115 | 116 | # ⭐ Star History 117 | 118 | [![Star History Chart](https://api.star-history.com/svg?repos=fxmeng/TransMLA&type=Date)](https://www.star-history.com/#fxmeng/TransMLA&Date) -------------------------------------------------------------------------------- /configs/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | 15 | "zero_optimization": { 16 | "stage": 3, 17 | "offload_optimizer": { 18 | "device": "cpu", 19 | "pin_memory": true 20 | }, 21 | "offload_param": { 22 | "device": "cpu", 23 | "pin_memory": true 24 | }, 25 | "overlap_comm": true, 26 | "contiguous_gradients": true, 27 | "sub_group_size": 1e9, 28 | "reduce_bucket_size": "auto", 29 | "stage3_prefetch_bucket_size": "auto", 30 | "stage3_param_persistence_threshold": "auto", 31 | "stage3_max_live_parameters": 1e9, 32 | "stage3_max_reuse_distance": 1e9, 33 | "stage3_gather_16bit_weights_on_model_save": true 34 | }, 35 | 36 | "gradient_accumulation_steps": "auto", 37 | "gradient_clipping": "auto", 38 | "steps_per_print": 20, 39 | "train_batch_size": "auto", 40 | "train_micro_batch_size_per_gpu": "auto", 41 | "wall_clock_breakdown": false 42 | } 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | vllm==0.8.4 2 | datasets 3 | accelerate==1.3.0 4 | datatrove 5 | tensorboardX -------------------------------------------------------------------------------- /scripts/convert/gemma2-9B-it.sh: -------------------------------------------------------------------------------- 1 | model_path=google/gemma-2-9b-it 2 | save_path=outputs/gemma2-9B-it-deepseek 3 | eval_batch_size=4 4 | 5 | 6 | # 1. convert to deepseek-mla 7 | python transmla/converter.py \ 8 | --model-path $model_path \ 9 | --save-path $save_path \ 10 | --freqfold 8 \ 11 | --ppl-eval-batch-size $eval_batch_size 12 | 13 | 14 | # 2. copy modeling and configuration files 15 | cp transmla/transformers/gemma2/* $save_path/ 16 | cp transmla/transformers/mla.py $save_path/ -------------------------------------------------------------------------------- /scripts/convert/gemma2-9B.sh: -------------------------------------------------------------------------------- 1 | model_path=google/gemma-2-9b 2 | save_path=outputs/gemma2-9B-deepseek 3 | eval_batch_size=4 4 | 5 | 6 | # 1. convert to deepseek-mla 7 | python transmla/converter.py \ 8 | --model-path $model_path \ 9 | --save-path $save_path \ 10 | --freqfold 8 \ 11 | --ppl-eval-batch-size $eval_batch_size \ 12 | --cal-dataset alpaca 13 | 14 | 15 | # 2. copy modeling and configuration files 16 | cp transmla/transformers/gemma2/* $save_path/ 17 | cp transmla/transformers/mla.py $save_path/ -------------------------------------------------------------------------------- /scripts/convert/llama2-7B.sh: -------------------------------------------------------------------------------- 1 | model_path=meta-llama/Llama-2-7b-hf 2 | save_path=outputs/llama2-7B-deepseek 3 | eval_batch_size=4 4 | 5 | 6 | # 1. convert to deepseek-mla 7 | python transmla/converter.py \ 8 | --model-path $model_path \ 9 | --save-path $save_path \ 10 | --freqfold 8 \ 11 | --ppl-eval-batch-size $eval_batch_size 12 | 13 | 14 | # 2. copy modeling and configuration files 15 | cp transmla/transformers/llama/* $save_path/ 16 | cp transmla/transformers/mla.py $save_path/ -------------------------------------------------------------------------------- /scripts/convert/llama3-8B.sh: -------------------------------------------------------------------------------- 1 | model_path=meta-llama/Meta-Llama-3-8B 2 | save_path=outputs/llama3-8B-deepseek 3 | eval_batch_size=4 4 | 5 | 6 | # 1. convert to deepseek-mla 7 | python transmla/converter.py \ 8 | --model-path $model_path \ 9 | --save-path $save_path \ 10 | --freqfold 4 \ 11 | --ppl-eval-batch-size $eval_batch_size 12 | 13 | 14 | # 2. copy modeling and configuration files 15 | cp transmla/transformers/llama/* $save_path/ 16 | cp transmla/transformers/mla.py $save_path/ -------------------------------------------------------------------------------- /scripts/convert/mistral-7B-v0.3.sh: -------------------------------------------------------------------------------- 1 | model_path=mistralai/Mistral-7B-v0.3 2 | save_path=outputs/mistral-7B-deepseek 3 | eval_batch_size=8 4 | 5 | 6 | # 1. convert to deepseek-mla 7 | python transmla/converter.py \ 8 | --model-path $model_path \ 9 | --save-path $save_path \ 10 | --freqfold 8 \ 11 | --ppl-eval-batch-size $eval_batch_size 12 | 13 | 14 | # 2. copy modeling and configuration files 15 | cp transmla/transformers/llama/* $save_path/ 16 | cp transmla/transformers/mla.py $save_path/ -------------------------------------------------------------------------------- /scripts/convert/mixtral-8x7B-v0.1.sh: -------------------------------------------------------------------------------- 1 | model_path=mistralai/Mixtral-8x7B-v0.1 2 | save_path=outputs/mixtral-8x7B-deepseek 3 | eval_batch_size=8 4 | 5 | 6 | # 1. convert to deepseek-mla 7 | python transmla/converter.py \ 8 | --model-path $model_path \ 9 | --save-path $save_path \ 10 | --freqfold 8 \ 11 | --ppl-eval-batch-size $eval_batch_size 12 | 13 | 14 | # 2. copy modeling and configuration files 15 | cp transmla/transformers/mixtral/* $save_path/ 16 | cp transmla/transformers/mla.py $save_path/ -------------------------------------------------------------------------------- /scripts/convert/qwen2.5-7B-Instruct.sh: -------------------------------------------------------------------------------- 1 | model_path=Qwen/Qwen2.5-7B-Instruct 2 | save_path=outputs/qwen2_5-7B-Instruct-deepseek 3 | eval_batch_size=8 4 | 5 | 6 | # 1. convert to deepseek-mla 7 | python transmla/converter.py \ 8 | --model-path $model_path \ 9 | --save-path $save_path \ 10 | --freqfold 4 \ 11 | --ppl-eval-batch-size $eval_batch_size 12 | 13 | 14 | # 2. copy modeling and configuration files 15 | cp transmla/transformers/llama/* $save_path/ 16 | cp transmla/transformers/mla.py $save_path/ -------------------------------------------------------------------------------- /scripts/train/run_train_llama2_7b.sh: -------------------------------------------------------------------------------- 1 | BASE_MODEL="outputs/TransMLA-llama-2-7b-q4096-kv448" 2 | OUTPUT_PATH="outputs/ft100m-TransMLA-llama-2-7b-q4096-kv448" 3 | DATA_PATH="/data2/mengfanxu/nanotron/datasets/100m" 4 | 5 | # batch size = per_device_train_batch_size * gradient_accumulation_steps * num_gpus = 128 6 | deepspeed --master_port=16971 --include=localhost:0,1,2,3,4,5,6,7 train.py \ 7 | --deepspeed configs/ds_config_zero3.json \ 8 | --model_name_or_path $BASE_MODEL \ 9 | --bf16 \ 10 | --data_path $DATA_PATH \ 11 | --output_dir $OUTPUT_PATH \ 12 | --num_train_epochs 1 \ 13 | --seq_len 2048 \ 14 | --per_device_train_batch_size 2 \ 15 | --gradient_accumulation_steps 4 \ 16 | --save_strategy "steps" \ 17 | --save_steps 100 \ 18 | --save_total_limit 1 \ 19 | --learning_rate 2e-5 \ 20 | --weight_decay 0.01 \ 21 | --warmup_ratio 0.03 \ 22 | --logging_steps 1 \ 23 | --lr_scheduler_type "cosine" \ 24 | --report_to "tensorboard" 25 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from typing import Optional 4 | from dataclasses import dataclass, field 5 | from datatrove.utils.dataset import DatatroveFolderDataset 6 | 7 | 8 | @dataclass 9 | class TrainingArguments(transformers.TrainingArguments): 10 | model_name_or_path: Optional[str] = field(default="meta-llama/Meta-Llama-3-8B") 11 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 12 | attn_implementation : Optional[str] = field(default="flash_attention_2") 13 | seq_len: int = field(default=2048,metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},) 14 | 15 | parser = transformers.HfArgumentParser(TrainingArguments) 16 | training_args = parser.parse_args_into_dataclasses()[0] 17 | model = transformers.AutoModelForCausalLM.from_pretrained( 18 | training_args.model_name_or_path, 19 | torch_dtype=torch.bfloat16, 20 | attn_implementation=training_args.attn_implementation, 21 | trust_remote_code=True, 22 | ) 23 | tokenizer = transformers.AutoTokenizer.from_pretrained(training_args.model_name_or_path) 24 | 25 | train_dataset = DatatroveFolderDataset( 26 | training_args.data_path, 27 | seq_len=training_args.seq_len, 28 | return_positions=True 29 | ) 30 | 31 | data_collator = transformers.DataCollatorForLanguageModeling( 32 | tokenizer=tokenizer, 33 | mlm=False, 34 | return_tensors="pt", 35 | ) 36 | trainer = transformers.Trainer( 37 | args=training_args, 38 | model=model, 39 | train_dataset=train_dataset, 40 | data_collator=data_collator 41 | ) 42 | 43 | trainer.train() 44 | trainer.save_state() 45 | trainer.save_model(output_dir=training_args.output_dir) 46 | -------------------------------------------------------------------------------- /transmla/converter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from transformers import AutoTokenizer, AutoModelForCausalLM 4 | import torch 5 | 6 | from modify_config import modify_config 7 | from utils import get_dataset, prepare_dataloader, prepare_test_dataloader, evaluate_ppl 8 | from partial_rope import partial_rope 9 | from lora_qkv import low_rank_qkv 10 | 11 | 12 | def load_model_and_tokenizer(model_path: str): 13 | model = AutoModelForCausalLM.from_pretrained( 14 | args.model_path, 15 | torch_dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16 if args.dtype == "bf16" else torch.float32, 16 | device_map=args.device, 17 | _attn_implementation="eager", 18 | trust_remote_code=True, 19 | ) 20 | tokenizer = AutoTokenizer.from_pretrained( 21 | args.model_path, 22 | trust_remote_code=True, 23 | ) 24 | if tokenizer.pad_token is None: 25 | tokenizer.pad_token = tokenizer.eos_token 26 | 27 | return model, tokenizer 28 | 29 | 30 | def get_dataset_loader(tokenizer: AutoTokenizer, **kwargs): 31 | dataset = get_dataset(kwargs["cal_dataset"]) 32 | train_loader = prepare_dataloader( 33 | dataset=dataset["train"], 34 | tokenizer=tokenizer, 35 | max_seqlen=kwargs["cal_max_seqlen"], 36 | batch_size=kwargs["cal_batch_size"], 37 | nsamples=kwargs["cal_nsamples"], 38 | seed=kwargs["seed"], 39 | ) 40 | if kwargs["ppl_eval_batch_size"] > 0: 41 | test_loader = prepare_test_dataloader( 42 | dataset=dataset["test"], tokenizer=tokenizer, batch_size=kwargs["ppl_eval_batch_size"] 43 | ) 44 | else: 45 | test_loader = None 46 | 47 | return train_loader, test_loader 48 | 49 | 50 | def main(args): 51 | 52 | ############################## 53 | # original model # 54 | ############################## 55 | print("\n" + "="*60) 56 | print("Original Model".center(60)) 57 | print("="*60 + "\n") 58 | 59 | # get model, tokenizer 60 | model, tokenizer = load_model_and_tokenizer(args.model_path) 61 | model_type = model.config.model_type 62 | # get dataset 63 | train_loader, test_loader = get_dataset_loader(tokenizer, **vars(args)) 64 | 65 | if test_loader: 66 | message = "Evaluating original model's ppl" 67 | dataset_ppl = evaluate_ppl(model, tokenizer.pad_token_id, test_loader, message) 68 | print(f'Original ppl: {dataset_ppl:.4f}') 69 | 70 | ############################## 71 | # partial rope # 72 | ############################## 73 | print("\n" + "="*60) 74 | print("Partial RoPE Model".center(60)) 75 | print("="*60 + "\n") 76 | 77 | if args.collapse == "auto": 78 | head_dim = model.config.head_dim if hasattr(model.config, "head_dim") else model.config.hidden_size // model.config.num_attention_heads 79 | model.config.head_dim = head_dim 80 | args.collapse = head_dim // args.qk_mqa_dim 81 | print(f"Auto collapse: {args.collapse} (head_dim={head_dim} / qk_mqa_dim={args.qk_mqa_dim})") 82 | else: 83 | args.collapse = int(args.collapse) 84 | 85 | model = partial_rope(model, tokenizer, train_loader, test_loader, **vars(args)) 86 | if args.freqfold == "auto": 87 | args.freqfold = model[1] 88 | model = model[0] 89 | 90 | ############################## 91 | # deepseek-mla model # 92 | ############################## 93 | print("\n" + "="*60) 94 | print("LoraQKV Model".center(60)) 95 | print("="*60 + "\n") 96 | 97 | model = low_rank_qkv(model, tokenizer, train_loader, test_loader, **vars(args)) 98 | 99 | # save model 100 | print(f"\nSaving model and tokenizer to {args.save_path}...") 101 | model.save_pretrained(os.path.join(args.save_path)) 102 | tokenizer.save_pretrained(os.path.join(args.save_path)) 103 | 104 | # modify config 105 | modify_config(model, os.path.join(args.save_path, "config.json"), args) 106 | 107 | 108 | if __name__ == "__main__": 109 | 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument("--model-path", type=str, default="meta-llama/Llama-2-7b-hf", help="Model to load") 112 | parser.add_argument("--save-path", type=str, default="outputs", help="output path.") 113 | parser.add_argument("--dtype", type=str, help="Data type to use.", choices=["fp32", "fp16", "bf16"], default="bf16") 114 | parser.add_argument("--device", type=str, help="Device to use.", default="auto") 115 | parser.add_argument("--cal-dataset", type=str, help="Dataset to calibrate and calculate perplexity on.", choices=["wikitext2", "ptb", "c4", "alpaca"], default="wikitext2") 116 | parser.add_argument("--cal-nsamples", type=int, help="Number of samples of the calibration data to load.", default=128) 117 | parser.add_argument("--cal-batch-size", type=int, default=8, help="Batch size for loading the calibration data.") 118 | parser.add_argument("--cal-max-seqlen", type=int, default=256, help="Maximum sequence length for the calibration data.") 119 | parser.add_argument("--seed", type=int, default=42, help="Seed for sampling the calibration data.") 120 | parser.add_argument("--ppl-eval-batch-size", type=int, default=0, help="Batch size for evaluating the perplexity.") 121 | parser.add_argument("--freqfold", type=str, default="auto", help="Freqfold for removing RoPE, int or auto") 122 | parser.add_argument("--collapse", type=str, default="auto", help="Collapse for removing RoPE, int or auto") 123 | parser.add_argument("--qk-mqa-dim", type=int, default=64, help="") 124 | parser.add_argument("--q-lora-rank", type=int, help="") 125 | parser.add_argument("--kv-lora-rank", type=int, default=512, help="") 126 | parser.add_argument("--balance-kv-ratio", type=float, default=1, help="") 127 | parser.add_argument("--use-qkv-norm", action='store_true', default=False, help="") 128 | args = parser.parse_args() 129 | 130 | main(args) 131 | -------------------------------------------------------------------------------- /transmla/lora_qkv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from typing import Optional, Tuple 5 | import torch.nn.functional as F 6 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS 7 | 8 | from utils import pca_calc, get_qkv_calibrate_outputs, evaluate_ppl, statistics_qkv_rmsnorm 9 | 10 | def rotate_half(x): 11 | x1 = x[..., : x.shape[-1] // 2] 12 | x2 = x[..., x.shape[-1] // 2 :] 13 | return torch.cat((-x2, x1), dim=-1) 14 | 15 | def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): 16 | cos = cos.unsqueeze(unsqueeze_dim) 17 | sin = sin.unsqueeze(unsqueeze_dim) 18 | q_embed = (q * cos) + (rotate_half(q) * sin) 19 | k_embed = (k * cos) + (rotate_half(k) * sin) 20 | return q_embed, k_embed 21 | 22 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 23 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 24 | if n_rep == 1: 25 | return hidden_states 26 | hidden_states = hidden_states[:, :, None, :, :].expand( 27 | batch, num_key_value_heads, n_rep, slen, head_dim 28 | ) 29 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 30 | 31 | class LoraQKV(nn.Module): 32 | def __init__( 33 | self, 34 | self_attn, 35 | query_outputs=None, 36 | key_outputs=None, 37 | value_outputs=None, 38 | q_lora_rank=None, 39 | qk_mqa_dim=64, 40 | collapse=1, 41 | kv_lora_rank=896, 42 | use_qkv_norm=None, 43 | balance_kv_ratio=None, 44 | rms_norm_eps=1e-6, 45 | ): 46 | super().__init__() 47 | assert qk_mqa_dim * collapse == self_attn.head_dim 48 | 49 | self.config = self_attn.config 50 | self.dtype = self_attn.q_proj.weight.dtype 51 | self.layer_idx = self_attn.layer_idx 52 | self.num_attention_heads = self_attn.num_attention_heads 53 | self.num_key_value_heads = self_attn.num_key_value_heads 54 | self.head_dim = self_attn.head_dim 55 | self.qk_mqa_dim = qk_mqa_dim 56 | self.collapse = collapse 57 | self.latent_dim = self_attn.latent_dim 58 | self.attention_dropout = self_attn.attention_dropout 59 | self.hidden_size = self_attn.hidden_size 60 | self.q_lora_rank = q_lora_rank 61 | self.kv_lora_rank = kv_lora_rank 62 | assert self.kv_lora_rank <= 2 * self.latent_dim - self.qk_mqa_dim, f"kv_lora_rank ({self.kv_lora_rank}) must be less than 2 * latent_dim ({self.latent_dim}) - qk_mqa_dim ({self.qk_mqa_dim})" 63 | 64 | self.attention_function = ALL_ATTENTION_FUNCTIONS["sdpa"] 65 | # self.scaling = (self.head_dim + self.qk_mqa_dim)**(-0.5) 66 | self.scaling = self.head_dim**(-0.5) 67 | 68 | # -----------------Attributes for the bias----------------- 69 | q_bias = self_attn.q_proj.bias is not None 70 | k_bias = self_attn.k_proj.bias is not None 71 | v_bias = self_attn.v_proj.bias is not None 72 | assert q_bias == k_bias == v_bias, f"q_bias ({q_bias}), k_bias ({k_bias}), v_bias ({v_bias}) must be the same" 73 | self.attention_bias = q_bias 74 | 75 | # -----------------module definitions----------------- 76 | # q_a_proj & q_b_proj 77 | if q_lora_rank is not None: 78 | self.q_a_proj = nn.Linear( 79 | self.hidden_size, 80 | q_lora_rank, 81 | bias=None, 82 | device = self_attn.q_proj.weight.device, 83 | dtype = self.dtype, 84 | ) 85 | if use_qkv_norm: 86 | self.q_a_layernorm = nn.RMSNorm(q_lora_rank, device=self_attn.q_proj.weight.device, dtype=self.dtype, eps=rms_norm_eps) 87 | self.q_b_proj = nn.Linear( 88 | q_lora_rank, 89 | self.num_attention_heads * (self.qk_mqa_dim + self.head_dim), 90 | bias=self.attention_bias, 91 | device = self_attn.q_proj.weight.device, 92 | dtype = self.dtype, 93 | ) 94 | else: 95 | self.q_proj = nn.Linear( 96 | self.hidden_size, 97 | self.num_attention_heads * (self.qk_mqa_dim + self.head_dim), 98 | bias=self.attention_bias, 99 | device = self_attn.q_proj.weight.device, 100 | dtype = self.dtype, 101 | ) 102 | # kv_a_proj & kv_b_proj 103 | self.kv_a_proj_with_mqa = nn.Linear( 104 | self.hidden_size, 105 | kv_lora_rank + qk_mqa_dim, 106 | bias=self.attention_bias, 107 | device = self_attn.k_proj.weight.device, 108 | dtype = self.dtype, 109 | ) 110 | if use_qkv_norm: 111 | self.kv_a_layernorm = nn.RMSNorm(kv_lora_rank, device=self_attn.k_proj.weight.device, dtype=self.dtype, eps=rms_norm_eps) 112 | self.kv_b_proj = nn.Linear( 113 | kv_lora_rank, 114 | self.num_attention_heads * self.head_dim * 2, 115 | bias=self.attention_bias, 116 | device = self_attn.k_proj.weight.device, 117 | dtype = self.dtype, 118 | ) 119 | # nothing else to do for o_proj 120 | self.o_proj = self_attn.o_proj 121 | 122 | # -----------------apply bkv on the key and value outputs----------------- 123 | if balance_kv_ratio is not None: 124 | k_outputs_norm = torch.cat([key.reshape(-1, self.latent_dim)[:,self.qk_mqa_dim:] for key in key_outputs]).norm(p=2,dim=0).mean() 125 | v_outputs_norm = torch.cat([value.reshape(-1, self.latent_dim)[:,self.qk_mqa_dim:] for value in value_outputs]).norm(p=2,dim=0).mean() 126 | ratio = k_outputs_norm/(v_outputs_norm * balance_kv_ratio) 127 | self_attn.k_proj.weight.data[self.qk_mqa_dim:] = self_attn.k_proj.weight.data[self.qk_mqa_dim:] / ratio 128 | self_attn.k_up_proj.weight.data[:, self.qk_mqa_dim:] = self_attn.k_up_proj.weight.data[:, self.qk_mqa_dim:] * ratio 129 | else: 130 | ratio = 1 131 | kv_outputs = [torch.cat([key_outputs[i][:,:,qk_mqa_dim:] / ratio, value_outputs[i]],dim=-1) for i in range(len(key_outputs))] 132 | 133 | # -----------------apply pca on the query and key/value outputs----------------- 134 | if self.q_lora_rank is not None: 135 | if self.attention_bias: 136 | # If q_bias is not None, we need to remove the bias from the query_outputs, 137 | # because the bias part does not need pca. 138 | R_q = pca_calc( 139 | [x.to(dtype=self.dtype, device=self_attn.q_proj.bias.device) - self_attn.q_proj.bias.data for x in query_outputs], 140 | self_attn.q_proj.weight.device 141 | ) 142 | else: 143 | R_q = pca_calc(query_outputs, self_attn.q_proj.weight.device) 144 | else: 145 | R_q = None 146 | if self.attention_bias: 147 | # If k_bias is not None, we need to remove the bias from the kv_outputs, 148 | # because the bias part does not need pca. 149 | kv_bias = torch.cat([self_attn.k_proj.bias.data[qk_mqa_dim:] / ratio, self_attn.v_proj.bias.data]) 150 | R_kv = pca_calc( 151 | [x.to(dtype=self.dtype, device=kv_bias.device) - kv_bias for x in kv_outputs], 152 | self_attn.k_proj.weight.device 153 | ) 154 | else: 155 | R_kv = pca_calc(kv_outputs, self_attn.k_proj.weight.device) 156 | 157 | # -----------------initialize the weights / bias----------------- 158 | self._init_weights(self_attn, R_q, R_kv) 159 | 160 | def _init_weights(self, self_attn, R_q, R_kv): 161 | # 0. Split the weights of k_proj and v_proj into rope / nope parts. 162 | k_a_rope_weight, k_a_nope_weight = self_attn.k_proj.weight.data.split([self.qk_mqa_dim, self.latent_dim - self.qk_mqa_dim],dim=0) 163 | k_b_rope_weight, k_b_nope_weight = self_attn.k_up_proj.weight.data.split([self.qk_mqa_dim, self.latent_dim - self.qk_mqa_dim], dim=1) 164 | k_b_rope_weight = k_b_rope_weight.view(self.num_attention_heads, self.head_dim, self.qk_mqa_dim) 165 | k_b_nope_weight = k_b_nope_weight.view(self.num_attention_heads, self.head_dim, self.latent_dim-self.qk_mqa_dim) 166 | if self.attention_bias: 167 | k_bias_rope, k_bias_nope = self_attn.k_proj.bias.data.split([self.qk_mqa_dim, self.latent_dim - self.qk_mqa_dim], dim=0) 168 | # k_bias = self_attn.k_proj.bias.data 169 | 170 | v_a_nope_weight = self_attn.v_proj.weight.data 171 | v_b_nope_weight = self_attn.v_up_proj.weight.data 172 | v_b_nope_weight = v_b_nope_weight.view(self.num_attention_heads, self.head_dim, self.latent_dim) 173 | if self.attention_bias: 174 | v_bias = self_attn.v_proj.bias.data 175 | 176 | if self.attention_bias: 177 | q_bias = self_attn.q_proj.bias.data 178 | 179 | 180 | # 1. Initialize q_a_proj / q_b_proj if q_lora_rank is not None 181 | # 1.1 Initialize q_a_proj 182 | if self.q_lora_rank is not None: 183 | q_a_weight = (R_q.T @ self_attn.q_proj.weight.data.to(torch.float64))[: self.q_lora_rank].to(self.dtype) 184 | q_b_weight = R_q[:, :self.q_lora_rank].to(self.dtype) 185 | q_b_weight = q_b_weight.view(self.num_attention_heads, self.head_dim, self.q_lora_rank) 186 | assert self.q_a_proj.weight.data.shape == q_a_weight.shape 187 | self.q_a_proj.weight.data = q_a_weight.contiguous() 188 | 189 | # 1.2 Initialize q_b_proj 190 | # scaling = math.sqrt(self.head_dim + self.qk_mqa_dim) / math.sqrt(self.head_dim) 191 | scaling = 1 192 | if self.q_lora_rank is not None: 193 | # Absorb the rope part of k_b_proj into q_b_proj 194 | q_b_rope_weight = torch.einsum("hdq,hdk->hkq", q_b_weight, k_b_rope_weight) 195 | q_b_with_mqa_weight = torch.cat([q_b_weight, q_b_rope_weight], dim=1).reshape( 196 | self.num_attention_heads * (self.head_dim + self.qk_mqa_dim), self.q_lora_rank 197 | ) 198 | 199 | # Scale the weight before initializing the q_b_proj 200 | # In the original GQA, attention scores are divided by sqrt(head_dim). 201 | # However, in the transformed MLA, the attention scores are divided by sqrt(head_dim + qk_mqa_dim). 202 | assert self.q_b_proj.weight.data.shape == q_b_with_mqa_weight.shape 203 | self.q_b_proj.weight.data = q_b_with_mqa_weight.contiguous() * scaling 204 | 205 | # Considering the bias 206 | if self.attention_bias: 207 | q_b_bias = q_bias.reshape(self.num_attention_heads, self.head_dim) 208 | q_b_rope_bias = torch.einsum("hd,hdk->hk", q_b_bias, k_b_rope_weight) 209 | q_b_with_mqa_bias = torch.cat([q_b_bias, q_b_rope_bias], dim=1).flatten() 210 | assert self.q_b_proj.bias.data.shape == q_b_with_mqa_bias.shape 211 | self.q_b_proj.bias.data = q_b_with_mqa_bias.contiguous() * scaling 212 | else: 213 | q_weight = self_attn.q_proj.weight.data.view(self.num_attention_heads, self.head_dim, self.hidden_size) 214 | q_rope_weight = torch.einsum("hdD,hdk->hkD", q_weight, k_b_rope_weight) 215 | q_with_mqa_weight = torch.cat([q_weight, q_rope_weight], dim=1).reshape( 216 | self.num_attention_heads * (self.head_dim + self.qk_mqa_dim), self.hidden_size 217 | ) 218 | assert self.q_proj.weight.data.shape == q_with_mqa_weight.shape 219 | self.q_proj.weight.data = q_with_mqa_weight.contiguous() * scaling 220 | 221 | if self.attention_bias: 222 | q_bias = q_bias.reshape(self.num_attention_heads, self.head_dim) 223 | q_rope_bias = torch.einsum("hd,hdk->hk", q_bias.to(torch.float64), k_b_rope_weight.to(torch.float64)).to(self.dtype) 224 | q_bias = torch.cat([q_bias, q_rope_bias], dim=1).flatten() 225 | assert self.q_proj.bias.data.shape == q_bias.shape 226 | self.q_proj.bias.data = q_bias.contiguous() * scaling 227 | 228 | 229 | # 2. Low-rank decomposing k_proj and v_proj 230 | # 2.1 Concatenate the nope parts of k_proj and v_proj 231 | kv_a_nope_weight = torch.cat([k_a_nope_weight, v_a_nope_weight], dim=0).to(torch.float64) 232 | kv_b_nope_weight = torch.cat( 233 | [ 234 | torch.cat([k_b_nope_weight, torch.zeros_like(v_b_nope_weight)], dim=-1), 235 | torch.cat([torch.zeros_like(k_b_nope_weight), v_b_nope_weight], dim=-1) 236 | ], 237 | dim=1 238 | ).reshape(2 * self.num_attention_heads * self.head_dim, 2 * self.latent_dim - self.qk_mqa_dim).to(torch.float64) 239 | 240 | # 2.2 Low-rank decomposing kv_a_nope_weight and kv_b_nope_weight 241 | kv_a_nope_weight = (R_kv.T @ kv_a_nope_weight)[: self.kv_lora_rank].to(self.dtype) 242 | kv_b_nope_weight = (kv_b_nope_weight @ R_kv)[:, :self.kv_lora_rank].to(self.dtype) 243 | self.kv_b_proj.weight.data = kv_b_nope_weight.contiguous() 244 | kv_a_proj_with_mqa_weight = torch.cat([kv_a_nope_weight, k_a_rope_weight], dim=0) 245 | assert self.kv_a_proj_with_mqa.weight.data.shape == kv_a_proj_with_mqa_weight.shape 246 | self.kv_a_proj_with_mqa.weight.data = kv_a_proj_with_mqa_weight.contiguous() 247 | 248 | # 2.3 Considering the bias of kv 249 | if self.attention_bias: 250 | kv_a_bias = torch.zeros(self.kv_lora_rank, dtype=self.dtype, device=self_attn.k_proj.bias.device) 251 | # kv_a_bias = torch.zeros_like(torch.cat([k_bias_nope, v_bias], dim=0), dtype=self.dtype, device=self_attn.k_proj.bias.device) 252 | kv_a_with_mqa_bias = torch.cat([kv_a_bias, k_bias_rope], dim=0) 253 | assert self.kv_a_proj_with_mqa.bias.data.shape == kv_a_with_mqa_bias.shape, f"{self.kv_a_proj_with_mqa.bias.data.shape} != {kv_a_with_mqa_bias.shape}" 254 | self.kv_a_proj_with_mqa.bias.data = kv_a_with_mqa_bias.contiguous() 255 | 256 | k_b_bias = k_b_nope_weight @ k_bias_nope 257 | v_b_bias = v_b_nope_weight @ v_bias 258 | kv_b_bias = torch.cat([k_b_bias, v_b_bias], dim=1).flatten() 259 | assert self.kv_b_proj.bias.data.shape == kv_b_bias.shape 260 | self.kv_b_proj.bias.data = kv_b_bias.contiguous() 261 | 262 | 263 | 264 | def forward( 265 | self, 266 | hidden_states: torch.Tensor, 267 | attention_mask: Optional[torch.Tensor] = None, 268 | position_ids: Optional[torch.LongTensor] = None, 269 | past_key_value = None, 270 | output_attentions: bool = False, 271 | use_cache: bool = False, 272 | cache_position: Optional[torch.LongTensor] = None, 273 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 274 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 275 | bsz, q_len, _ = hidden_states.size() 276 | 277 | # query 278 | if self.q_lora_rank is not None: 279 | query_states = self.q_a_proj(hidden_states) 280 | if hasattr(self, "q_a_layernorm"): 281 | query_states = self.q_a_layernorm(query_states) 282 | query_states = self.q_b_proj(query_states) 283 | else: 284 | query_states = self.q_proj(hidden_states) 285 | 286 | query_states = query_states.view(bsz, q_len, self.num_attention_heads, -1).transpose(1,2) 287 | q_nope, q_rope = query_states.split([self.head_dim, self.qk_mqa_dim], dim=-1) 288 | 289 | # key and value 290 | compressed_kv = self.kv_a_proj_with_mqa(hidden_states) 291 | kv_nope, k_rope = compressed_kv.split([self.kv_lora_rank, self.qk_mqa_dim], dim=-1) 292 | kv_nope = kv_nope.view(bsz, 1, q_len, self.kv_lora_rank) 293 | k_rope = k_rope.view(bsz, 1, q_len, self.qk_mqa_dim) 294 | 295 | cos, sin = position_embeddings 296 | q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos[ :, :, : : self.collapse], sin[ :, :, : : self.collapse]) 297 | query_states = torch.cat([q_nope, q_rope], dim=-1) 298 | 299 | if hasattr(self, "kv_a_layernorm"): 300 | kv_nope = self.kv_a_layernorm(kv_nope) 301 | kv_nope = self.kv_b_proj(kv_nope).view(bsz, q_len, self.num_attention_heads, self.head_dim * 2).transpose(1, 2) 302 | k_nope, value_states = kv_nope.split([self.head_dim, self.head_dim],dim=-1) 303 | key_states = torch.cat([k_nope, repeat_kv(k_rope, self.num_attention_heads)], dim=-1) 304 | 305 | attn_output, attn_weights = self.attention_function( 306 | self, 307 | query_states, 308 | key_states, 309 | value_states, 310 | attention_mask, 311 | dropout=0.0 if not self.training else self.attention_dropout, 312 | scaling=self.scaling, 313 | ) 314 | 315 | attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() 316 | attn_output = self.o_proj(attn_output) 317 | 318 | return attn_output, attn_weights 319 | 320 | 321 | def low_rank_qkv(model, tokenizer, train_loader, test_loader, **kwargs): 322 | 323 | message = "Calibrating rope-removed model's qkv outputs" 324 | rm_rope_qkv_outputs = get_qkv_calibrate_outputs(model, train_loader, message) 325 | 326 | for layer_idx, layer in enumerate(model.model.layers): 327 | setattr(layer, "self_attn", LoraQKV( 328 | layer.self_attn, 329 | rm_rope_qkv_outputs["query"][layer_idx], 330 | rm_rope_qkv_outputs["key"][layer_idx], 331 | rm_rope_qkv_outputs["value"][layer_idx], 332 | q_lora_rank=kwargs["q_lora_rank"], 333 | qk_mqa_dim=kwargs["qk_mqa_dim"], 334 | collapse=kwargs["collapse"], 335 | kv_lora_rank=kwargs["kv_lora_rank"], 336 | use_qkv_norm=kwargs["use_qkv_norm"], 337 | balance_kv_ratio=kwargs["balance_kv_ratio"], 338 | rms_norm_eps=model.config.rms_norm_eps, 339 | )) 340 | 341 | if kwargs["use_qkv_norm"]: 342 | lora_qkv_outputs = get_qkv_calibrate_outputs(model, train_loader) 343 | for layer_idx, layer in enumerate(model.model.layers): 344 | statistics_qkv_rmsnorm( 345 | layer.self_attn, 346 | lora_qkv_outputs["q_a_proj"][layer_idx] if len(lora_qkv_outputs["q_a_proj"]) > layer_idx else None, 347 | lora_qkv_outputs["kv_a_proj"][layer_idx] 348 | ) 349 | 350 | if test_loader: 351 | message = "Evaluating lora-qkv model's ppl" 352 | dataset_ppl = evaluate_ppl(model, tokenizer.pad_token_id, test_loader, message) 353 | print(f'Low rank approximate QKV ppl: {dataset_ppl:.4f}') 354 | 355 | return model -------------------------------------------------------------------------------- /transmla/modify_config.py: -------------------------------------------------------------------------------- 1 | import json 2 | import transformers.models as models 3 | 4 | 5 | settings = { 6 | "llama": { 7 | "config_class": models.llama.configuration_llama.LlamaConfig, 8 | "auto_map": { 9 | "AutoConfig": "configuration_llamamla.LlamaMLAConfig", 10 | "AutoModel": "modeling_llamamla.LlamaMLAModel", 11 | "AutoModelForCausalLM": "modeling_llamamla.LlamaMLAForCausalLM" 12 | }, 13 | "architectures": ["LlamaMLAForCausalLM"], 14 | "model_type": "llamamla", 15 | }, 16 | "mixtral": { 17 | "config_class": models.mixtral.configuration_mixtral.MixtralConfig, 18 | "auto_map": { 19 | "AutoConfig": "configuration_mixtralmla.MixtralMLAConfig", 20 | "AutoModel": "modeling_mixtralmla.MixtralMLAModel", 21 | "AutoModelForCausalLM": "modeling_mixtralmla.MixtralMLAForCausalLM" 22 | }, 23 | "architectures": ["MixtralMLAForCausalLM"], 24 | "model_type": "mixtralmla", 25 | }, 26 | "gemma2": { 27 | "config_class": models.gemma2.configuration_gemma2.Gemma2Config, 28 | "auto_map": { 29 | "AutoConfig": "configuration_gemma2mla.Gemma2MLAConfig", 30 | "AutoModel": "modeling_gemma2mla.Gemma2MLAModel", 31 | "AutoModelForCausalLM": "modeling_gemma2mla.Gemma2MLAForCausalLM" 32 | }, 33 | "architectures": ["Gemma2MLAForCausalLM"], 34 | "model_type": "gemma2mla", 35 | } 36 | } 37 | settings["qwen2"] = settings["llama"] 38 | settings["mistral"] = settings["llama"] 39 | 40 | 41 | 42 | def modify_config(model, config_path: str, args): 43 | setting = settings[model.config.model_type] 44 | config_class = setting["config_class"] 45 | config = config_class.from_pretrained(config_path) 46 | config = config.to_dict() 47 | 48 | config["auto_map"] = setting["auto_map"] 49 | config["architectures"] = setting["architectures"] 50 | config["model_type"] = setting["model_type"] 51 | 52 | config["num_key_value_heads"] = config["num_attention_heads"] 53 | config["attention_bias"] = model.model.layers[0].self_attn.attention_bias 54 | config["qk_rope_head_dim"] = config["head_dim"] = args.qk_mqa_dim 55 | config["qk_nope_head_dim"] = config["v_head_dim"] = model.model.layers[0].self_attn.head_dim 56 | config["q_lora_rank"] = args.q_lora_rank 57 | config["kv_lora_rank"] = args.kv_lora_rank 58 | 59 | if config.get("query_pre_attn_scalar", None) is None: 60 | config["query_pre_attn_scalar"] = model.model.layers[0].self_attn.head_dim 61 | 62 | with open(config_path, "w") as f: 63 | json.dump(config, f, indent=4) -------------------------------------------------------------------------------- /transmla/partial_rope.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from copy import deepcopy 5 | from typing import Optional, Tuple 6 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS 7 | 8 | from utils import get_qkv_calibrate_outputs, evaluate_ppl 9 | 10 | def rotate_half(x, group): 11 | rotate_x = [] 12 | dh = x.shape[-1] // group 13 | for i in range(group): 14 | rotate_x.append(-x[..., i * dh + dh // 2 : (i + 1) * dh]) 15 | rotate_x.append(x[..., i * dh : i * dh + dh // 2]) 16 | return torch.cat(rotate_x, dim=-1) 17 | 18 | def apply_rotary_pos_emb(q, k, cos, sin, rope_head=1): 19 | rope_dim = cos.shape[-1] * rope_head 20 | nope_dim = q.shape[-1] - rope_dim 21 | q_rope, q_nope = q.split([rope_dim, nope_dim], dim=-1) 22 | k_rope, k_nope = k.split([rope_dim, nope_dim], dim=-1) 23 | 24 | cos = cos.unsqueeze(1) 25 | sin = sin.unsqueeze(1) 26 | rope_repeat = q_rope.shape[-1] // cos.shape[-1] 27 | q_rope_embed = q_rope * cos.repeat(1,1,1,rope_repeat) + rotate_half(q_rope, rope_repeat) * sin.repeat(1,1,1,rope_repeat) 28 | rope_repeat = k_rope.shape[-1] // cos.shape[-1] 29 | k_rope_embed = k_rope * cos.repeat(1,1,1,rope_repeat) + rotate_half(k_rope, rope_repeat) * sin.repeat(1,1,1,rope_repeat) 30 | 31 | q_embed = torch.cat([q_rope_embed, q_nope], dim=-1) 32 | k_embed = torch.cat([k_rope_embed, k_nope], dim=-1) 33 | return q_embed, k_embed 34 | 35 | class PartialRope(nn.Module): 36 | def __init__(self, self_attn, key_outputs=None, freqfold=1, rope_head=1, collapse=1): 37 | super().__init__() 38 | self.config = self_attn.config 39 | self.layer_idx = self_attn.layer_idx 40 | self.hidden_size = self_attn.config.hidden_size 41 | self.num_attention_heads = self_attn.config.num_attention_heads 42 | self.head_dim = self_attn.head_dim 43 | self.num_key_value_heads = self_attn.config.num_key_value_heads 44 | self.latent_dim = self.num_key_value_heads * self.head_dim 45 | self.attention_dropout = self_attn.attention_dropout 46 | self.rope_head = rope_head 47 | self.collapse = collapse 48 | self.scaling = self.head_dim**(-0.5) 49 | self.attention_function = ALL_ATTENTION_FUNCTIONS["sdpa"] 50 | assert freqfold % self.collapse == 0, f"freqfold ({freqfold}) must be divisible by collapse ({self.collapse})" 51 | 52 | self.q_proj = self_attn.q_proj 53 | self.k_proj = self_attn.k_proj 54 | self.v_proj = self_attn.v_proj 55 | self.o_proj = self_attn.o_proj 56 | self.__insert_kv_up_proj__() 57 | if key_outputs is not None: 58 | Rk = self.joint_complex_pca(key_outputs, freqfold) 59 | self.rotate_k_proj(Rk, freqfold=freqfold) 60 | self.rotate_k_up_proj(Rk, freqfold=freqfold) 61 | 62 | def __insert_kv_up_proj__(self): 63 | self.k_up_proj = nn.Linear(self.latent_dim, self.hidden_size, bias=False, dtype=self.k_proj.weight.dtype, device=self.k_proj.weight.device) 64 | self.v_up_proj = nn.Linear(self.latent_dim, self.hidden_size, bias=False, dtype=self.v_proj.weight.dtype, device=self.v_proj.weight.device) 65 | kv_groups = self.num_attention_heads // self.num_key_value_heads 66 | k_up_eye = torch.eye(self.latent_dim, dtype=self.k_proj.weight.dtype, device=self.k_proj.weight.device) 67 | v_up_eye = torch.eye(self.latent_dim, dtype=self.v_proj.weight.dtype, device=self.v_proj.weight.device) 68 | k_up_eye = k_up_eye.reshape(self.num_key_value_heads, self.head_dim, self.latent_dim) 69 | v_up_eye = v_up_eye.reshape(self.num_key_value_heads, self.head_dim, self.latent_dim) 70 | self.k_up_proj.weight.data = torch.stack([k_up_eye]*kv_groups,dim=1).reshape(-1, self.latent_dim).contiguous() 71 | self.v_up_proj.weight.data = torch.stack([v_up_eye]*kv_groups,dim=1).reshape(-1, self.latent_dim).contiguous() 72 | 73 | @torch.no_grad() 74 | def joint_complex_pca(self, Z: list[torch.Tensor], freqfold: int = 1) -> torch.Tensor: 75 | dtype = self.k_proj.weight.dtype 76 | eigen_vecs = [] 77 | for i in range(self.head_dim//2//freqfold): 78 | H = None 79 | for Z_batch in Z: 80 | b,n,d = Z_batch.shape 81 | head_batch = deepcopy(Z_batch).view(b,n, self.num_key_value_heads, 2, self.head_dim//2//freqfold, freqfold//self.collapse, self.collapse) 82 | head_batch = head_batch.permute(0, 1, 3, 6, 2, 5, 4) 83 | head_batch = head_batch.reshape(b,n*2, self.num_key_value_heads*freqfold, self.head_dim//2//freqfold) 84 | head_batch_i = head_batch[:,:,:,i].double().to(self.k_proj.weight.device) 85 | head_batch_i = torch.sum(head_batch_i.mT @ head_batch_i, dim=0) # sum over the batch dimension. 86 | H = head_batch_i if H is None else H + head_batch_i 87 | damp = 0.01 * torch.mean(torch.diag(H)) 88 | diag = torch.arange(H.shape[-1]).to(self.k_proj.weight.device) 89 | H[diag, diag] = H[diag, diag] + damp 90 | X_eig = torch.linalg.eigh(H) 91 | del H 92 | index = torch.argsort(X_eig[0], descending=True) 93 | eigen_vecs.append(X_eig[1][:, index]) 94 | return torch.stack(eigen_vecs+eigen_vecs).to(dtype) 95 | 96 | def rotate_k_proj(self, U, freqfold=1): 97 | k_weight = deepcopy(self.k_proj.weight.data) 98 | U = U.to(k_weight.dtype).to(k_weight.device) 99 | if self.k_proj.bias is not None: 100 | k_bias = deepcopy(self.k_proj.bias.data) 101 | k_weight = torch.cat([k_weight, k_bias.unsqueeze(1)], dim=1) 102 | k_weight = k_weight.reshape(self.num_key_value_heads, self.head_dim//freqfold, freqfold//self.collapse, self.collapse, -1) 103 | k_weight = k_weight.permute(3, 0, 2, 1, 4).reshape(self.num_key_value_heads*freqfold, self.head_dim//freqfold, -1) 104 | k_weight = torch.einsum("dhc,hdD->cdD", U, k_weight) 105 | k_weight = k_weight.reshape(self.collapse, self.num_key_value_heads, freqfold//self.collapse, self.head_dim//freqfold, -1) 106 | k_weight = k_weight.permute(0, 1, 3, 2, 4).reshape(self.num_key_value_heads, self.head_dim, -1) 107 | if self.k_proj.bias is not None: 108 | k_bias = k_weight[:, :, -1] 109 | k_weight = k_weight[:, :, :-1] 110 | self.k_proj.bias.data = k_bias.flatten().contiguous() 111 | assert self.k_proj.weight.data.shape == (self.latent_dim, self.hidden_size) 112 | self.k_proj.weight.data = k_weight.reshape(self.latent_dim, self.hidden_size).contiguous() 113 | 114 | def rotate_k_up_proj(self, U, freqfold=1): 115 | k_up_weight = deepcopy(self.k_up_proj.weight.data) 116 | U = U.to(k_up_weight.dtype).to(k_up_weight.device) 117 | k_up_weight = k_up_weight.reshape(-1, self.num_key_value_heads, self.head_dim//freqfold, freqfold//self.collapse, self.collapse) 118 | k_up_weight = k_up_weight.permute(0, 4, 1, 3, 2).reshape(-1, self.num_key_value_heads*freqfold, self.head_dim//freqfold) 119 | k_up_weight = torch.einsum("dhc,Dhd->Dcd", U, k_up_weight) 120 | k_up_weight = k_up_weight.reshape(-1, self.collapse, self.num_key_value_heads, freqfold//self.collapse, self.head_dim//freqfold) 121 | k_up_weight = k_up_weight.permute(0, 1, 2, 4, 3).reshape(-1, self.latent_dim) 122 | # assert self.k_up_proj.weight.data.shape == (self.hidden_size, self.latent_dim) 123 | self.k_up_proj.weight.data = k_up_weight.contiguous() 124 | 125 | def forward( 126 | self, 127 | hidden_states: torch.Tensor, 128 | attention_mask: Optional[torch.Tensor] = None, 129 | position_ids: Optional[torch.LongTensor] = None, 130 | past_key_value = None, 131 | output_attentions: bool = False, 132 | use_cache: bool = False, 133 | cache_position: Optional[torch.LongTensor] = None, 134 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 135 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 136 | bsz, q_len, _ = hidden_states.size() 137 | 138 | query_states = self.q_proj(hidden_states) 139 | key_states = self.k_proj(hidden_states) 140 | value_states = self.v_proj(hidden_states) 141 | 142 | query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim) 143 | k_up_weight = self.k_up_proj.weight.view(self.num_attention_heads, self.head_dim, self.latent_dim) 144 | query_states = torch.einsum("bthd,hdc->bhtc", query_states, k_up_weight) 145 | 146 | key_states = key_states.view(bsz, 1, q_len, self.latent_dim) 147 | value_states = value_states.view(bsz, 1, q_len, self.latent_dim) 148 | 149 | cos, sin = position_embeddings 150 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos[:,:,::self.collapse], sin[:,:,::self.collapse], self.rope_head) 151 | 152 | if past_key_value is not None: 153 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models 154 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 155 | 156 | v_up_weight = self.v_up_proj.weight.view(self.num_key_value_heads, self.num_attention_heads//self.num_key_value_heads, self.head_dim, self.latent_dim) 157 | value_states = torch.einsum("bhtc,hgdc->bhgtd", value_states, v_up_weight) 158 | value_states = value_states.reshape(bsz, self.num_attention_heads, -1, self.head_dim) 159 | 160 | 161 | attn_output, attn_weights = self.attention_function( 162 | self, 163 | query_states, 164 | key_states, 165 | value_states, 166 | attention_mask, 167 | dropout=0.0 if not self.training else self.attention_dropout, 168 | scaling=self.scaling, 169 | ) 170 | 171 | attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() 172 | attn_output = self.o_proj(attn_output) 173 | 174 | return attn_output, attn_weights 175 | 176 | 177 | 178 | def partial_rope(model, tokenizer, train_loader, test_loader, **kwargs): 179 | 180 | freqfold = kwargs["freqfold"] 181 | collapse = kwargs["collapse"] 182 | 183 | message = "Calibrating original model's qkv outputs" 184 | ori_qkv_outputs = get_qkv_calibrate_outputs(model, train_loader, message) 185 | 186 | def partial_rope_freqfold(model, ori_qkv_outputs, test_loader, freqfold: int, collapse): 187 | for layer_idx, layer in enumerate(model.model.layers): 188 | setattr(layer, "self_attn", PartialRope( 189 | layer.self_attn, 190 | ori_qkv_outputs["key"][layer_idx], 191 | freqfold=freqfold, 192 | collapse=collapse, 193 | )) 194 | 195 | if test_loader: 196 | message = f"Evaluating partial-rope model's ppl, freqfold={freqfold}" 197 | dataset_ppl = evaluate_ppl(model, tokenizer.pad_token_id, test_loader, message) 198 | print(f'Partial RoPE ppl, freqfold={freqfold}: {dataset_ppl:.4f}') 199 | return model, dataset_ppl 200 | else: 201 | return model, None 202 | 203 | if freqfold != "auto": 204 | freqfold = int(freqfold) 205 | return partial_rope_freqfold(model, ori_qkv_outputs, test_loader, freqfold, collapse)[0] 206 | else: 207 | assert test_loader is not None, "test_loader is required for auto freqfold detection" 208 | device = model.device 209 | model_original = model.to("cpu") 210 | 211 | print(f"Auto freqfold detection...") 212 | 213 | best_freqfold = freqfold = collapse 214 | best_ppl = float("inf") 215 | while freqfold <= model_original.config.head_dim // 2: 216 | model = deepcopy(model_original) 217 | model = model.to(device) 218 | model, ppl = partial_rope_freqfold(model, ori_qkv_outputs, test_loader, freqfold, collapse) 219 | if ppl < best_ppl: 220 | best_ppl = ppl 221 | best_freqfold = freqfold 222 | freqfold *= 2 223 | else: 224 | break 225 | 226 | model = deepcopy(model_original) 227 | model = model.to(device) 228 | model, _ = partial_rope_freqfold(model, ori_qkv_outputs, None, best_freqfold, collapse) 229 | 230 | print(f"Best freqfold: {best_freqfold}") 231 | 232 | return model, best_freqfold -------------------------------------------------------------------------------- /transmla/transformers/gemma2/configuration_gemma2mla.py: -------------------------------------------------------------------------------- 1 | from transformers.models.gemma2.configuration_gemma2 import Gemma2Config 2 | 3 | class Gemma2MLAConfig(Gemma2Config): 4 | model_type = "gemma2mla" 5 | 6 | def __init__( 7 | self, 8 | *args, 9 | kv_lora_rank=512, 10 | q_lora_rank=None, 11 | qk_rope_head_dim=64, 12 | qk_nope_head_dim=128, 13 | v_head_dim=128, 14 | query_pre_attn_scalar=128, 15 | attention_bias=False, 16 | softcap=None, 17 | **kwargs 18 | ): 19 | super().__init__(*args, **kwargs) 20 | 21 | self.kv_lora_rank = kv_lora_rank 22 | self.q_lora_rank = q_lora_rank 23 | self.qk_rope_head_dim = qk_rope_head_dim 24 | self.qk_nope_head_dim = qk_nope_head_dim 25 | self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim 26 | self.v_head_dim = v_head_dim 27 | self.query_pre_attn_scalar = query_pre_attn_scalar 28 | self.softcap = softcap 29 | self.attention_bias = attention_bias -------------------------------------------------------------------------------- /transmla/transformers/gemma2/modeling_gemma2mla.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.cache_utils import Cache 6 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs 7 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel 8 | from transformers.processing_utils import Unpack 9 | from transformers.utils import LossKwargs 10 | 11 | from transformers.models.gemma2.modeling_gemma2 import ( 12 | Gemma2Model, 13 | Gemma2DecoderLayer, 14 | Gemma2PreTrainedModel, 15 | Gemma2ForCausalLM 16 | ) 17 | 18 | from .configuration_gemma2mla import Gemma2MLAConfig 19 | from .mla import MLAAttention, eager_attention_forward 20 | 21 | 22 | class Gemma2MLADecoderLayer(Gemma2DecoderLayer): 23 | 24 | def __init__(self, config: Gemma2MLAConfig, layer_idx: int): 25 | super().__init__(config, layer_idx) 26 | self.self_attn = MLAAttention(config, layer_idx) 27 | 28 | 29 | class Gemma2MLAPreTrainedModel(Gemma2PreTrainedModel): 30 | 31 | config_class = Gemma2MLAConfig 32 | _no_split_modules = ["Gemma2MLADecoderLayer"] 33 | 34 | 35 | class Gemma2MLAModel(Gemma2MLAPreTrainedModel, Gemma2Model): 36 | 37 | def __init__(self, config: Gemma2MLAConfig): 38 | super().__init__(config) 39 | 40 | self.layers = nn.ModuleList( 41 | [Gemma2MLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 42 | ) 43 | 44 | 45 | class Gemma2MLAForCausalLM(Gemma2MLAPreTrainedModel, Gemma2ForCausalLM): 46 | 47 | def __init__(self, config): 48 | super().__init__(config) 49 | self.model = Gemma2MLAModel(config) 50 | 51 | 52 | __all__ = [ 53 | "Gemma2MLAForCausalLM", 54 | "Gemma2MLAModel", 55 | "Gemma2MLAPreTrainedModel", 56 | ] -------------------------------------------------------------------------------- /transmla/transformers/llama/configuration_llamamla.py: -------------------------------------------------------------------------------- 1 | from transformers.models.llama.configuration_llama import LlamaConfig 2 | 3 | class LlamaMLAConfig(LlamaConfig): 4 | model_type = "llamamla" 5 | 6 | def __init__( 7 | self, 8 | *args, 9 | kv_lora_rank=512, 10 | q_lora_rank=None, 11 | qk_rope_head_dim=64, 12 | qk_nope_head_dim=128, 13 | v_head_dim=128, 14 | query_pre_attn_scalar=128, 15 | softcap=None, 16 | **kwargs 17 | ): 18 | super().__init__(*args, **kwargs) 19 | 20 | self.kv_lora_rank = kv_lora_rank 21 | self.q_lora_rank = q_lora_rank 22 | self.qk_rope_head_dim = qk_rope_head_dim 23 | self.qk_nope_head_dim = qk_nope_head_dim 24 | self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim 25 | self.v_head_dim = v_head_dim 26 | self.query_pre_attn_scalar = query_pre_attn_scalar 27 | self.softcap = softcap -------------------------------------------------------------------------------- /transmla/transformers/llama/modeling_llamamla.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.cache_utils import Cache 6 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs 7 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel 8 | from transformers.processing_utils import Unpack 9 | from transformers.utils import LossKwargs 10 | 11 | from transformers.models.llama.modeling_llama import ( 12 | LlamaModel, 13 | LlamaDecoderLayer, 14 | LlamaPreTrainedModel, 15 | LlamaForCausalLM 16 | ) 17 | 18 | from .configuration_llamamla import LlamaMLAConfig 19 | from .mla import MLAAttention, eager_attention_forward 20 | 21 | 22 | class LlamaMLADecoderLayer(LlamaDecoderLayer): 23 | 24 | def __init__(self, config: LlamaMLAConfig, layer_idx: int): 25 | super().__init__(config, layer_idx) 26 | self.self_attn = MLAAttention(config, layer_idx) 27 | 28 | 29 | class LlamaMLAPreTrainedModel(LlamaPreTrainedModel): 30 | 31 | config_class = LlamaMLAConfig 32 | _no_split_modules = ["LlamaMLADecoderLayer"] 33 | 34 | 35 | class LlamaMLAModel(LlamaMLAPreTrainedModel, LlamaModel): 36 | 37 | def __init__(self, config: LlamaMLAConfig): 38 | super().__init__(config) 39 | 40 | self.layers = nn.ModuleList( 41 | [LlamaMLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 42 | ) 43 | 44 | 45 | class LlamaMLAForCausalLM(LlamaMLAPreTrainedModel, LlamaForCausalLM): 46 | 47 | def __init__(self, config): 48 | super().__init__(config) 49 | self.model = LlamaMLAModel(config) 50 | 51 | 52 | __all__ = [ 53 | "LlamaMLAForCausalLM", 54 | "LlamaMLAModel", 55 | "LlamaMLAPreTrainedModel", 56 | ] -------------------------------------------------------------------------------- /transmla/transformers/mixtral/configuration_mixtralmla.py: -------------------------------------------------------------------------------- 1 | from transformers.models.mixtral.configuration_mixtral import MixtralConfig 2 | 3 | class MixtralMLAConfig(MixtralConfig): 4 | model_type = "mixtralmla" 5 | 6 | def __init__( 7 | self, 8 | *args, 9 | kv_lora_rank=512, 10 | q_lora_rank=None, 11 | qk_rope_head_dim=64, 12 | qk_nope_head_dim=128, 13 | v_head_dim=128, 14 | query_pre_attn_scalar=128, 15 | attention_bias=False, 16 | softcap=None, 17 | **kwargs 18 | ): 19 | super().__init__(*args, **kwargs) 20 | 21 | self.kv_lora_rank = kv_lora_rank 22 | self.q_lora_rank = q_lora_rank 23 | self.qk_rope_head_dim = qk_rope_head_dim 24 | self.qk_nope_head_dim = qk_nope_head_dim 25 | self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim 26 | self.v_head_dim = v_head_dim 27 | self.query_pre_attn_scalar = query_pre_attn_scalar 28 | self.softcap = softcap 29 | self.attention_bias = attention_bias -------------------------------------------------------------------------------- /transmla/transformers/mixtral/modeling_mixtralmla.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.cache_utils import Cache 6 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs 7 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel 8 | from transformers.processing_utils import Unpack 9 | from transformers.utils import LossKwargs 10 | 11 | from transformers.models.mixtral.modeling_mixtral import ( 12 | MixtralModel, 13 | MixtralDecoderLayer, 14 | MixtralPreTrainedModel, 15 | MixtralForCausalLM 16 | ) 17 | 18 | from .configuration_mixtralmla import MixtralMLAConfig 19 | from .mla import MLAAttention, eager_attention_forward 20 | 21 | 22 | class MixtralMLADecoderLayer(MixtralDecoderLayer): 23 | 24 | def __init__(self, config: MixtralMLAConfig, layer_idx: int): 25 | super().__init__(config, layer_idx) 26 | self.self_attn = MLAAttention(config, layer_idx) 27 | 28 | 29 | class MixtralMLAPreTrainedModel(MixtralPreTrainedModel): 30 | 31 | config_class = MixtralMLAConfig 32 | _no_split_modules = ["MixtralMLADecoderLayer"] 33 | 34 | 35 | class MixtralMLAModel(MixtralMLAPreTrainedModel, MixtralModel): 36 | 37 | def __init__(self, config: MixtralMLAConfig): 38 | super().__init__(config) 39 | 40 | self.layers = nn.ModuleList( 41 | [MixtralMLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 42 | ) 43 | 44 | 45 | class MixtralMLAForCausalLM(MixtralMLAPreTrainedModel, MixtralForCausalLM): 46 | 47 | def __init__(self, config): 48 | super().__init__(config) 49 | self.model = MixtralMLAModel(config) 50 | 51 | 52 | __all__ = [ 53 | "MixtralMLAForCausalLM", 54 | "MixtralMLAModel", 55 | "MixtralMLAPreTrainedModel", 56 | ] -------------------------------------------------------------------------------- /transmla/transformers/mla.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from transformers.cache_utils import Cache 6 | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs 7 | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel 8 | from transformers.processing_utils import Unpack 9 | from transformers.utils import LossKwargs 10 | 11 | from transformers.models.gemma2.modeling_gemma2 import ( 12 | repeat_kv, 13 | apply_rotary_pos_emb, 14 | eager_attention_forward 15 | ) 16 | 17 | 18 | class MLAAttention(nn.Module): 19 | """ 20 | Modified from `transformers.models.llama.modeling_deepseek_v3.DeepseekV3Attention` 21 | add support for attention bias and softcapping 22 | """ 23 | def __init__(self, config, layer_idx: int): 24 | 25 | super().__init__() 26 | self.config = config 27 | self.layer_idx = layer_idx 28 | 29 | self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads 30 | self.attention_dropout = config.attention_dropout 31 | self.num_heads = config.num_attention_heads 32 | self.rope_theta = config.rope_theta 33 | self.q_lora_rank = config.q_lora_rank 34 | self.kv_lora_rank = config.kv_lora_rank 35 | self.qk_rope_head_dim = config.qk_rope_head_dim 36 | self.qk_nope_head_dim = config.qk_nope_head_dim 37 | self.v_head_dim = config.v_head_dim 38 | self.qk_head_dim = config.qk_head_dim 39 | self.softcap = config.softcap 40 | 41 | self.is_causal = True 42 | if self.q_lora_rank is None: 43 | self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=config.attention_bias) 44 | else: 45 | self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=False) 46 | self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=config.attention_bias) 47 | 48 | self.kv_a_proj_with_mqa = nn.Linear( 49 | config.hidden_size, 50 | self.kv_lora_rank + self.qk_rope_head_dim, 51 | bias=config.attention_bias, 52 | ) 53 | self.kv_b_proj = nn.Linear( 54 | self.kv_lora_rank, 55 | self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), 56 | bias=config.attention_bias, 57 | ) 58 | 59 | self.o_proj = nn.Linear( 60 | self.num_heads * self.v_head_dim, 61 | config.hidden_size, 62 | bias=False, 63 | ) 64 | 65 | self.scaling = self.config.query_pre_attn_scalar ** (-0.5) 66 | 67 | def forward( 68 | self, 69 | hidden_states: torch.Tensor, 70 | position_embeddings: Tuple[torch.Tensor, torch.Tensor], 71 | attention_mask: Optional[torch.Tensor], 72 | past_key_value: Optional[Cache] = None, 73 | cache_position: Optional[torch.LongTensor] = None, 74 | **kwargs: Unpack[FlashAttentionKwargs], 75 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 76 | batch_size, seq_length = hidden_states.shape[:-1] 77 | query_shape = (batch_size, seq_length, -1, self.qk_head_dim) 78 | key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) 79 | if self.q_lora_rank is None: 80 | q_states = self.q_proj(hidden_states) 81 | else: 82 | q_states = self.q_b_proj(self.q_a_proj(hidden_states)) 83 | q_states = q_states.view(query_shape).transpose(1, 2) 84 | q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) 85 | compressed_kv = self.kv_a_proj_with_mqa(hidden_states) 86 | k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) 87 | 88 | k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2) 89 | k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) 90 | 91 | k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) 92 | 93 | cos, sin = position_embeddings 94 | q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) 95 | k_rot = k_rot.expand(*k_pass.shape[:-1], -1) 96 | 97 | query_states = torch.cat((q_pass, q_rot), dim=-1) 98 | key_states = torch.cat((k_pass, k_rot), dim=-1) 99 | 100 | if past_key_value is not None: 101 | # sin and cos are specific to RoPE models; cache_position needed for the static cache 102 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 103 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 104 | 105 | if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: 106 | value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) 107 | 108 | attention_interface = eager_attention_forward 109 | if self.config._attn_implementation != "eager": 110 | if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): 111 | logger.warning_once( 112 | "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 113 | 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 114 | ) 115 | else: 116 | attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] 117 | 118 | attn_output, attn_weights = attention_interface( 119 | self, 120 | query_states, 121 | key_states, 122 | value_states, 123 | attention_mask, 124 | dropout=0.0 if not self.training else self.attention_dropout, 125 | scaling=self.scaling, 126 | softcap=self.softcap, 127 | **kwargs, 128 | ) 129 | if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: 130 | attn_output = attn_output[:, :, :, : self.v_head_dim] 131 | attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() 132 | attn_output = self.o_proj(attn_output) 133 | return attn_output, attn_weights 134 | -------------------------------------------------------------------------------- /transmla/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | import torch 4 | import datasets 5 | from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler 6 | from transformers import PreTrainedTokenizerBase 7 | from tqdm import tqdm 8 | import logging 9 | 10 | def get_dataset(name: str) -> datasets.DatasetDict: 11 | """ 12 | Get the dataset from the HuggingFace datasets library. 13 | 14 | Args: 15 | name: The name of the HuggingFace dataset to load. Must be one of "wikitext2", "ptb", "c4" or "alpaca". 16 | 17 | Returns: 18 | The dataset. 19 | """ 20 | logging.info(f"Loading dataset: {name}") 21 | 22 | ds_properties = { 23 | "wikitext2": {"path": "wikitext", "config_name": "wikitext-2-raw-v1"}, 24 | "ptb": {"path": "ptb_text_only", "config_name": "penn_treebank"}, 25 | "c4": { 26 | "path": "allenai/c4", 27 | "config_name": "en", 28 | "data_files": { 29 | "train": "en/c4-train.00000-of-01024.json.gz", 30 | "validation": "en/c4-validation.00000-of-00008.json.gz", 31 | }, 32 | "cols_to_remove": ['url', 'timestamp'], 33 | }, 34 | "alpaca": {"path": "tatsu-lab/alpaca", "cols_to_remove": ['input', 'output', 'instruction']}, 35 | } 36 | 37 | if name not in ds_properties: 38 | raise NotImplementedError("The provided dataset is not supported") 39 | 40 | properties = ds_properties[name] 41 | ds = datasets.load_dataset( 42 | properties["path"], name=properties.get("config_name"), data_files=properties.get("data_files") 43 | ) 44 | 45 | if "cols_to_remove" in properties: 46 | ds = ds.remove_columns(properties["cols_to_remove"]) 47 | 48 | # if alpaca, create a test and validation set from the training set 49 | if name == "alpaca": 50 | ds = ds["train"].train_test_split(test_size=0.2, seed=42) 51 | temp_ds = ds.pop("test") 52 | temp_ds = temp_ds.train_test_split(test_size=0.5, seed=42) 53 | ds["test"] = temp_ds["train"] 54 | ds["validation"] = temp_ds["test"] 55 | 56 | logging.info("Loading dataset done") 57 | return ds 58 | 59 | def prepare_test_dataloader( 60 | dataset: datasets.Dataset, tokenizer: PreTrainedTokenizerBase, seqlen: int = 2048, batch_size: int = 1 61 | ) -> DataLoader[dict[str, torch.Tensor]]: 62 | """ 63 | Get a DataLoader from a test dataset. This dataloader should be used when comparing WikiText2 perplexities with other papers, e.g. SparseGPT (arxiv.org/abs/2301.00774). 64 | 65 | Args: 66 | dataset: The dataset to create a dataloader from. 67 | tokenizer: The tokenizer to use. 68 | seqlen: The sequence length of sequences in the dataset. 69 | batch_size: The batch size. 70 | 71 | Returns: 72 | A DataLoader. 73 | """ 74 | 75 | logging.info(f"Preparing test dataloader") 76 | 77 | class TestDataset(Dataset): 78 | def __init__(self, ds, tokenizer, seqlen=2048): 79 | """Tokenize the entire dataset and reshape it into sequences of length seqlen.""" 80 | 81 | tokenized_ds = tokenizer("\n\n".join(ds['text']), return_tensors='pt') 82 | nsamples = tokenized_ds.input_ids.numel() // seqlen 83 | 84 | input_ids = tokenized_ds.input_ids[0, : nsamples * seqlen] 85 | input_ids = input_ids.reshape(nsamples, seqlen) 86 | attn_mask = tokenized_ds.attention_mask[0, : nsamples * seqlen] 87 | attn_mask = attn_mask.reshape(nsamples, seqlen) 88 | 89 | self.input_ids = input_ids 90 | self.attn_mask = attn_mask 91 | 92 | def __getitem__(self, idx): 93 | return {"input_ids": self.input_ids[idx], "attention_mask": self.attn_mask[idx]} 94 | 95 | def __len__(self): 96 | return len(self.input_ids) 97 | 98 | test_ds = TestDataset(dataset, tokenizer, seqlen) 99 | loader = DataLoader(test_ds, batch_size=batch_size) 100 | logging.info(f"Preparing test dataloader done") 101 | return loader 102 | 103 | def prepare_dataloader( 104 | dataset: datasets.Dataset, 105 | tokenizer: PreTrainedTokenizerBase, 106 | max_seqlen: int = 2048, 107 | batch_size: int = 1, 108 | nsamples: int = 128, 109 | varied_seqlen: bool = False, 110 | seed=42, 111 | ) -> DataLoader[dict[str, torch.Tensor]]: 112 | """ 113 | Get a DataLoader from a dataset. 114 | 115 | Args: 116 | dataset: The dataset to create a dataloader from. 117 | tokenizer: The tokenizer to use. 118 | max_seqlen: The maximum sequence length, used for truncation of sequences in the dataset. 119 | batch_size: The batch size. 120 | nsamples: The number of samples to produce. 121 | varied_seqlen: If False, concatenate multiple examples from the dataset into one example until max_seqlen is reached. 122 | seed: The seed for sampling the dataset. 123 | 124 | Returns: 125 | A DataLoader. 126 | """ 127 | logging.info(f"Preparing dataloader") 128 | 129 | if not varied_seqlen and not nsamples: 130 | logging.warning( 131 | "varied_seqlen=False, but nsamples is not specified. This will lead to tokenization of the entire dataset, which will be slow." 132 | ) 133 | 134 | data_name = dataset.column_names[0] 135 | ds = dataset.filter(lambda x: len(x[data_name]) > 0) 136 | 137 | if not varied_seqlen: 138 | # create a new dataset where each example is a concatenation of multiple examples of total length = max_seqlen. 139 | data_list = ds[data_name] 140 | new_data_list = [] 141 | 142 | torch.manual_seed(seed) 143 | indices = list(range(len(data_list))) 144 | 145 | while len(new_data_list) < nsamples and len(indices) > 0: 146 | start_idx = torch.randint(0, len(indices), (1,)).item() 147 | idx = start_idx 148 | tokens = [] 149 | while len(tokens) < max_seqlen and idx < len(indices): 150 | item = data_list[indices[idx]] 151 | sep = "" if not tokens else "\n\n" 152 | tokens += tokenizer.tokenize(sep + item) 153 | idx += 1 154 | 155 | indices = indices[:start_idx] + indices[idx:] # remove the used indices 156 | 157 | if len(tokens) >= max_seqlen: 158 | tokens = tokens[:max_seqlen] # truncate to max_seqlen 159 | new_data_list.append(tokenizer.convert_tokens_to_string(tokens)) 160 | 161 | ds = datasets.Dataset.from_dict({data_name: new_data_list}) 162 | 163 | def tokenize(data_batch): 164 | # tokenize then pad each batch according to the longest sequence in the batch 165 | batch = tokenizer( 166 | data_batch[data_name], 167 | padding="longest", 168 | max_length=max_seqlen, 169 | truncation=True, 170 | return_tensors="pt", 171 | ) 172 | batch["labels"] = batch["input_ids"].clone() 173 | return batch 174 | 175 | # tokenize lazily 176 | ds.set_transform(tokenize) 177 | 178 | torch.manual_seed(seed) 179 | sampler = SubsetRandomSampler(torch.randperm(len(ds))[:nsamples]) 180 | 181 | loader = DataLoader(ds, batch_size=batch_size, sampler=sampler) 182 | logging.info(f"Preparing dataloader done") 183 | return loader 184 | 185 | def sync_gpus() -> None: 186 | """Sync all GPUs to make sure all operations are finished, needed for correct benchmarking of latency/throughput.""" 187 | for i in range(torch.cuda.device_count()): 188 | torch.cuda.synchronize(device=i) 189 | 190 | def map_tensors(obj, device: torch.device | str | None = None, dtype: torch.dtype | None = None): 191 | """Recursively map tensors to device and dtype.""" 192 | if isinstance(obj, torch.Tensor): 193 | if device is not None: 194 | obj = obj.to(device=device) 195 | if dtype is not None: 196 | obj = obj.to(dtype=dtype) 197 | return obj 198 | elif isinstance(obj, (list, tuple)): 199 | return type(obj)(map_tensors(x, device, dtype) for x in obj) 200 | elif isinstance(obj, dict): 201 | return {k: map_tensors(v, device, dtype) for k, v in obj.items()} # type: ignore 202 | else: 203 | return obj 204 | 205 | @torch.no_grad() 206 | def evaluate_ppl( 207 | model: torch.nn.Module, 208 | pad_token_id: int | None, 209 | testloader: DataLoader[dict[str, torch.Tensor]], 210 | message: str = "Evaluating perplexity" 211 | ) -> float: 212 | """ 213 | Evaluate the model's perplexity on the test set using batch processing. 214 | It is expected that model is already on the correct device. 215 | """ 216 | sync_gpus() 217 | 218 | start_time = time.time() 219 | 220 | model.eval() 221 | 222 | if pad_token_id: 223 | loss_fn = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=pad_token_id) 224 | else: 225 | loss_fn = torch.nn.CrossEntropyLoss(reduction="none") 226 | 227 | nlls = [] 228 | 229 | logging.info(message) 230 | for batch in tqdm(testloader, desc=message): 231 | logging.debug(f"Evaluating batch {len(nlls)}") 232 | batch = map_tensors(batch, model.model.embed_tokens.weight.device) 233 | logits = model(**batch, use_cache=False).logits 234 | 235 | # shift outputs and labels autoregressively. 236 | logits = logits[:, :-1, :] 237 | shift_labels = batch["input_ids"][:, 1:] 238 | 239 | # CrossEntropyLoss demands data dimension is dimension 1. 240 | nll = loss_fn(logits.permute(0, 2, 1), shift_labels).float() 241 | 242 | mask = shift_labels != loss_fn.ignore_index 243 | nll_means = (nll * mask).sum(dim=1) / mask.sum(dim=1) 244 | nlls.append(nll_means) 245 | 246 | nlls_tensor = torch.cat(nlls) 247 | ppl = torch.exp(nlls_tensor.mean()) 248 | 249 | sync_gpus() 250 | 251 | elapsed = time.time() - start_time 252 | logging.info( 253 | "Time spent on evaluation: %s", 254 | time.strftime("%H:%M:%S.{}".format(str(elapsed % 1)[2:])[:13], time.gmtime(elapsed)), 255 | ) 256 | 257 | return ppl.item() 258 | 259 | def insert_qkv_hooks(model): 260 | query_hooks = [] 261 | key_hooks = [] 262 | value_hooks = [] 263 | q_a_proj_hooks = [] 264 | kv_a_proj_with_mqa_hooks = [] 265 | query_outputs = {} 266 | key_outputs = {} 267 | value_outputs = {} 268 | q_a_proj_outputs = {} 269 | kv_a_proj_with_mqa_outputs = {} 270 | 271 | def query_hook_fn(module, input, output, index): 272 | if index not in query_outputs: 273 | query_outputs[index] = [] 274 | query_outputs[index].append(output.to('cpu')) 275 | 276 | def key_hook_fn(module, input, output, index): 277 | if index not in key_outputs: 278 | key_outputs[index] = [] 279 | key_outputs[index].append(output.to('cpu')) 280 | 281 | def value_hook_fn(module, input, output, index): 282 | if index not in value_outputs: 283 | value_outputs[index] = [] 284 | value_outputs[index].append(output.to('cpu')) 285 | 286 | def q_a_proj_hook_fn(module, input, output, index): 287 | if index not in q_a_proj_outputs: 288 | q_a_proj_outputs[index] = [] 289 | q_a_proj_outputs[index].append(output.to('cpu')) 290 | 291 | def kv_a_proj_with_mqa_hook_fn(module, input, output, index): 292 | if index not in kv_a_proj_with_mqa_outputs: 293 | kv_a_proj_with_mqa_outputs[index] = [] 294 | kv_a_proj_with_mqa_outputs[index].append(output.to('cpu')) 295 | 296 | for idx, layer in enumerate(model.model.layers): 297 | if hasattr(layer.self_attn, "q_proj"): 298 | query_hook = layer.self_attn.q_proj.register_forward_hook(lambda module, input, output, idx=idx: query_hook_fn(module, input, output, idx)) 299 | query_hooks.append(query_hook) 300 | if hasattr(layer.self_attn, "k_proj"): 301 | key_hook = layer.self_attn.k_proj.register_forward_hook(lambda module, input, output, idx=idx: key_hook_fn(module, input, output, idx)) 302 | key_hooks.append(key_hook) 303 | if hasattr(layer.self_attn, "v_proj"): 304 | value_hook = layer.self_attn.v_proj.register_forward_hook(lambda module, input, output, idx=idx: value_hook_fn(module, input, output, idx)) 305 | value_hooks.append(value_hook) 306 | if hasattr(layer.self_attn, "q_a_proj"): 307 | q_a_proj_hook = layer.self_attn.q_a_proj.register_forward_hook(lambda module, input, output, idx=idx: q_a_proj_hook_fn(module, input, output, idx)) 308 | q_a_proj_hooks.append(q_a_proj_hook) 309 | if hasattr(layer.self_attn, "kv_a_proj_with_mqa"): 310 | kv_a_proj_with_mqa_hook = layer.self_attn.kv_a_proj_with_mqa.register_forward_hook(lambda module, input, output, idx=idx: kv_a_proj_with_mqa_hook_fn(module, input, output, idx)) 311 | kv_a_proj_with_mqa_hooks.append(kv_a_proj_with_mqa_hook) 312 | 313 | return query_hooks, key_hooks, value_hooks, q_a_proj_hooks, kv_a_proj_with_mqa_hooks, query_outputs, key_outputs, value_outputs, q_a_proj_outputs, kv_a_proj_with_mqa_outputs 314 | 315 | @torch.no_grad() 316 | def get_qkv_calibrate_outputs( 317 | model: torch.nn.Module, 318 | trainloader: DataLoader[dict[str, torch.Tensor]], 319 | message: str = "Calibrating QKV" 320 | ): 321 | """ 322 | Take the input signals ("activations") for a layer, run the layer forward. 323 | """ 324 | 325 | start_time = time.time() 326 | 327 | model.eval() 328 | query_hooks, key_hooks, value_hooks, q_a_proj_hooks, kv_a_proj_with_mqa_hooks, query_outputs, key_outputs, value_outputs, q_a_proj_outputs, kv_a_proj_with_mqa_outputs = insert_qkv_hooks(model) 329 | ignore_masks = [] 330 | logging.info(message) 331 | for batch in tqdm(trainloader, desc=message): 332 | batch = map_tensors(batch, model.model.embed_tokens.weight.device) 333 | ignore_masks.append(batch["attention_mask"].to('cpu')) 334 | model(**batch, use_cache=False) 335 | 336 | elapsed = time.time() - start_time 337 | logging.info( 338 | "Time spent on evaluation: %s", 339 | time.strftime("%H:%M:%S.{}".format(str(elapsed % 1)[2:])[:13], time.gmtime(elapsed)), 340 | ) 341 | 342 | for hook in query_hooks: 343 | hook.remove() 344 | for hook in key_hooks: 345 | hook.remove() 346 | for hook in value_hooks: 347 | hook.remove() 348 | for hook in q_a_proj_hooks: 349 | hook.remove() 350 | for hook in kv_a_proj_with_mqa_hooks: 351 | hook.remove() 352 | 353 | for value in query_outputs.values(): 354 | for idx, X_batch in enumerate(value): 355 | if ignore_masks: 356 | X_batch[ignore_masks[idx] == 0] = 0 357 | 358 | for value in key_outputs.values(): 359 | for idx, X_batch in enumerate(value): 360 | if ignore_masks: 361 | X_batch[ignore_masks[idx] == 0] = 0 362 | 363 | for value in value_outputs.values(): 364 | for idx, X_batch in enumerate(value): 365 | if ignore_masks: 366 | X_batch[ignore_masks[idx] == 0] = 0 367 | 368 | for value in q_a_proj_outputs.values(): 369 | for idx, X_batch in enumerate(value): 370 | if ignore_masks: 371 | X_batch[ignore_masks[idx] == 0] = 0 372 | 373 | for value in kv_a_proj_with_mqa_outputs.values(): 374 | for idx, X_batch in enumerate(value): 375 | if ignore_masks: 376 | X_batch[ignore_masks[idx] == 0] = 0 377 | 378 | qkv_outputs = { 379 | "query": query_outputs, 380 | "key": key_outputs, 381 | "value": value_outputs, 382 | "q_a_proj": q_a_proj_outputs, 383 | "kv_a_proj": kv_a_proj_with_mqa_outputs, 384 | } 385 | return qkv_outputs 386 | 387 | @torch.no_grad() 388 | def pca_calc(X: list[torch.Tensor], device: str) -> torch.Tensor: 389 | H = None 390 | for idx, X_batch in enumerate(X): 391 | 392 | X_batch = X_batch.double().to(device) 393 | H_batch = torch.sum(X_batch.mT @ X_batch, dim=0) # sum over the batch dimension. 394 | H = H_batch if H is None else H + H_batch 395 | 396 | damp = 0.01 * torch.mean(torch.diag(H)) 397 | diag = torch.arange(H.shape[-1]).to(device) 398 | H[diag, diag] = H[diag, diag] + damp 399 | X_eig = torch.linalg.eigh(H) 400 | del H 401 | index = torch.argsort(X_eig[0], descending=True) 402 | eigen_vec = X_eig[1][:, index] 403 | return eigen_vec 404 | 405 | def statistics_qkv_rmsnorm(self_attn, q_a_outputs, kv_a_outputs): 406 | if q_a_outputs is not None: 407 | self_attn.q_a_layernorm.weight.data.to(self_attn.q_a_proj.weight.device).to(self_attn.dtype) 408 | q_a_proj = torch.cat(q_a_outputs) 409 | q_a_rmsnorm = torch.rsqrt(q_a_proj.pow(2).mean(-1) + self_attn.q_a_layernorm.eps).mean() 410 | self_attn.q_a_layernorm.weight.data = torch.full_like(self_attn.q_a_layernorm.weight.data, q_a_rmsnorm) 411 | 412 | self_attn.kv_a_layernorm.weight.data.to(self_attn.kv_a_proj_with_mqa.weight.device).to(self_attn.dtype) 413 | kv_a_proj = torch.cat(kv_a_outputs) 414 | kv_a_rmsnorm = torch.rsqrt(kv_a_proj.pow(2).mean(-1) + self_attn.kv_a_layernorm.eps).mean() 415 | self_attn.kv_a_layernorm.weight.data = torch.full_like(self_attn.kv_a_layernorm.weight.data, kv_a_rmsnorm) 416 | --------------------------------------------------------------------------------