├── .gitignore ├── README.md ├── configs ├── default_chat_config.yaml ├── llama_small.json ├── llama_small_lckv.json ├── llama_tiny.json ├── llama_tiny_lckv.json ├── tinyllama.json └── tinyllama_lckv.json ├── convert_pretrained.py ├── models ├── __init__.py ├── cache_utils.py ├── configuration_lckv.py ├── kernel.py ├── modeling_lckv.py ├── ops_rope.py └── utils.py ├── pyproject.toml ├── requirements.txt ├── run_chat.py ├── run_clm.py ├── run_clm.sh ├── run_generation.py ├── run_generation.sh ├── run_sft.py ├── run_sft.sh ├── run_streaming.py ├── run_streaming.sh ├── test_harness.py ├── test_latency.py ├── test_streaming.py └── tests ├── test_kernel.py └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | wandb 3 | outputs 4 | harness 5 | streaming 6 | run*.sh 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Layer-Condensed KV Cache 2 | 3 |
4 | 5 |

6 | The KVs of the top layer 7 |
8 | are the most informative and important. 9 |
10 | So why bother caching the rest? 11 |

12 |
13 | 14 | The code base for project **Layer-Condensed KV Cache**, a new variant of transformer decoders in which queries of all layers are paired with keys and values of just the top layer. It reduces the memory and computation cost, reduces the number of parameters, significantly improves the inference throughput with comparable or better task performance. The paper "[Layer-Condensed KV Cache for Efficient Inference of Large Language Models](https://arxiv.org/abs/2405.10637)" was accepted to ACL 2024 main conference. 15 | 16 | This work is inspired by [Probabilistic Transformer](https://github.com/whyNLP/Probabilistic-Transformer), where we consider the stacking layer structure of a transformer as an iterative process of improving token representation. 17 | 18 |
19 | The Map of AI Approaches 20 |
21 | 22 |
23 |
24 | 25 | ## News 26 | 27 | - [25/01/23] Our paper "[A Systematic Study of Cross-Layer KV Sharing for Efficient LLM Inference](http://arxiv.org/abs/2410.14442)" was accepted to NAACL 2025 main conference. 28 | - [24/12/08] We release the main branch, with a general framework for Cross-Layer KV Sharing. A illustrative post can be found on [PaperWeekly](https://mp.weixin.qq.com/s/Nr7K-xgcQRvHYNs82HU4gQ) (in Chinese). See the [published branch](https://github.com/whyNLP/LCKV/tree/dev-lckv-publish) for the old version of the code. 29 | - [24/10/18] Our new empirical study "[A Systematic Study of Cross-Layer KV Sharing for Efficient LLM Inference](http://arxiv.org/abs/2410.14442)" has released on arXiv. A new configuration has been found to be more efficient than the original LCKV. 30 | - [24/05/28] This code base now also supports Cross-Layer Attention (CLA). The idea is similar, but they 1) divide the transformer layers into small groups with 2-4 layers in each group; 2) pairs the queries of all the layers with the keys and values of the bottom layer in each group. See details in their paper "[Reducing Transformer Key-Value Cache Size with Cross-Layer Attention](http://arxiv.org/abs/2405.12981)". 31 | - [24/05/20] LCKV initial paper and code release. 32 | - [24/05/12] Our paper "[Layer-Condensed KV Cache for Efficient Inference of Large Language Models](http://arxiv.org/abs/2405.10637)" was accepted to ACL 2024 main conference. 33 | - [24/02/14] Our paper "[Layer-Condensed KV Cache for Efficient Inference of Large Language Models](http://arxiv.org/abs/2405.10637)" was submitted to ARR February 2024 cycle. 34 | 35 | ## Quick Start 36 | 37 | We have released a series of pre-trained models described in our paper on HuggingFace. There is no need to clone this repo if you just want to use the pre-trained models. Load the model with the following code: 38 | 39 | ```python 40 | # Use a pipeline as a high-level helper 41 | from transformers import pipeline 42 | pipe = pipeline("text-generation", model="whynlp/tinyllama-lckv-w2-ft-100b", trust_remote_code=True) 43 | 44 | # Load model directly 45 | from transformers import AutoModelForCausalLM 46 | model = AutoModelForCausalLM.from_pretrained("whynlp/tinyllama-lckv-w2-ft-100b", trust_remote_code=True) 47 | ``` 48 | 49 | See more models on the [HuggingFace model hub](https://huggingface.co/models?search=whynlp). Note that these models are for research purposes only and may not be suitable for production. 50 | 51 | | Model | Paper Section | Dev ppl. | Common-sense Reasoning | 52 | | --------------------------------------------------------------------------------------------- | ------------------------------ | -------- | ---------------------- | 53 | | [whynlp/tinyllama-lckv-w10-ft-250b](https://huggingface.co/whynlp/tinyllama-lckv-w10-ft-250b) | -- | 7.939 | 50.86 | 54 | | [whynlp/tinyllama-lckv-w2-ft-100b](https://huggingface.co/whynlp/tinyllama-lckv-w2-ft-100b) | Appendix C.1, Table 7 (line 5) | 8.514 | 49.55 | 55 | | [whynlp/tinyllama-lckv-w10-100b](https://huggingface.co/whynlp/tinyllama-lckv-w10-100b) | Section 3.2, Table 2 (line 3) | 9.265 | 46.84 | 56 | | [whynlp/tinyllama-lckv-w2-100b](https://huggingface.co/whynlp/tinyllama-lckv-w2-100b) | Section 3.2, Table 2 (line 2) | 9.746 | 45.45 | 57 | 58 | ## Installation 59 | 60 | You may install the dependencies with the following commands: 61 | 62 | ```sh 63 | conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia 64 | pip install -r requirements.txt 65 | ``` 66 | 67 | where the CUDA version is set to `12.1`. For other CUDA versions, please refer to installation instructions of [PyTorch](https://pytorch.org/get-started/locally/). See [Trouble shooting](#trouble-shooting) for more details. 68 | 69 | ## Usage 70 | 71 | Our implementation is based on HuggingFace `transformers`. We register a new model `lckv-llama` that supports the Layer-Condensed KV Cache. It inherits from the `llama` model and adds support for the Layer-Condensed KV Cache. 72 | 73 | > [!NOTE] 74 | > It is difficult to support the Layer-Condensed KV Cache for a variety of models with a small amount of code. This is because the Layer-Condensed KV Cache requires to modify the attention mechanism and training recipe of the transformer decoder. Currently, we only implemented the Layer-Condensed KV Cache for the `llama` model, and it is possible to extend it to other models with similar structures. 75 | 76 | ```python 77 | import models # register the lckv-llama model 78 | from transformers import AutoModelForCausalLM, AutoTokenizer 79 | 80 | tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") 81 | model = AutoModelForCausalLM.from_config(config="configs/tinyllama_lckv.json") 82 | ``` 83 | 84 | and now you have a randomly initialized model with the Layer-Condensed KV Cache. 85 | 86 | ### Optimization 87 | 88 | To accelerate the training and inference of the model, one could apply the liger kernel supported by `transformers` library. The provided training script `run_clm.py` has already activated the liger kernel. See more details [here](https://huggingface.co/docs/transformers/v4.45.2/en/trainer#liger-kernel). 89 | 90 | ### Configuration 91 | 92 | We provide some sample configuration files in the `configs` folder. The config settings are defined in [models/configuration_lckv.py](models/configuration_lckv.py). You may refer to this file for more details. 93 | 94 | #### Option 1: Modify the configurations in python: 95 | 96 | ```python 97 | from models import LCKVLlamaConfig 98 | 99 | # we have prepared a sample configuration file 100 | config = LCKVLlamaConfig.from_pretrained("configs/tinyllama_lckv.json") 101 | 102 | # below is the LCKV config. you may modify the configuration as you like 103 | config.forward_passes = 7 # m in the paper 104 | config.backward_passes = 2 # b in the paper 105 | config.layer_types = "0_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_21" # for each layer, which layer to attend to 106 | 107 | # we also support this 108 | config.layer_types = "0_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_21" # the sandwich-middle configuration 109 | config.layer_types = "0_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21" # Llama config 110 | config.layer_types = "0_0_2_2_4_4_6_6_8_8_10_10_12_12_14_14_16_16_18_18_20_20" # CLA config 111 | 112 | config.sliding_window = 1024 # the window size for the sliding window attention 113 | config.layer_types = "0s_1s_2s_3s_4s_5s_6s_7s_8s_9s_10s_11_11_11_11_11_11_11_11_11_11_11" # YOCO config, 's' is for sliding window 114 | 115 | config.sliding_window = 1024 # the window size for the sliding window attention 116 | config.layer_types = "0_1s_1s_3s_3s_3s_0_7s_7s_9s_9s_9s_12_13s_13s_15s_15s_15s_12_19s_19s_19s" # MixAttention (Pairs) config 117 | 118 | # we also support sequential training / inference, which will process the tokens one by one 119 | # corresponding to LCKV paper Figure 2(a) 120 | config.use_sequential = True 121 | ``` 122 | 123 | #### Option 2: Modify the configurations in the shell script (via `--config_overrides`): 124 | 125 | ```sh 126 | accelerate launch run_clm.py \ 127 | --config_name configs/tinyllama_lckv.json \ 128 | --config_overrides forward_passes=7,backward_passes=2,layer_types=0_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_21 \ 129 | ... 130 | ``` 131 | 132 | With the above configurations, you can create [CLA](http://arxiv.org/abs/2405.12981), [YOCO](https://arxiv.org/abs/2405.05254) or any configurations in [Cross-Layer KV Sharing](http://arxiv.org/abs/2410.14442) or [MixAttention](http://arxiv.org/abs/2409.15012) without changing the code. The only thing you need to do is to write the correct `layer_types` in the configuration file. 133 | 134 | ### Pre-training 135 | 136 | We use the same [training script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py) as the original `transformers` library. You may refer to the [official documentation](https://huggingface.co/transformers/training.html) for more details. 137 | 138 | We provide a training script `run_clm.sh` for training a 50M parameter model on the `wikitext-103` dataset. You may run the script with: 139 | 140 | ```sh 141 | bash run_clm.sh 142 | ``` 143 | 144 | See the script for more details. For pretraining on SlimPajama, please follow the instructions in [tinyllama-zh](https://github.com/whyNLP/tinyllama-zh) and replace the dataset with SlimPajama. 145 | 146 | 147 | #### Initializing from a Pretrained Model 148 | 149 | We may initialize our LCKV model from a pretrained model. Most parts of the model structure are consistent with the standard transformer model and we can directly inherit the weights. For the KV weights $W_K, W_V$, we mainly have 2 options: 150 | 151 | ##### Option 1: Directly Copy the Weights 152 | 153 | Simply add `--model_name_or_path` to the training script: 154 | 155 | ```sh 156 | accelerate launch run_clm.py \ 157 | --model_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T \ 158 | --config configs/tinyllama_lckv.json \ 159 | ... 160 | ``` 161 | 162 | See the script `run_clm.sh` for more details. 163 | 164 | ##### Option 2: Average the Weights from Multiple Layers 165 | 166 | Following [MLKV](http://arxiv.org/abs/2406.09297), we may average the weights from multiple layers to initialize the KV weights. We provide a script `convert_pretrained.py` to convert the pretrained model to the LCKV model. You may run the following command: 167 | 168 | ```sh 169 | python convert_pretrained.py --model_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T --config_name configs/tinyllama_lckv.json --output_dir outputs/tinyllama-converted 170 | ``` 171 | 172 | The KV weights of each layer will be the average from the all the layers attends to it. For example, 173 | 174 | ```python 175 | # the CLA / MLKV config 176 | config.layer_types = "0_0_2_2_4_4_6_6" 177 | # then layer 0 will have the average KV weights from layer 0 and 1 in the pretrained model 178 | # layer 2 will have the average KV weights from layer 2 and 3 in the pretrained model 179 | 180 | # the LCKV config 181 | config.layer_types = "0_6_6_6_6_6_6_7" 182 | # then layer 0 will inherit the KV weights from layer 0 in the pretrained model 183 | # layer 6 will have the average KV weights from layer 1, 2, 3, 4, 5, 6 in the pretrained model 184 | # layer 7 will inherit the KV weights from layer 7 in the pretrained model 185 | ``` 186 | 187 | then, use the converted model to initialize the LCKV model: 188 | 189 | ```sh 190 | accelerate launch run_clm.py \ 191 | --model_name_or_path outputs/tinyllama-converted \ 192 | ... 193 | ``` 194 | 195 | Our experiments show that such an initialization strategy can effectively improve the performance of the model in most cases. 196 | 197 | 198 | ### Inference 199 | 200 | We use the same [inference script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py) as the original `transformers` library. To perform inference, you may run the following command: 201 | 202 | ```sh 203 | bash run_generation.sh 204 | ``` 205 | 206 | You may get responses from the trained model given any prompts. See the script for more details. 207 | 208 | #### Streaming 209 | 210 | We integrate our model with [StreamingLLM](https://github.com/mit-han-lab/streaming-llm). To perform streaming inference, you may run the following command: 211 | 212 | ```sh 213 | bash run_streaming.sh 214 | ``` 215 | 216 | See the script for more details. The `run_generation.py` script also supports streaming inference with the `--sink_cache` flag. 217 | 218 | #### Sliding Window Attention 219 | 220 | The generation script also supports sliding window attention inference. If the model is trained with sliding window attention, the generation script will automatically use the sliding window attention for inference. 221 | 222 | ### Evaluation 223 | 224 | We use [LM-Harness](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate the model. You may run the following command: 225 | 226 | ```sh 227 | python test_harness.py --model_name_or_path ... 228 | ``` 229 | 230 | with the path to the model checkpoint. Run `python test_harness.py --help` for more details. 231 | 232 | ### Latency Testing 233 | 234 | To test the latency of the model, you may run the following command: 235 | 236 | ```sh 237 | python test_latency.py 238 | ``` 239 | 240 | ### Instruction Fine-tuning 241 | 242 | > [!WARNING] 243 | > This section is currently experimental and may not work as expected. 244 | 245 | We provide a script `run_sft.sh` for supervised instruction fine-tuning. The code is consistent with the official `trl` library from HuggingFace. You may run the script with: 246 | 247 | ```sh 248 | bash run_sft.sh 249 | ``` 250 | 251 | See the script for more details. 252 | 253 | To chat with the fine-tuned model, you may run the following command: 254 | 255 | ```sh 256 | python chat.py --model_name_or_path outputs/llamatiny-sft-test 257 | ``` 258 | 259 | It will load the fine-tuned model and you can chat with it. 260 | 261 | ## Code Style 262 | 263 | We mostly follow that of `transformers`. Run the following command to check the code style: 264 | 265 | ```sh 266 | # Use `pip install ruff` to install ruff if it is not available 267 | ruff check models 268 | ``` 269 | 270 | See more details in `pyproject.toml`. 271 | 272 | 273 | ## Trouble shooting 274 | 275 | ### Flash-Attn Installation 276 | 277 | https://github.com/Dao-AILab/flash-attention/issues/451 278 | 279 | Behavior: 280 | 281 | Runtime error. 282 | ```sh 283 | ImportError: /home/.../flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_... 284 | ``` 285 | 286 | Solution: 287 | ```sh 288 | pip uninstall flash-attn 289 | FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn 290 | ``` 291 | 292 | ### CUDA version 293 | 294 | The cuda version may affect the installation of: 295 | - [PyTorch](https://pytorch.org/get-started/locally/) 296 | - [Flash-Attn](https://github.com/Dao-AILab/flash-attention) 297 | 298 | Please make sure to install the correct version of the packages (so long as they are consistent, the code would work). Also make sure that `nvcc` is installed and available in the path. 299 | 300 | Our experiment environment uses `CUDA 12.1` and you may install with 301 | ```sh 302 | conda install pytorch==2.5.0 pytorch-cuda=12.1 -c pytorch -c nvidia 303 | pip install -r requirements.txt 304 | ``` 305 | 306 | ### Sequential update produces different outputs 307 | 308 | Behavior: Model inference with sequential update will produce different outputs with parallel update. 309 | 310 | This is due to the precision issues. We find that using `bfloat16`, the down projection in Llama MLP will produce different results when inference with different number of tokens. 311 | 312 | ## Questions 313 | 314 | > 1. Is it possible to integrate the LCKV with MQA / GQA? 315 | 316 | Yes. The fact is that we have already done this in our experiments. Tinyllama uses 32 attention heads and 4 KV heads. We follow the same setting in our experiments. If you want to experiment with different settings, you may modify the `num_attention_heads` and `num_key_value_heads` in the configuration file. 317 | -------------------------------------------------------------------------------- /configs/default_chat_config.yaml: -------------------------------------------------------------------------------- 1 | examples: 2 | llama: 3 | text: There is a Llama in my lawn, how can I get rid of it? 4 | code: 5 | text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end]. 6 | helicopter: 7 | text: How many helicopters can a human eat in one sitting? 8 | numbers: 9 | text: Count to 10 but skip every number ending with an 'e' 10 | birds: 11 | text: Why aren't birds real? 12 | socks: 13 | text: Why is it important to eat socks after meditating? -------------------------------------------------------------------------------- /configs/llama_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 2048, 12 | "max_position_embeddings": 1024, 13 | "model_type": "lckv-llama", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "num_key_value_heads": 6, 17 | "pretraining_tp": 1, 18 | "rms_norm_eps": 1e-05, 19 | "rope_scaling": null, 20 | "tie_word_embeddings": false, 21 | "tokenizer_class": "LlamaTokenizer", 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.45.2", 24 | "use_cache": true, 25 | "vocab_size": 32000 26 | } -------------------------------------------------------------------------------- /configs/llama_small_lckv.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LCKVLlamaForCausalLM" 5 | ], 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 2048, 12 | "max_position_embeddings": 1024, 13 | "model_type": "lckv-llama", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "num_key_value_heads": 6, 17 | "pretraining_tp": 1, 18 | "rms_norm_eps": 1e-05, 19 | "rope_scaling": null, 20 | "tie_word_embeddings": false, 21 | "tokenizer_class": "LlamaTokenizer", 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.45.2", 24 | "use_cache": true, 25 | "vocab_size": 32000, 26 | "layer_types": "0_10_10_10_10_10_10_10_10_10_10_11", 27 | "forward_passes": 7, 28 | "backward_passes": 2 29 | } -------------------------------------------------------------------------------- /configs/llama_tiny.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 512, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 1024, 12 | "max_position_embeddings": 1024, 13 | "model_type": "llama", 14 | "num_attention_heads": 8, 15 | "num_hidden_layers": 8, 16 | "num_key_value_heads": 4, 17 | "pretraining_tp": 1, 18 | "rms_norm_eps": 1e-05, 19 | "rope_scaling": null, 20 | "tie_word_embeddings": false, 21 | "tokenizer_class": "LlamaTokenizer", 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.45.2", 24 | "use_cache": true, 25 | "vocab_size": 32000 26 | } -------------------------------------------------------------------------------- /configs/llama_tiny_lckv.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LCKVLlamaForCausalLM" 5 | ], 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 512, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 1024, 12 | "max_position_embeddings": 1024, 13 | "model_type": "lckv-llama", 14 | "num_attention_heads": 8, 15 | "num_hidden_layers": 8, 16 | "num_key_value_heads": 4, 17 | "pretraining_tp": 1, 18 | "rms_norm_eps": 1e-05, 19 | "rope_scaling": null, 20 | "tie_word_embeddings": false, 21 | "tokenizer_class": "LlamaTokenizer", 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.45.2", 24 | "use_cache": true, 25 | "vocab_size": 32000, 26 | "layer_types": "0_6_6_6_6_6_6_7", 27 | "forward_passes": 7, 28 | "backward_passes": 2 29 | } -------------------------------------------------------------------------------- /configs/tinyllama.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 2048, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 5632, 12 | "max_position_embeddings": 2048, 13 | "model_type": "llama", 14 | "num_attention_heads": 32, 15 | "num_hidden_layers": 22, 16 | "num_key_value_heads": 4, 17 | "pretraining_tp": 1, 18 | "rms_norm_eps": 1e-05, 19 | "rope_scaling": null, 20 | "tie_word_embeddings": false, 21 | "tokenizer_class": "LlamaTokenizer", 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.45.2", 24 | "use_cache": true, 25 | "vocab_size": 32000 26 | } -------------------------------------------------------------------------------- /configs/tinyllama_lckv.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "meta-llama/Llama-2-7b-hf", 3 | "architectures": [ 4 | "LCKVLlamaForCausalLM" 5 | ], 6 | "bos_token_id": 1, 7 | "eos_token_id": 2, 8 | "hidden_act": "silu", 9 | "hidden_size": 2048, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 5632, 12 | "max_position_embeddings": 2048, 13 | "model_type": "lckv-llama", 14 | "num_attention_heads": 32, 15 | "num_hidden_layers": 22, 16 | "num_key_value_heads": 4, 17 | "pretraining_tp": 1, 18 | "rms_norm_eps": 1e-05, 19 | "rope_scaling": null, 20 | "tie_word_embeddings": false, 21 | "tokenizer_class": "LlamaTokenizer", 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.45.2", 24 | "use_cache": true, 25 | "vocab_size": 32000, 26 | "layer_types": "0_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_21", 27 | "forward_passes": 7, 28 | "backward_passes": 2 29 | } -------------------------------------------------------------------------------- /convert_pretrained.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | 4 | import models 5 | from models.utils import LayerTypeParser 6 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 7 | 8 | 9 | def main(): 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model_name_or_path", help="The pretrained llama model to convert", required=True) 13 | parser.add_argument("--config_name", help="The config file of the expected LCKV model", required=True) 14 | parser.add_argument("--config_overrides", help="Override some existing config settings. Example: layer_types=0_6_6_6_6_6_6_7,forward_passes=7", default=None, required=False) 15 | parser.add_argument("--tokenizer_name", help="Pretrained tokenizer name or path if not the same as the pretrained model.", default=None, required=False) 16 | parser.add_argument("--output_dir", help="The output directory where the converted model will be written.", required=True) 17 | args = parser.parse_args() 18 | 19 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path) 20 | config = AutoConfig.from_pretrained(args.config_name) 21 | pt_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) 22 | pt_model_state_dict = pt_model.state_dict() 23 | 24 | assert config.model_type == "lckv-llama", "The target model must be a LCKV model" 25 | # allow config overrides under all circumstances 26 | if args.config_overrides is not None: 27 | print(f"Overriding config: {args.config_overrides}") 28 | config.update_from_string(args.config_overrides) 29 | print(f"New config: {config}") 30 | 31 | model = AutoModelForCausalLM.from_config(config) 32 | model_state_dict = model.state_dict() 33 | 34 | # Copy the weights from the pretrained model to the LCKV model 35 | print("Copying weights from the pretrained model to the LCKV model...") 36 | for name, param in pt_model.named_parameters(): 37 | if ('k_proj' in name or 'v_proj' in name): 38 | continue 39 | 40 | if name in model_state_dict: 41 | model_state_dict[name].copy_(param.data) 42 | else: 43 | print(f"WARNING: {name} not found in the model") 44 | 45 | # Average the weights of the k_proj and v_proj layers 46 | # The pretrained layer weights will contribute to the layer it attends to 47 | # XXX: how to align heads? 48 | print("Averaging the weights of the k_proj and v_proj layers...") 49 | parser = LayerTypeParser(config.layer_types) 50 | k_proj, v_proj = defaultdict(list), defaultdict(list) 51 | for layer_type in parser: 52 | k_proj[layer_type.attends_to].append(pt_model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.k_proj.weight"]) 53 | v_proj[layer_type.attends_to].append(pt_model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.v_proj.weight"]) 54 | 55 | for layer_type in parser: 56 | if layer_type.computes_kv: 57 | model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.k_proj.weight"].copy_(sum(k_proj[layer_type.layer_idx]) / len(k_proj[layer_type.layer_idx])) 58 | model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.v_proj.weight"].copy_(sum(v_proj[layer_type.layer_idx]) / len(v_proj[layer_type.layer_idx])) 59 | 60 | # Save the model 61 | print(f"Saving the model to {args.output_dir}...") 62 | model.save_pretrained(args.output_dir) 63 | tokenizer.save_pretrained(args.output_dir) 64 | 65 | print("Model convertion finished successfully") 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 2 | from transformers.utils import is_liger_kernel_available 3 | 4 | from .configuration_lckv import LCKVLlamaConfig 5 | from .modeling_lckv import LCKVLlamaForCausalLM, LCKVLlamaModel 6 | 7 | 8 | AutoConfig.register("lckv-llama", LCKVLlamaConfig) 9 | AutoModel.register(LCKVLlamaConfig, LCKVLlamaModel) 10 | AutoModelForCausalLM.register(LCKVLlamaConfig, LCKVLlamaForCausalLM) 11 | 12 | 13 | if is_liger_kernel_available(): 14 | from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN 15 | 16 | from .kernel import apply_liger_kernel_to_lckv_llama 17 | MODEL_TYPE_TO_APPLY_LIGER_FN["lckv-llama"] = apply_liger_kernel_to_lckv_llama 18 | -------------------------------------------------------------------------------- /models/cache_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import torch 4 | 5 | from transformers.cache_utils import Cache, DynamicCache, SinkCache 6 | 7 | from .utils import LayerTypeParser 8 | 9 | 10 | class IndexedCache(Cache): 11 | """ 12 | Similar to the `DynamicCache` class, but with the ability to index the cache by layer index. DynamicCache 13 | assumes that all layers compute KVs, while IndexedCache allows for a more flexible cache structure. 14 | """ 15 | build_position_ids_based_on_cache = False 16 | 17 | def __init__(self) -> None: 18 | super().__init__() 19 | self.key_cache: Dict[int, torch.Tensor] = {} 20 | self.value_cache: Dict[int, torch.Tensor] = {} 21 | self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen 22 | self._update = True # to prevent the cache from updating when inference with iterations 23 | 24 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: 25 | """ 26 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the 27 | sequence length. 28 | """ 29 | if layer_idx in self.key_cache: 30 | return (self.key_cache[layer_idx], self.value_cache[layer_idx]) 31 | else: 32 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") 33 | 34 | def __iter__(self): 35 | """ 36 | Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over 37 | keys and values 38 | """ 39 | for layer_idx in sorted(self.key_cache.keys()): 40 | yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) 41 | 42 | def __len__(self): 43 | """ 44 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds 45 | to the number of layers that compute KVs in the model. 46 | """ 47 | return len(self.key_cache) 48 | 49 | @property 50 | def min_layer(self) -> int: 51 | return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None 52 | 53 | def is_min_layer(self, layer_idx: int) -> bool: 54 | return self.min_layer is None or self.min_layer == layer_idx 55 | 56 | def update( 57 | self, 58 | key_states: torch.Tensor, 59 | value_states: torch.Tensor, 60 | layer_idx: int, 61 | cache_kwargs: Optional[Dict[str, Any]] = None, 62 | ) -> Tuple[torch.Tensor, torch.Tensor]: 63 | """ 64 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. 65 | 66 | Parameters: 67 | key_states (`torch.Tensor`): 68 | The new key states to cache. 69 | value_states (`torch.Tensor`): 70 | The new value states to cache. 71 | layer_idx (`int`): 72 | The index of the layer to cache the states for. 73 | cache_kwargs (`Dict[str, Any]`, `optional`): 74 | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. 75 | 76 | Return: 77 | A tuple containing the updated key and value states. 78 | """ 79 | # Update the number of seen tokens 80 | if self.is_min_layer(layer_idx): 81 | self._seen_tokens += key_states.shape[-2] 82 | 83 | # Retrieve the cache 84 | if layer_idx not in self.key_cache: 85 | new_key_states = key_states 86 | new_value_states = value_states 87 | else: 88 | new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) 89 | new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) 90 | 91 | # Update the cache 92 | if self._update: 93 | self.key_cache[layer_idx] = new_key_states 94 | self.value_cache[layer_idx] = new_value_states 95 | 96 | return new_key_states, new_value_states 97 | 98 | def get_seq_length(self, layer_idx: Optional[int] = None) -> int: 99 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 100 | if layer_idx is None: 101 | layer_idx = self.min_layer 102 | 103 | # TODO: deprecate this function in favor of `cache_position` 104 | is_empty_layer = ( 105 | (len(self.key_cache) == 0) # no cache in any layer 106 | or (layer_idx not in self.key_cache) # skipped `layer_idx` and hasn't run a layer with cache after it 107 | ) 108 | layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 109 | return layer_seq_length 110 | 111 | def get_max_length(self) -> Optional[int]: 112 | """Returns the maximum sequence length of the cached states. IndexedCache does not have a maximum length.""" 113 | return None 114 | 115 | @classmethod 116 | def from_cache(cls, dynamic_cache: DynamicCache, *args, **kwargs) -> "IndexedCache": 117 | """Converts a dynamic cache into an equivalent `IndexedCache`.""" 118 | cache = cls(*args, **kwargs) 119 | 120 | cache._seen_tokens = dynamic_cache._seen_tokens 121 | for layer_idx in range(len(dynamic_cache.key_cache)): 122 | key_states, value_states = dynamic_cache[layer_idx] 123 | cache.update(key_states, value_states, layer_idx) 124 | 125 | return cache 126 | 127 | 128 | class IndexedSinkCache(Cache): 129 | """ 130 | This is a fix to the SinkCache class in the transformers library. It also allows for the cache to be indexed by 131 | layer index, similar to the `IndexedCache` class. 132 | """ 133 | build_position_ids_based_on_cache = True 134 | 135 | def __init__(self, window_length: int = None, num_sink_tokens: int = None) -> None: 136 | super().__init__() 137 | self.key_cache: Dict[int, torch.Tensor] = {} 138 | self.value_cache: Dict[int, torch.Tensor] = {} 139 | self.window_length = window_length 140 | self.num_sink_tokens = num_sink_tokens 141 | self.cos_sin_rerotation_cache = {} 142 | self._cos_cache = None 143 | self._sin_cache = None 144 | self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen 145 | self._update = True # to prevent the cache from updating when inference with iterations 146 | 147 | @staticmethod 148 | def _rotate_half(x): 149 | x1 = x[..., : x.shape[-1] // 2] 150 | x2 = x[..., x.shape[-1] // 2 :] 151 | return torch.cat((-x2, x1), dim=-1) 152 | 153 | def _apply_key_rotary_pos_emb( 154 | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor 155 | ) -> torch.Tensor: 156 | rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) 157 | return rotated_key_states 158 | 159 | def _get_rerotation_cos_sin( 160 | self, offset: int, dtype: torch.dtype, cos: torch.Tensor, sin: torch.Tensor 161 | ) -> Tuple[torch.Tensor, torch.Tensor]: 162 | if offset not in self.cos_sin_rerotation_cache: 163 | # Upcast to float32 temporarily for better accuracy 164 | cos = cos.to(torch.float32) 165 | sin = sin.to(torch.float32) 166 | 167 | # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence 168 | original_cos = cos[self.num_sink_tokens + offset :] 169 | shifted_cos = cos[self.num_sink_tokens : -offset] 170 | original_sin = sin[self.num_sink_tokens + offset :] 171 | shifted_sin = sin[self.num_sink_tokens : -offset] 172 | rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin 173 | rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin 174 | 175 | self.cos_sin_rerotation_cache[offset] = ( 176 | rerotation_cos.to(dtype).unsqueeze(0), 177 | rerotation_sin.to(dtype).unsqueeze(0), 178 | ) 179 | return self.cos_sin_rerotation_cache[offset] 180 | 181 | @property 182 | def min_layer(self) -> int: 183 | return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None 184 | 185 | def is_min_layer(self, layer_idx: int) -> bool: 186 | return self.min_layer is None or self.min_layer == layer_idx 187 | 188 | def get_seq_length(self, layer_idx: Optional[int] = None) -> int: 189 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 190 | # TODO: deprecate this function in favor of `cache_position` 191 | # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length 192 | if layer_idx is None: 193 | layer_idx = self.min_layer 194 | 195 | if layer_idx not in self.key_cache: 196 | return 0 197 | 198 | return self.key_cache[layer_idx].shape[-2] 199 | 200 | def get_max_length(self) -> Optional[int]: 201 | """Returns the maximum sequence length of the cached states.""" 202 | return self.window_length 203 | 204 | def update( 205 | self, 206 | key_states: torch.Tensor, 207 | value_states: torch.Tensor, 208 | layer_idx: int, 209 | cache_kwargs: Optional[Dict[str, Any]] = None, 210 | ) -> Tuple[torch.Tensor, torch.Tensor]: 211 | """ 212 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. 213 | 214 | Parameters: 215 | key_states (`torch.Tensor`): 216 | The new key states to cache. 217 | value_states (`torch.Tensor`): 218 | The new value states to cache. 219 | layer_idx (`int`): 220 | The index of the layer to cache the states for. 221 | cache_kwargs (`Dict[str, Any]`, `optional`): 222 | Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, 223 | `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the 224 | rotation as the tokens are shifted. 225 | 226 | Return: 227 | A tuple containing the updated key and value states. 228 | """ 229 | # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models 230 | # with partially rotated position embeddings, like Phi or Persimmon. 231 | sin = cache_kwargs.get("sin") 232 | cos = cache_kwargs.get("cos") 233 | partial_rotation_size = cache_kwargs.get("partial_rotation_size") 234 | using_rope = cos is not None and sin is not None 235 | 236 | # Update the number of seen tokens 237 | if self.is_min_layer(layer_idx): 238 | self._seen_tokens += key_states.shape[-2] 239 | 240 | # Update the sin/cos cache, which holds sin/cos values for all possible positions 241 | if using_rope and self.is_min_layer(layer_idx): 242 | # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove 243 | # after all RoPE models have a llama-like cache utilization. 244 | if cos.dim() == 2: 245 | self._cos_cache = cos 246 | self._sin_cache = sin 247 | else: 248 | if self._cos_cache is None: 249 | self._cos_cache = cos[0, ...] 250 | self._sin_cache = sin[0, ...] 251 | elif self._cos_cache.shape[0] < self.window_length + key_states.shape[-2]: 252 | self._cos_cache = torch.cat([self._cos_cache[: self.window_length], cos[0, ...]], dim=0) 253 | self._sin_cache = torch.cat([self._sin_cache[: self.window_length], sin[0, ...]], dim=0) 254 | 255 | # [bsz, num_heads, seq_len, head_dim] 256 | if layer_idx not in self.key_cache: 257 | # Empty cache 258 | new_key_states = key_states 259 | new_value_states = value_states 260 | 261 | else: 262 | # Growing cache 263 | new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) 264 | new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) 265 | 266 | if self._update: 267 | self.key_cache[layer_idx] = new_key_states 268 | self.value_cache[layer_idx] = new_value_states 269 | 270 | # If the cache is full, we need to shift the cache 271 | if (seq_length := self.get_seq_length(layer_idx)) > self.window_length: 272 | # Shifting cache 273 | keys_to_keep = self.key_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :] 274 | 275 | # On RoPE models, we need to recompute the Key rotation as the tokens are shifted 276 | if using_rope: 277 | rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( 278 | seq_length - self.window_length, 279 | key_states.dtype, 280 | self._cos_cache[:seq_length], 281 | self._sin_cache[:seq_length], 282 | ) 283 | if partial_rotation_size is not None: 284 | keys_to_keep, keys_pass = ( 285 | keys_to_keep[..., :partial_rotation_size], 286 | keys_to_keep[..., partial_rotation_size:], 287 | ) 288 | keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) 289 | if partial_rotation_size is not None: 290 | keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) 291 | 292 | # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens 293 | sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] 294 | self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep], dim=-2) 295 | 296 | sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] 297 | values_to_keep = self.value_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :] 298 | self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep], dim=-2) 299 | 300 | return new_key_states, new_value_states 301 | 302 | @classmethod 303 | def from_cache(cls, sink_cache: SinkCache, *args, **kwargs) -> "IndexedSinkCache": 304 | """Converts a dynamic cache into an equivalent `IndexedCache`.""" 305 | cache = cls(*args, **kwargs) 306 | 307 | cache.window_length = sink_cache.window_length 308 | cache.num_sink_tokens = sink_cache.num_sink_tokens 309 | cache._seen_tokens = sink_cache._seen_tokens 310 | cache._cos_cache = sink_cache._cos_cache 311 | cache._sin_cache = sink_cache._sin_cache 312 | cache.cos_sin_rerotation_cache = sink_cache.cos_sin_rerotation_cache 313 | for layer_idx in range(len(sink_cache.key_cache)): 314 | cache.key_cache[layer_idx] = sink_cache.key_cache[layer_idx] 315 | cache.value_cache[layer_idx] = sink_cache.value_cache[layer_idx] 316 | 317 | return cache 318 | 319 | 320 | class IndexedSlidingWindowCache(IndexedCache): 321 | """ 322 | Similar to the `SlidingWindowCache` class, but with the ability to index the cache by layer index. It is no longer 323 | a subclass of `StaticCache` as it is dynamic. 324 | """ 325 | build_position_ids_based_on_cache = False 326 | 327 | def __init__(self, sliding_window: int = None) -> None: 328 | super().__init__() 329 | self.sliding_window = sliding_window 330 | 331 | def update( 332 | self, 333 | key_states: torch.Tensor, 334 | value_states: torch.Tensor, 335 | layer_idx: int, 336 | cache_kwargs: Optional[Dict[str, Any]] = None, 337 | ) -> Tuple[torch.Tensor]: 338 | # Update the number of seen tokens 339 | if self.is_min_layer(layer_idx): 340 | self._seen_tokens += key_states.shape[-2] 341 | 342 | # [bsz, num_heads, seq_len, head_dim] 343 | if layer_idx not in self.key_cache: 344 | # Empty cache 345 | new_key_states = key_states 346 | new_value_states = value_states 347 | 348 | else: 349 | # Growing cache 350 | new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) 351 | new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) 352 | 353 | if self._update: 354 | self.key_cache[layer_idx] = new_key_states 355 | self.value_cache[layer_idx] = new_value_states 356 | 357 | # If the cache is full, we need to shift the cache 358 | if self.get_seq_length(layer_idx) > self.sliding_window: 359 | self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, -self.sliding_window :] 360 | self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, -self.sliding_window :] 361 | 362 | return new_key_states, new_value_states 363 | 364 | def get_max_length(self) -> Optional[int]: 365 | return self.sliding_window 366 | 367 | @classmethod 368 | def from_cache(cls, sliding_window_cache: "IndexedSlidingWindowCache", *args, **kwargs) -> "IndexedSlidingWindowCache": 369 | """This is to override the `from_cache` method in the `IndexedCache` class.""" 370 | cache = cls(*args, **kwargs) 371 | 372 | cache._seen_tokens = sliding_window_cache._seen_tokens 373 | cache.sliding_window = sliding_window_cache.sliding_window 374 | for layer_idx in range(len(sliding_window_cache.key_cache)): 375 | cache.key_cache[layer_idx] = sliding_window_cache.key_cache[layer_idx] 376 | cache.value_cache[layer_idx] = sliding_window_cache.value_cache[layer_idx] 377 | 378 | return cache 379 | 380 | 381 | class IndexedHybridCache(IndexedSlidingWindowCache, IndexedCache): 382 | """ 383 | Hybrid Cache class to be used for models that alternate between a local sliding window attention and global 384 | attention in every other layer. Under the hood, Hybrid Cache leverages ["IndexedSlidingWindowCache"] for 385 | sliding window attention and ["IndexedCache"] for global attention. 386 | """ 387 | build_position_ids_based_on_cache = False 388 | 389 | def __init__(self, parser: LayerTypeParser = None, sliding_window: int = None) -> None: 390 | super().__init__(sliding_window=sliding_window) 391 | self.parser = parser 392 | 393 | def update( 394 | self, 395 | key_states: torch.Tensor, 396 | value_states: torch.Tensor, 397 | layer_idx: int, 398 | cache_kwargs: Optional[Dict[str, Any]] = None, 399 | ) -> Tuple[torch.Tensor]: 400 | if self.parser[layer_idx].use_sliding_window: 401 | return IndexedSlidingWindowCache.update(self, key_states, value_states, layer_idx, cache_kwargs) 402 | else: 403 | return IndexedCache.update(self, key_states, value_states, layer_idx, cache_kwargs) 404 | 405 | def get_max_length(self) -> Optional[int]: 406 | return IndexedCache.get_max_length(self) 407 | 408 | @classmethod 409 | def from_cache(cls, hybrid_cache: "IndexedHybridCache", *args, **kwargs) -> "IndexedHybridCache": 410 | """This is to override the `from_cache` method in the `IndexedSlidingWindowCache` class.""" 411 | cache = cls(*args, **kwargs) 412 | 413 | cache._seen_tokens = hybrid_cache._seen_tokens 414 | cache.sliding_window = hybrid_cache.sliding_window 415 | cache.parser = hybrid_cache.parser 416 | for layer_idx in range(len(hybrid_cache.key_cache)): 417 | cache.key_cache[layer_idx] = hybrid_cache.key_cache[layer_idx] 418 | cache.value_cache[layer_idx] = hybrid_cache.value_cache[layer_idx] 419 | 420 | return cache 421 | 422 | 423 | class LayerCache(torch.nn.Module): 424 | """ 425 | A cache for storing the key-value pairs for layers. 426 | """ 427 | def __init__(self) -> None: 428 | """ 429 | The placeholder is used to expand the key-value pairs if the layer attends to the top layers. 430 | Size: (batch_size, num_key_value_heads, 1, head_dim) 431 | """ 432 | super().__init__() 433 | self.key_layer_cache: Dict[int, torch.Tensor] = {} 434 | self.value_layer_cache: Dict[int, torch.Tensor] = {} 435 | self.layer_type = None 436 | self.placeholder = None 437 | 438 | def setup(self, placeholder: torch.Tensor): 439 | """setup the cache, calling this function is necessary if there is a layer that attends to the top layers""" 440 | self.placeholder = placeholder 441 | 442 | def initialize(self, parser: LayerTypeParser, sequence_length: int): 443 | """initialize the cache""" 444 | layers_to_init = {parser[idx].attends_to for idx in range(len(parser)) if parser[idx].attends_top} 445 | 446 | if layers_to_init: 447 | b, h, _, d = self.placeholder.size() 448 | init_kvs = self.placeholder.new_zeros((b, h, sequence_length, d)) 449 | 450 | for layer_idx in layers_to_init: 451 | self.layer_append(layer_idx, init_kvs, init_kvs) 452 | 453 | def layer_get(self, layer_idx: int, zerofill: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: 454 | key_states = self.key_layer_cache.get(layer_idx, None) 455 | value_states = self.value_layer_cache.get(layer_idx, None) 456 | 457 | if zerofill: 458 | if key_states is None: 459 | key_states = self.placeholder 460 | value_states = self.placeholder 461 | else: 462 | key_states = torch.cat([self.placeholder, key_states], dim=2) 463 | value_states = torch.cat([self.placeholder, value_states], dim=2) 464 | 465 | return key_states, value_states 466 | 467 | def layer_set(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor): 468 | self.key_layer_cache[layer_idx] = key 469 | self.value_layer_cache[layer_idx] = value 470 | 471 | def layer_append(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor): 472 | if layer_idx not in self.key_layer_cache: 473 | self.key_layer_cache[layer_idx] = key 474 | self.value_layer_cache[layer_idx] = value 475 | else: 476 | self.key_layer_cache[layer_idx] = torch.cat([self.key_layer_cache[layer_idx], key], dim=2) 477 | self.value_layer_cache[layer_idx] = torch.cat([self.value_layer_cache[layer_idx], value], dim=2) 478 | 479 | 480 | class LayerIndexedCache(LayerCache, IndexedCache): 481 | """ 482 | A cache for storing the key-value pairs for layers, in combination with the ability of standard KV cache. 483 | """ 484 | def __init__(self) -> None: 485 | LayerCache.__init__(self) 486 | IndexedCache.__init__(self) 487 | 488 | 489 | class LayerIndexedSinkCache(LayerCache, IndexedSinkCache): 490 | """ 491 | A cache for storing the key-value pairs for layers, in combination with the ability of sink KV cache. 492 | """ 493 | def __init__(self) -> None: 494 | LayerCache.__init__(self) 495 | IndexedSinkCache.__init__(self) 496 | 497 | 498 | class LayerIndexedSlidingWindowCache(LayerCache, IndexedSlidingWindowCache): 499 | """ 500 | A cache for storing the key-value pairs for layers, in combination with the ability of sliding window KV cache. 501 | """ 502 | def __init__(self) -> None: 503 | LayerCache.__init__(self) 504 | IndexedSlidingWindowCache.__init__(self) 505 | 506 | 507 | class LayerIndexedHybridCache(LayerCache, IndexedHybridCache): 508 | """ 509 | A cache for storing the key-value pairs for layers, in combination with the ability of hybrid KV cache. 510 | """ 511 | def __init__(self) -> None: 512 | LayerCache.__init__(self) 513 | IndexedHybridCache.__init__(self) 514 | 515 | 516 | class AutoLayerCache(torch.nn.Module): 517 | """ 518 | AutoLayerCache is a module that automatically creates a cache from an existing cache. 519 | """ 520 | CACHE_MAPPING = { 521 | DynamicCache: LayerIndexedCache, 522 | SinkCache: LayerIndexedSinkCache, 523 | IndexedSlidingWindowCache: LayerIndexedSlidingWindowCache, 524 | IndexedHybridCache: LayerIndexedHybridCache, 525 | } 526 | 527 | def __init__(self, *args, **kwargs): 528 | raise RuntimeError( 529 | f"{self.__class__.__name__} is designed to be instantiated " 530 | f"using the `{self.__class__.__name__}.from_cache(cache)` method." 531 | ) 532 | 533 | @classmethod 534 | def from_cache(cls, cache: Cache, *args, **kwargs): 535 | """ 536 | Create a new cache from an existing cache. The new cache will have the same type as the original cache. 537 | """ 538 | cache_type = type(cache) 539 | if cache_type not in cls.CACHE_MAPPING: 540 | raise ValueError(f"Cache type {cache_type} is not supported by {cls.__name__}.") 541 | 542 | cache_class = cls.CACHE_MAPPING[cache_type] 543 | 544 | if hasattr(cache_class, "from_cache"): 545 | return cache_class.from_cache(cache, *args, **kwargs) 546 | else: 547 | # we init an empty cache and copy the attributes 548 | new_cache = cache_class(*args, **kwargs) 549 | new_cache.__dict__.update(cache.__dict__) 550 | return new_cache 551 | -------------------------------------------------------------------------------- /models/configuration_lckv.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LCKV LLaMA model configuration""" 21 | from transformers.models.llama.configuration_llama import LlamaConfig 22 | 23 | from .utils import LayerTypeParser 24 | 25 | 26 | class LCKVLlamaConfig(LlamaConfig): 27 | 28 | model_type = "lckv-llama" 29 | 30 | def __init__( 31 | self, 32 | layer_types: str = None, 33 | forward_passes: int = 7, 34 | backward_passes: int = 2, 35 | sliding_window: int = 4096, 36 | use_sequential: bool = False, 37 | force_nodiag: bool = False, 38 | **kwargs, 39 | ): 40 | """ 41 | Initialize a LCKV LLaMA configuration. Instantiating a configuration with the defaults 42 | will yield a similar configuration to that of the LLaMA-7B with the standard transformer 43 | training scheme. 44 | 45 | Args: 46 | layer_types (`str`, *optional*): 47 | A string of integers separated by underscores. The i-th integer means the layer 48 | will use the key-value pair in the i-th layer as the kv cache. Special characters 49 | may be placed after the integers: 50 | - `s` means the layer will use sliding window attention. 51 | The default value is "0_1_2_..." till the number of layers in the current config. 52 | forward_passes (`int`, *optional*, defaults to 7): 53 | The number of forward passes during training and prompt encoding. Equivlent 54 | to `m` in the paper. 55 | backward_passes (`int`, *optional*, defaults to 2): 56 | The number of backward passes during training and prompt encoding. Equivlent 57 | to `b` in the paper. 58 | sliding_window (`int`, *optional*, defaults to 4096): 59 | Sliding window attention window size. If not specified, will default to `4096`. 60 | It will only be effective if the corresponding layer uses sliding window attention. 61 | use_sequential (`bool`, *optional*, defaults to False): 62 | Whether to do forwarding sequentially, token by token. Useful for testing purpose 63 | for models with cyclic dependency. Also can be used for sequential training. 64 | force_nodiag (`bool`, *optional*, defaults to False): 65 | Whether to force the model to not use the diagonal attention. By default, the model 66 | will mask the diagonal attention only in layers necessary. If set to `True`, the model 67 | will never use the diagonal attention in any layer. This is mainly for backward compatibility. 68 | """ 69 | super().__init__(**kwargs) 70 | self.layer_types = layer_types 71 | self.forward_passes = forward_passes 72 | self.backward_passes = backward_passes 73 | self.sliding_window = sliding_window 74 | self.use_sequential = use_sequential 75 | self.force_nodiag = force_nodiag 76 | 77 | if self.layer_types is None: 78 | self.layer_types = "_".join(map(str, range(self.num_hidden_layers))) 79 | 80 | # post check 81 | LayerTypeParser(self.layer_types).check(self.num_hidden_layers) 82 | -------------------------------------------------------------------------------- /models/kernel.py: -------------------------------------------------------------------------------- 1 | from liger_kernel.transformers.monkey_patch import ( 2 | LigerCrossEntropyLoss, 3 | LigerRMSNorm, 4 | LigerSwiGLUMLP, 5 | PreTrainedModel, 6 | _bind_method_to_module, 7 | _patch_rms_norm_module, 8 | llama_lce_forward, 9 | ) 10 | 11 | from .ops_rope import SingleLigerRopeFunction 12 | 13 | 14 | def liger_rotary(q, cos, sin, unsqueeze_dim=1): 15 | """ 16 | Applies Rotary Positional Embedding (RoPE) operation to query and key states. 17 | 18 | Args: 19 | q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim). 20 | cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim). 21 | sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim). 22 | unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1. 23 | 24 | Returns: 25 | Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation. 26 | """ 27 | 28 | return SingleLigerRopeFunction.apply(q, cos, sin, None, unsqueeze_dim) 29 | 30 | 31 | def apply_liger_kernel_to_lckv_llama( 32 | rope: bool = True, 33 | cross_entropy: bool = False, 34 | fused_linear_cross_entropy: bool = True, 35 | rms_norm: bool = True, 36 | swiglu: bool = True, 37 | model: PreTrainedModel = None, 38 | ) -> None: 39 | """ 40 | Apply Liger kernels to replace original implementation in LCKV models. 41 | 42 | Args: 43 | rope (bool): Whether to apply Liger's rotary position embedding. Default is True. 44 | cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. 45 | fused_linear_cross_entropy (bool): 46 | Whether to apply Liger's fused linear cross entropy loss. Default is True. 47 | `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. 48 | If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. 49 | rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. 50 | swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True. 51 | model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been 52 | loaded. Default is None. 53 | """ 54 | 55 | assert not ( 56 | cross_entropy and fused_linear_cross_entropy 57 | ), "cross_entropy and fused_linear_cross_entropy cannot both be True." 58 | 59 | from transformers.models.llama import modeling_llama 60 | 61 | from . import modeling_lckv 62 | 63 | if rope: 64 | modeling_lckv.apply_rotary = liger_rotary 65 | if rms_norm: 66 | modeling_llama.LlamaRMSNorm = LigerRMSNorm 67 | if swiglu: 68 | modeling_llama.LlamaMLP = LigerSwiGLUMLP 69 | if cross_entropy: 70 | modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss 71 | if fused_linear_cross_entropy: 72 | modeling_lckv.LlamaForCausalLM.forward = llama_lce_forward 73 | 74 | if model is not None: 75 | # The model instance already exists, so we need to additionally patch the 76 | # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP) 77 | 78 | base_model = getattr(model, model.base_model_prefix, model) 79 | 80 | if rms_norm: 81 | _patch_rms_norm_module(base_model.norm) 82 | 83 | for decoder_layer in base_model.layers: 84 | if swiglu: 85 | _bind_method_to_module( 86 | decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward 87 | ) 88 | if rms_norm: 89 | _patch_rms_norm_module(decoder_layer.input_layernorm) 90 | _patch_rms_norm_module(decoder_layer.post_attention_layernorm) 91 | -------------------------------------------------------------------------------- /models/ops_rope.py: -------------------------------------------------------------------------------- 1 | """https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py""" 2 | 3 | import torch 4 | import triton 5 | import triton.language as tl 6 | 7 | 8 | @triton.jit 9 | def _triton_rope( 10 | q_ptr, 11 | q_row_stride, 12 | cos, 13 | cos_row_stride, 14 | sin, 15 | sin_row_stride, 16 | sl, 17 | bs: tl.constexpr, 18 | n_qh: tl.constexpr, 19 | hd: tl.constexpr, 20 | pad_n_qh: tl.constexpr, 21 | pad_hd: tl.constexpr, 22 | BLOCK_SIZE: tl.constexpr, 23 | BACKWARD_PASS: tl.constexpr = False, 24 | ): 25 | # q size: (bsz, seq_len, num_q_heads, head_dim) 26 | # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1) 27 | 28 | # cos size: (1, seq_len, head_dim) 29 | # stride: (seq_len * head_dim, head_dim, 1) 30 | pid = tl.program_id(0) 31 | 32 | # locate start address 33 | q_ptr = q_ptr + pid * q_row_stride 34 | 35 | # #################################################################### 36 | # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position 37 | # m of this program instance 38 | # #################################################################### 39 | 40 | # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which 41 | # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension 42 | # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index 43 | # and pid % sl to get the sequence index. 44 | # 2. We only need the left half of cos and sin matrix because the right half is just 45 | # a clone of the left half. 46 | cos_row_idx = pid % (sl) 47 | cos = cos + cos_row_idx * cos_row_stride 48 | sin = sin + cos_row_idx * sin_row_stride 49 | cos_offsets = tl.arange(0, pad_hd // 2) 50 | cos_mask = cos_offsets < hd // 2 51 | cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0) 52 | sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0) 53 | 54 | # #################################################################### 55 | # Load the left and right half of q for the current 56 | # program instance (i.e. for the current token) separately 57 | # #################################################################### 58 | # left half of the head 59 | first_half_q_offsets = ( 60 | tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] 61 | ) 62 | first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( 63 | tl.arange(0, pad_hd // 2)[None, :] < hd // 2 64 | ) 65 | q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( 66 | sin_row.dtype 67 | ) 68 | 69 | # right half of the head 70 | second_half_q_offsets = first_half_q_offsets + (hd // 2) 71 | second_q_mask = first_q_mask 72 | q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( 73 | sin_row.dtype 74 | ) 75 | 76 | if not BACKWARD_PASS: 77 | # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] 78 | new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row 79 | tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) 80 | new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row 81 | tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) 82 | else: 83 | # with some math, we can get: 84 | # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] 85 | new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row 86 | tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) 87 | new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row 88 | tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) 89 | 90 | 91 | def rope_forward(q, cos, sin): 92 | 93 | # transpose it back to the physical shape because Triton looks at the physical storage 94 | # note: q is incontiguous before the transformation and will become contiguous after transpose 95 | q = q.transpose(1, 2) 96 | 97 | batch_size, seq_len, n_q_head, head_dim = q.shape 98 | pad_hd = triton.next_power_of_2(head_dim) 99 | pad_n_q_head = triton.next_power_of_2(n_q_head) 100 | BLOCK_SIZE = pad_n_q_head 101 | 102 | n_row = batch_size * seq_len 103 | 104 | # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous 105 | q = q.contiguous() 106 | cos = cos.contiguous() 107 | sin = sin.contiguous() 108 | 109 | _triton_rope[(n_row,)]( 110 | q, 111 | q.stride(1), 112 | cos, 113 | cos.stride(-2), 114 | sin, 115 | sin.stride(-2), 116 | seq_len, 117 | batch_size, 118 | n_q_head, 119 | head_dim, 120 | pad_n_q_head, 121 | pad_hd, 122 | BLOCK_SIZE=BLOCK_SIZE, 123 | BACKWARD_PASS=False, 124 | ) 125 | return q.transpose(1, 2), cos, sin 126 | 127 | 128 | def rope_backward(dq, cos, sin): 129 | dq = dq.transpose(1, 2) 130 | 131 | batch_size, seq_len, n_q_head, head_dim = dq.shape 132 | pad_hd = triton.next_power_of_2(head_dim) 133 | pad_n_q_head = triton.next_power_of_2(n_q_head) 134 | BLOCK_SIZE = pad_n_q_head 135 | 136 | n_row = batch_size * seq_len 137 | 138 | # ensure dq is contiguous 139 | dq = dq.contiguous() 140 | 141 | # backward is similar to forward except swapping few ops 142 | _triton_rope[(n_row,)]( 143 | dq, 144 | dq.stride(1), 145 | cos, 146 | cos.stride(-2), 147 | sin, 148 | sin.stride(-2), 149 | seq_len, 150 | batch_size, 151 | n_q_head, 152 | head_dim, 153 | pad_n_q_head, 154 | pad_hd, 155 | BLOCK_SIZE=BLOCK_SIZE, 156 | BACKWARD_PASS=True, 157 | ) 158 | return dq.transpose(1, 2) 159 | 160 | 161 | class SingleLigerRopeFunction(torch.autograd.Function): 162 | """ 163 | This function re-implements the RoPE operation with only one input tensor. 164 | """ 165 | 166 | @staticmethod 167 | def forward(ctx, q, cos, sin, position_ids=None, unsqueeze_dim=1): 168 | """ 169 | q size: (bsz, n_q_head, seq_len, head_dim) 170 | cos size: (1, seq_len, head_dim) 171 | sin size: (1, seq_len, head_dim) 172 | """ 173 | q, cos, sin = rope_forward(q, cos, sin) 174 | ctx.save_for_backward(cos, sin) 175 | return q 176 | 177 | def backward(ctx, dq): 178 | """ 179 | dq size: (bsz, n_q_head, seq_len, head_dim) 180 | cos size: (1, seq_len, head_dim) 181 | sin size: (1, seq_len, head_dim) 182 | """ 183 | 184 | cos, sin = ctx.saved_tensors 185 | dq = rope_backward(dq, cos, sin) 186 | return dq, None, None, None, None 187 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from dataclasses import dataclass 3 | from typing import List, Optional 4 | 5 | import torch 6 | 7 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 8 | 9 | 10 | @dataclass 11 | class IterStep: 12 | """A helper class for the iteration plan""" 13 | layer_slice: slice = slice(None) 14 | requires_grad: bool = True 15 | update: bool = True 16 | 17 | @dataclass 18 | class LayerType: 19 | """A helper class to collect the layer type information""" 20 | layer_idx: int 21 | use_sliding_window: bool 22 | attends_to: int 23 | attends_top: bool 24 | computes_kv: bool 25 | 26 | class LayerTypeParser: 27 | """ 28 | A helper class to parse the layer type string and provide some useful methods. 29 | 30 | Arguments: 31 | layer_type (str): A string of integers separated by underscores. The i-th integer 32 | means the layer will use the key-value pair in the i-th layer as the kv cache. 33 | Special characters may be placed after the integers: 34 | - `s` means the layer will use sliding window attention. 35 | 36 | >>> layer_type = LayerTypeParser("0_0_0_5s_5s_5s_8_8_8")[3] 37 | >>> layer_type.attends_to 38 | 5 39 | >>> layer_type.attends_top 40 | True 41 | >>> layer_type.use_sliding_window 42 | True 43 | """ 44 | def __init__(self, layer_type: str): 45 | self._layer_type = layer_type 46 | 47 | # parse the layer type 48 | self.layer_indices = [] 49 | self.sliding_window = [] 50 | for s in layer_type.split("_"): 51 | layer_idx, sliding_window = re.match(r"^(\d+)(s)?$", s).groups() 52 | self.layer_indices.append(int(layer_idx)) 53 | self.sliding_window.append(bool(sliding_window)) 54 | 55 | def __len__(self): 56 | return len(self.layer_indices) 57 | 58 | def __getitem__(self, layer_idx: int) -> LayerType: 59 | """return the layer type information for the given layer index""" 60 | return LayerType( 61 | layer_idx=layer_idx, 62 | use_sliding_window=self.sliding_window[layer_idx], 63 | attends_to=self.layer_indices[layer_idx], 64 | attends_top=self.layer_indices[layer_idx] > layer_idx, 65 | computes_kv=layer_idx in self.layer_indices, 66 | ) 67 | 68 | def use_sliding_window(self) -> bool: 69 | """whether there exists a layer that uses sliding window attention""" 70 | return any(self.sliding_window) 71 | 72 | def attends_top(self) -> bool: 73 | """whether there exists a layer that attends to layers above it""" 74 | return any(self.layer_indices[i] > i for i in range(len(self))) 75 | 76 | def iteration_plan(self, forward_passes: int = 7, backward_passes: int = 2) -> List[IterStep]: 77 | """ 78 | Return a iteration plan for the layer types. The plan is a list of IterStep objects. 79 | """ 80 | # if there is no cyclic dependency, return the default plan 81 | if not self.attends_top(): 82 | return [IterStep()] 83 | 84 | # otherwise, return the plan for the cyclic dependency 85 | plan = [] 86 | i = 0 87 | while i < len(self): 88 | 89 | # if the layer attends to top layers, resolve the cyclic dependency 90 | if self[i].attends_top: 91 | 92 | # find the top layer in the cyclic dependency 93 | top = self[i].attends_to 94 | while top < max(self.layer_indices[i: top + 1]): 95 | top = max(self.layer_indices[i: top + 1]) 96 | top += 1 97 | 98 | # create iteration plan for this group 99 | layer_slice = slice(i, top) 100 | plan.extend([ 101 | *forward_passes * [IterStep(layer_slice, requires_grad=False, update=False)], 102 | *(backward_passes - 1) * [IterStep(layer_slice, update=False)], 103 | IterStep(layer_slice) 104 | ]) 105 | 106 | # otherwise, create a default plan 107 | else: 108 | 109 | top = i + 1 110 | while top < len(self) and not self[top].attends_top: 111 | top += 1 112 | plan.append(IterStep(slice(i, top))) 113 | 114 | # update the index 115 | i = top 116 | 117 | return plan 118 | 119 | def check(self, num_hidden_layers: int): 120 | """Check if the layer type is valid""" 121 | if len(self.layer_indices) != num_hidden_layers: 122 | raise ValueError("The number of layer types should be equal to the number of hidden layers.") 123 | for i in range(num_hidden_layers): 124 | if self.layer_indices[i] not in range(num_hidden_layers): 125 | raise ValueError("The layer type should be in the range of the number of hidden layers.") 126 | 127 | 128 | def flash_attention_forward( 129 | query_states: torch.Tensor, 130 | key_states: torch.Tensor, 131 | value_states: torch.Tensor, 132 | attention_mask: torch.Tensor, 133 | query_length: int, 134 | is_causal: bool, 135 | dropout: float = 0.0, 136 | position_ids: Optional[torch.Tensor] = None, 137 | softmax_scale: Optional[float] = None, 138 | sliding_window: Optional[int] = None, 139 | use_top_left_mask: bool = False, 140 | softcap: Optional[float] = None, 141 | deterministic: bool = None, 142 | no_diag: bool = False, 143 | ): 144 | """ 145 | This function is a wrapper around the _flash_attention_forward function in the 146 | transformers library. It adds support to mask the diagonal elements of the attention 147 | matrix. The diagonal mask is used to resolve the cyclic dependencies in the LCKV model. 148 | """ 149 | prune_query = False 150 | if no_diag: 151 | if key_states.size(1) == 1: 152 | b, l, _, d = value_states.size() 153 | _, _, h, _ = query_states.size() 154 | return value_states.new_zeros((b, l, h, d)) 155 | 156 | if key_states.size(1) == query_states.size(1): 157 | prune_query = True 158 | query_states = query_states[:, 1:, :, :] 159 | query_length -= 1 160 | 161 | if attention_mask is not None: 162 | attention_mask = attention_mask[:, 1:] 163 | 164 | key_states = key_states[:, :-1, :, :] 165 | value_states = value_states[:, :-1, :, :] 166 | 167 | if sliding_window is not None: 168 | sliding_window = sliding_window - 1 169 | 170 | result: torch.Tensor = _flash_attention_forward( 171 | query_states=query_states, 172 | key_states=key_states, 173 | value_states=value_states, 174 | attention_mask=attention_mask, 175 | query_length=query_length, 176 | is_causal=is_causal, 177 | dropout=dropout, 178 | position_ids=position_ids, 179 | softmax_scale=softmax_scale, 180 | sliding_window=sliding_window, 181 | use_top_left_mask=use_top_left_mask, 182 | softcap=softcap, 183 | deterministic=deterministic, 184 | ) 185 | 186 | if prune_query: 187 | b, _, h, d = result.size() 188 | result = torch.cat([result.new_zeros((b, 1, h, d)), result], dim=1) 189 | 190 | return result 191 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 119 3 | 4 | [tool.ruff.lint] 5 | # Never enforce `E501` (line length violations). 6 | ignore = ["C901", "E501", "E741", "F402", "F823" ] 7 | select = ["C", "E", "F", "I", "W"] 8 | 9 | # Ignore import violations in all `__init__.py` files. 10 | [tool.ruff.lint.per-file-ignores] 11 | "__init__.py" = ["E402", "F401", "F403", "F811"] 12 | 13 | [tool.ruff.lint.isort] 14 | lines-after-imports = 2 15 | known-first-party = ["transformers"] 16 | 17 | [tool.ruff.format] 18 | # Like Black, use double quotes for strings. 19 | quote-style = "double" 20 | 21 | # Like Black, indent with spaces, rather than tabs. 22 | indent-style = "space" 23 | 24 | # Like Black, respect magic trailing commas. 25 | skip-magic-trailing-comma = false 26 | 27 | # Like Black, automatically detect the appropriate line ending. 28 | line-ending = "auto" 29 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers ~= 4.45.2 2 | trl ~= 0.11.4 3 | accelerate ~= 1.0.1 4 | datasets 5 | evaluate 6 | scikit-learn 7 | sentencepiece 8 | liger-kernel ~= 0.3.0 9 | flash-attn >= 2.6.3 10 | git+https://github.com/EleutherAI/lm-evaluation-harness@v0.4.5 11 | -------------------------------------------------------------------------------- /run_chat.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # Copyright 2024 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from trl.commands.cli_utils import init_zero_verbose 18 | 19 | init_zero_verbose() 20 | 21 | import copy 22 | import json 23 | import os 24 | import sys 25 | import pwd 26 | import re 27 | import time 28 | from threading import Thread 29 | 30 | import torch 31 | from rich.console import Console 32 | from rich.live import Live 33 | from rich.markdown import Markdown 34 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer 35 | import models 36 | 37 | from trl.commands.cli_utils import ChatArguments, TrlParser, init_zero_verbose 38 | from trl.trainer.utils import get_quantization_config 39 | 40 | 41 | HELP_STRING = """\ 42 | 43 | **TRL CHAT INTERFACE** 44 | 45 | The chat interface is a simple tool to try out a chat model. 46 | 47 | Besides talking to the model there are several commands: 48 | - **clear**: clears the current conversation and start a new one 49 | - **example {NAME}**: load example named `{NAME}` from the config and use it as the user input 50 | - **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';'). 51 | - **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set** 52 | - **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided 53 | - **exit**: closes the interface 54 | """ 55 | 56 | SUPPORTED_GENERATION_KWARGS = [ 57 | "max_new_tokens", 58 | "do_sample", 59 | "num_beams", 60 | "temperature", 61 | "top_p", 62 | "top_k", 63 | "repetition_penalty", 64 | ] 65 | 66 | SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$" 67 | 68 | 69 | # the default chat template for timdettmers/openassistant-guanaco 70 | DEFAULT_CHAT_TEMPLATE = """ 71 | {%- for message in messages %} 72 | {%- if message['role'] == 'user' %} 73 | {{- 'Human: ' + message['content'].strip() }} 74 | {%- elif message['role'] == 'assistant' %} 75 | {{- 'Assistant: ' + message['content'].strip() }} 76 | {%- endif %} 77 | {%- endfor %} 78 | """ 79 | 80 | 81 | class RichInterface: 82 | def __init__(self, model_name=None, user_name=None): 83 | self._console = Console() 84 | if model_name is None: 85 | self.model_name = "assistant" 86 | else: 87 | self.model_name = model_name 88 | if user_name is None: 89 | self.user_name = "user" 90 | else: 91 | self.user_name = user_name 92 | 93 | def stream_output(self, output_stream): 94 | """Stream output from a role.""" 95 | # This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py 96 | # Create a Live context for updating the console output 97 | text = "" 98 | self._console.print(f"[bold blue]<{self.model_name}>:") 99 | with Live(console=self._console, refresh_per_second=4) as live: 100 | # Read lines from the stream 101 | for i, outputs in enumerate(output_stream): 102 | if not outputs or i == 0: 103 | continue 104 | text += outputs 105 | # Render the accumulated text as Markdown 106 | # NOTE: this is a workaround for the rendering "unstandard markdown" 107 | # in rich. The chatbots output treat "\n" as a new line for 108 | # better compatibility with real-world text. However, rendering 109 | # in markdown would break the format. It is because standard markdown 110 | # treat a single "\n" in normal text as a space. 111 | # Our workaround is adding two spaces at the end of each line. 112 | # This is not a perfect solution, as it would 113 | # introduce trailing spaces (only) in code block, but it works well 114 | # especially for console output, because in general the console does not 115 | # care about trailing spaces. 116 | lines = [] 117 | for line in text.splitlines(): 118 | lines.append(line) 119 | if line.startswith("```"): 120 | # Code block marker - do not add trailing spaces, as it would 121 | # break the syntax highlighting 122 | lines.append("\n") 123 | else: 124 | lines.append(" \n") 125 | markdown = Markdown("".join(lines).strip(), code_theme="github-dark") 126 | # Update the Live console output 127 | live.update(markdown) 128 | self._console.print() 129 | return text 130 | 131 | def input(self): 132 | input = self._console.input(f"[bold red]<{self.user_name}>:\n") 133 | self._console.print() 134 | return input 135 | 136 | def clear(self): 137 | self._console.clear() 138 | 139 | def print_user_message(self, text): 140 | self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}") 141 | self._console.print() 142 | 143 | def print_green(self, text): 144 | self._console.print(f"[bold green]{text}") 145 | self._console.print() 146 | 147 | def print_red(self, text): 148 | self._console.print(f"[bold red]{text}") 149 | self._console.print() 150 | 151 | def print_help(self): 152 | self._console.print(Markdown(HELP_STRING)) 153 | self._console.print() 154 | 155 | 156 | def get_username(): 157 | return pwd.getpwuid(os.getuid())[0] 158 | 159 | 160 | def create_default_filename(model_name): 161 | time_str = time.strftime("%Y-%m-%d_%H-%M-%S") 162 | return f"{model_name}/chat_{time_str}.json" 163 | 164 | 165 | def save_chat(chat, args, filename): 166 | output_dict = {} 167 | output_dict["settings"] = vars(args) 168 | output_dict["chat_history"] = chat 169 | 170 | folder = args.save_folder 171 | 172 | if filename is None: 173 | filename = create_default_filename(args.model_name_or_path) 174 | filename = os.path.join(folder, filename) 175 | os.makedirs(os.path.dirname(filename), exist_ok=True) 176 | 177 | with open(filename, "w") as f: 178 | json.dump(output_dict, f, indent=4) 179 | return os.path.abspath(filename) 180 | 181 | 182 | def clear_chat_history(system_prompt): 183 | if system_prompt is None: 184 | chat = [] 185 | else: 186 | chat = [{"role": "system", "content": system_prompt}] 187 | return chat 188 | 189 | 190 | def parse_settings(user_input, current_args, interface): 191 | settings = user_input[4:].strip().split(";") 192 | settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings] 193 | settings = dict(settings) 194 | error = False 195 | 196 | for name in settings: 197 | if hasattr(current_args, name): 198 | try: 199 | if isinstance(getattr(current_args, name), bool): 200 | if settings[name] == "True": 201 | settings[name] = True 202 | elif settings[name] == "False": 203 | settings[name] = False 204 | else: 205 | raise ValueError 206 | else: 207 | settings[name] = type(getattr(current_args, name))(settings[name]) 208 | except ValueError: 209 | interface.print_red( 210 | f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}." 211 | ) 212 | else: 213 | interface.print_red(f"There is no '{name}' setting.") 214 | 215 | if error: 216 | interface.print_red("There was an issue parsing the settings. No settings have been changed.") 217 | return current_args, False 218 | else: 219 | for name in settings: 220 | setattr(current_args, name, settings[name]) 221 | interface.print_green(f"Set {name} to {settings[name]}.") 222 | 223 | time.sleep(1.5) # so the user has time to read the changes 224 | return current_args, True 225 | 226 | 227 | def load_model_and_tokenizer(args): 228 | tokenizer = AutoTokenizer.from_pretrained( 229 | args.model_name_or_path, 230 | revision=args.model_revision, 231 | trust_remote_code=args.trust_remote_code, 232 | ) 233 | 234 | torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype) 235 | quantization_config = get_quantization_config(args) 236 | model_kwargs = dict( 237 | revision=args.model_revision, 238 | attn_implementation=args.attn_implementation, 239 | torch_dtype=torch_dtype, 240 | device_map="auto", 241 | quantization_config=quantization_config, 242 | ) 243 | model = AutoModelForCausalLM.from_pretrained( 244 | args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs 245 | ) 246 | 247 | if getattr(model, "hf_device_map", None) is None: 248 | model = model.to(args.device) 249 | 250 | return model, tokenizer 251 | 252 | 253 | def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids): 254 | if tokenizer.pad_token_id is None: 255 | pad_token_id = tokenizer.eos_token_id 256 | else: 257 | pad_token_id = tokenizer.pad_token_id 258 | 259 | all_eos_token_ids = [] 260 | 261 | if eos_tokens is not None: 262 | all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(","))) 263 | 264 | if eos_token_ids is not None: 265 | all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")]) 266 | 267 | if len(all_eos_token_ids) == 0: 268 | all_eos_token_ids.append(tokenizer.eos_token_id) 269 | 270 | return pad_token_id, all_eos_token_ids 271 | 272 | 273 | def chat_cli(): 274 | parser = TrlParser(ChatArguments) 275 | 276 | if "--config" not in sys.argv: 277 | sys.argv.append("--config") 278 | sys.argv.append(os.path.join(os.path.dirname(__file__), "configs/default_chat_config.yaml")) 279 | args = parser.parse_args_and_config()[0] 280 | if args.examples is None: 281 | args.examples = {} 282 | 283 | current_args = copy.deepcopy(args) 284 | 285 | if args.user is None: 286 | user = get_username() 287 | else: 288 | user = args.user 289 | 290 | model, tokenizer = load_model_and_tokenizer(args) 291 | generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True) 292 | 293 | pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids) 294 | 295 | interface = RichInterface(model_name=args.model_name_or_path, user_name=user) 296 | interface.clear() 297 | chat = clear_chat_history(current_args.system_prompt) 298 | while True: 299 | try: 300 | user_input = interface.input() 301 | 302 | if user_input == "clear": 303 | chat = clear_chat_history(current_args.system_prompt) 304 | interface.clear() 305 | continue 306 | 307 | if user_input == "help": 308 | interface.print_help() 309 | continue 310 | 311 | if user_input == "exit": 312 | break 313 | 314 | if user_input == "reset": 315 | interface.clear() 316 | current_args = copy.deepcopy(args) 317 | chat = clear_chat_history(current_args.system_prompt) 318 | continue 319 | 320 | if user_input.startswith("save") and len(user_input.split()) < 2: 321 | split_input = user_input.split() 322 | 323 | if len(split_input) == 2: 324 | filename = split_input[1] 325 | else: 326 | filename = None 327 | filename = save_chat(chat, current_args, filename) 328 | interface.print_green(f"Chat saved in {filename}!") 329 | continue 330 | 331 | if re.match(SETTING_RE, user_input): 332 | current_args, success = parse_settings(user_input, current_args, interface) 333 | if success: 334 | chat = [] 335 | interface.clear() 336 | continue 337 | 338 | if user_input.startswith("example") and len(user_input.split()) == 2: 339 | example_name = user_input.split()[1] 340 | if example_name in current_args.examples: 341 | interface.clear() 342 | chat = [] 343 | interface.print_user_message(current_args.examples[example_name]["text"]) 344 | user_input = current_args.examples[example_name]["text"] 345 | else: 346 | interface.print_red( 347 | f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}." 348 | ) 349 | continue 350 | 351 | chat.append({"role": "user", "content": user_input}) 352 | 353 | if tokenizer.chat_template is None: 354 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE 355 | 356 | generation_kwargs = dict( 357 | inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( 358 | model.device 359 | ), 360 | streamer=generation_streamer, 361 | max_new_tokens=current_args.max_new_tokens, 362 | do_sample=current_args.do_sample, 363 | num_beams=current_args.num_beams, 364 | temperature=current_args.temperature, 365 | top_k=current_args.top_k, 366 | top_p=current_args.top_p, 367 | repetition_penalty=current_args.repetition_penalty, 368 | pad_token_id=pad_token_id, 369 | eos_token_id=eos_token_ids, 370 | ) 371 | 372 | thread = Thread(target=model.generate, kwargs=generation_kwargs) 373 | thread.start() 374 | model_output = interface.stream_output(generation_streamer) 375 | thread.join() 376 | chat.append({"role": "assistant", "content": model_output}) 377 | 378 | except KeyboardInterrupt: 379 | break 380 | 381 | 382 | if __name__ == "__main__": 383 | chat_cli() 384 | -------------------------------------------------------------------------------- /run_clm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 18 | 19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 20 | https://huggingface.co/models?filter=text-generation 21 | """ 22 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 23 | 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | import warnings 29 | from dataclasses import dataclass, field 30 | from itertools import chain 31 | from typing import Optional 32 | 33 | import datasets 34 | import evaluate 35 | import torch 36 | from datasets import load_dataset 37 | 38 | import transformers 39 | from transformers import ( 40 | CONFIG_MAPPING, 41 | MODEL_FOR_CAUSAL_LM_MAPPING, 42 | AutoConfig, 43 | AutoModelForCausalLM, 44 | AutoTokenizer, 45 | HfArgumentParser, 46 | Trainer, 47 | TrainingArguments, 48 | default_data_collator, 49 | is_torch_tpu_available, 50 | set_seed, 51 | ) 52 | from transformers.testing_utils import CaptureLogger 53 | from transformers.trainer_utils import get_last_checkpoint 54 | from transformers.utils import check_min_version, send_example_telemetry 55 | from transformers.utils.versions import require_version 56 | 57 | import models 58 | 59 | 60 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 61 | # check_min_version("4.33.0.dev0") 62 | 63 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") 64 | 65 | logger = logging.getLogger(__name__) 66 | 67 | 68 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 69 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 70 | 71 | 72 | # learning rate decay scheduler (cosine with warmup) 73 | def tinyllama_get_lr( 74 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float 75 | ): 76 | final_div_factor = 0.1 77 | if current_step < num_warmup_steps: 78 | return float(current_step) / float(max(1, num_warmup_steps)) 79 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 80 | coeff = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) 81 | return final_div_factor + coeff * (1 - final_div_factor) 82 | 83 | ## Uncomment the following line to use the cosine schedule with the minimum lr set to 10% of the initial lr 84 | # transformers.optimization._get_cosine_schedule_with_warmup_lr_lambda = tinyllama_get_lr 85 | 86 | 87 | @dataclass 88 | class ModelArguments: 89 | """ 90 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 91 | """ 92 | 93 | model_name_or_path: Optional[str] = field( 94 | default=None, 95 | metadata={ 96 | "help": ( 97 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch." 98 | ) 99 | }, 100 | ) 101 | model_type: Optional[str] = field( 102 | default=None, 103 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 104 | ) 105 | config_overrides: Optional[str] = field( 106 | default=None, 107 | metadata={ 108 | "help": ( 109 | "Override some existing default config settings when a model is trained from scratch. Example: " 110 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 111 | ) 112 | }, 113 | ) 114 | config_name: Optional[str] = field( 115 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 116 | ) 117 | tokenizer_name: Optional[str] = field( 118 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 119 | ) 120 | cache_dir: Optional[str] = field( 121 | default=None, 122 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 123 | ) 124 | use_fast_tokenizer: bool = field( 125 | default=True, 126 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 127 | ) 128 | model_revision: str = field( 129 | default="main", 130 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 131 | ) 132 | token: str = field( 133 | default=None, 134 | metadata={ 135 | "help": ( 136 | "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " 137 | "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." 138 | ) 139 | }, 140 | ) 141 | use_auth_token: bool = field( 142 | default=None, 143 | metadata={ 144 | "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." 145 | }, 146 | ) 147 | trust_remote_code: bool = field( 148 | default=False, 149 | metadata={ 150 | "help": ( 151 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option" 152 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will" 153 | "execute code present on the Hub on your local machine." 154 | ) 155 | }, 156 | ) 157 | torch_dtype: Optional[str] = field( 158 | default=None, 159 | metadata={ 160 | "help": ( 161 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 162 | "dtype will be automatically derived from the model's weights." 163 | ), 164 | "choices": ["auto", "bfloat16", "float16", "float32"], 165 | }, 166 | ) 167 | low_cpu_mem_usage: bool = field( 168 | default=False, 169 | metadata={ 170 | "help": ( 171 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." 172 | "set True will benefit LLM loading time and RAM consumption." 173 | ) 174 | }, 175 | ) 176 | 177 | attn_implementation: Optional[str] = field( 178 | default="flash_attention_2", 179 | metadata={ 180 | "help": ( 181 | "The attention implementation to use. By default the LCKV implementation uses flash attention implementation." 182 | ), 183 | "choices": ["eager", "flash_attention_2", "sdpa"], 184 | }, 185 | ) 186 | 187 | 188 | @dataclass 189 | class DataTrainingArguments: 190 | """ 191 | Arguments pertaining to what data we are going to input our model for training and eval. 192 | """ 193 | 194 | dataset_name: Optional[str] = field( 195 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 196 | ) 197 | dataset_config_name: Optional[str] = field( 198 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 199 | ) 200 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 201 | validation_file: Optional[str] = field( 202 | default=None, 203 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 204 | ) 205 | max_train_samples: Optional[int] = field( 206 | default=None, 207 | metadata={ 208 | "help": ( 209 | "For debugging purposes or quicker training, truncate the number of training examples to this " 210 | "value if set." 211 | ) 212 | }, 213 | ) 214 | max_eval_samples: Optional[int] = field( 215 | default=None, 216 | metadata={ 217 | "help": ( 218 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 219 | "value if set." 220 | ) 221 | }, 222 | ) 223 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 224 | block_size: Optional[int] = field( 225 | default=None, 226 | metadata={ 227 | "help": ( 228 | "Optional input sequence length after tokenization. " 229 | "The training dataset will be truncated in block of this size for training. " 230 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 231 | ) 232 | }, 233 | ) 234 | overwrite_cache: bool = field( 235 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 236 | ) 237 | validation_split_percentage: Optional[int] = field( 238 | default=5, 239 | metadata={ 240 | "help": "The percentage of the train set used as validation set in case there's no validation split" 241 | }, 242 | ) 243 | preprocessing_num_workers: Optional[int] = field( 244 | default=None, 245 | metadata={"help": "The number of processes to use for the preprocessing."}, 246 | ) 247 | keep_linebreaks: bool = field( 248 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."} 249 | ) 250 | 251 | def __post_init__(self): 252 | if self.streaming: 253 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") 254 | 255 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 256 | raise ValueError("Need either a dataset name or a training/validation file.") 257 | else: 258 | if self.train_file is not None: 259 | extension = self.train_file.split(".")[-1] 260 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file." 261 | if self.validation_file is not None: 262 | extension = self.validation_file.split(".")[-1] 263 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file." 264 | 265 | 266 | def main(): 267 | # See all possible arguments in src/transformers/training_args.py 268 | # or by passing the --help flag to this script. 269 | # We now keep distinct sets of args, for a cleaner separation of concerns. 270 | 271 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 272 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 273 | # If we pass only one argument to the script and it's the path to a json file, 274 | # let's parse it to get our arguments. 275 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 276 | else: 277 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 278 | 279 | if model_args.use_auth_token is not None: 280 | warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning) 281 | if model_args.token is not None: 282 | raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.") 283 | model_args.token = model_args.use_auth_token 284 | 285 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 286 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 287 | # send_example_telemetry("run_clm", model_args, data_args) 288 | 289 | # Setup logging 290 | logging.basicConfig( 291 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 292 | datefmt="%m/%d/%Y %H:%M:%S", 293 | handlers=[logging.StreamHandler(sys.stdout)], 294 | ) 295 | 296 | if training_args.should_log: 297 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 298 | transformers.utils.logging.set_verbosity_info() 299 | 300 | log_level = training_args.get_process_log_level() 301 | logger.setLevel(log_level) 302 | datasets.utils.logging.set_verbosity(log_level) 303 | transformers.utils.logging.set_verbosity(log_level) 304 | transformers.utils.logging.enable_default_handler() 305 | transformers.utils.logging.enable_explicit_format() 306 | 307 | # Log on each process the small summary: 308 | logger.warning( 309 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 310 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 311 | ) 312 | logger.info(f"Training/evaluation parameters {training_args}") 313 | 314 | # Detecting last checkpoint. 315 | last_checkpoint = None 316 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 317 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 318 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 319 | raise ValueError( 320 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 321 | "Use --overwrite_output_dir to overcome." 322 | ) 323 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 324 | logger.info( 325 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 326 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 327 | ) 328 | 329 | # Set seed before initializing model. 330 | set_seed(training_args.seed) 331 | 332 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 333 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 334 | # (the dataset will be downloaded automatically from the datasets Hub). 335 | # 336 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 337 | # 'text' is found. You can easily tweak this behavior (see below). 338 | # 339 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 340 | # download the dataset. 341 | if data_args.dataset_name is not None: 342 | # Downloading and loading a dataset from the hub. 343 | raw_datasets = load_dataset( 344 | data_args.dataset_name, 345 | data_args.dataset_config_name, 346 | cache_dir=model_args.cache_dir, 347 | token=model_args.token, 348 | streaming=data_args.streaming, 349 | download_config=datasets.DownloadConfig(resume_download=True,max_retries=10000) 350 | ) 351 | if "validation" not in raw_datasets.keys(): 352 | raw_datasets["validation"] = load_dataset( 353 | data_args.dataset_name, 354 | data_args.dataset_config_name, 355 | split=f"train[:{data_args.validation_split_percentage}%]", 356 | cache_dir=model_args.cache_dir, 357 | token=model_args.token, 358 | streaming=data_args.streaming, 359 | ) 360 | raw_datasets["train"] = load_dataset( 361 | data_args.dataset_name, 362 | data_args.dataset_config_name, 363 | split=f"train[{data_args.validation_split_percentage}%:]", 364 | cache_dir=model_args.cache_dir, 365 | token=model_args.token, 366 | streaming=data_args.streaming, 367 | ) 368 | else: 369 | data_files = {} 370 | dataset_args = {} 371 | if data_args.train_file is not None: 372 | data_files["train"] = data_args.train_file 373 | if data_args.validation_file is not None: 374 | data_files["validation"] = data_args.validation_file 375 | extension = ( 376 | data_args.train_file.split(".")[-1] 377 | if data_args.train_file is not None 378 | else data_args.validation_file.split(".")[-1] 379 | ) 380 | if extension == "txt": 381 | extension = "text" 382 | dataset_args["keep_linebreaks"] = data_args.keep_linebreaks 383 | raw_datasets = load_dataset( 384 | extension, 385 | data_files=data_files, 386 | cache_dir=model_args.cache_dir, 387 | token=model_args.token, 388 | **dataset_args, 389 | ) 390 | # If no validation data is there, validation_split_percentage will be used to divide the dataset. 391 | if "validation" not in raw_datasets.keys(): 392 | raw_datasets["validation"] = load_dataset( 393 | extension, 394 | data_files=data_files, 395 | split=f"train[:{data_args.validation_split_percentage}%]", 396 | cache_dir=model_args.cache_dir, 397 | token=model_args.token, 398 | **dataset_args, 399 | ) 400 | raw_datasets["train"] = load_dataset( 401 | extension, 402 | data_files=data_files, 403 | split=f"train[{data_args.validation_split_percentage}%:]", 404 | cache_dir=model_args.cache_dir, 405 | token=model_args.token, 406 | **dataset_args, 407 | ) 408 | 409 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 410 | # https://huggingface.co/docs/datasets/loading_datasets.html. 411 | 412 | # Load pretrained model and tokenizer 413 | # 414 | # Distributed training: 415 | # The .from_pretrained methods guarantee that only one local process can concurrently 416 | # download model & vocab. 417 | 418 | config_kwargs = { 419 | "cache_dir": model_args.cache_dir, 420 | "revision": model_args.model_revision, 421 | "token": model_args.token, 422 | "trust_remote_code": model_args.trust_remote_code, 423 | } 424 | if model_args.config_name: 425 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 426 | elif model_args.model_name_or_path: 427 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 428 | else: 429 | config = CONFIG_MAPPING[model_args.model_type]() 430 | logger.warning("You are instantiating a new config instance from scratch.") 431 | 432 | # allow config overrides under all circumstances 433 | if model_args.config_overrides is not None: 434 | logger.info(f"Overriding config: {model_args.config_overrides}") 435 | config.update_from_string(model_args.config_overrides) 436 | logger.info(f"New config: {config}") 437 | 438 | tokenizer_kwargs = { 439 | "cache_dir": model_args.cache_dir, 440 | "use_fast": model_args.use_fast_tokenizer, 441 | "revision": model_args.model_revision, 442 | "token": model_args.token, 443 | "trust_remote_code": model_args.trust_remote_code, 444 | } 445 | if model_args.tokenizer_name: 446 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) 447 | elif model_args.model_name_or_path: 448 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs) 449 | else: 450 | raise ValueError( 451 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 452 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 453 | ) 454 | 455 | if model_args.model_name_or_path: 456 | torch_dtype = ( 457 | model_args.torch_dtype 458 | if model_args.torch_dtype in ["auto", None] 459 | else getattr(torch, model_args.torch_dtype) 460 | ) 461 | model = AutoModelForCausalLM.from_pretrained( 462 | model_args.model_name_or_path, 463 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 464 | config=config, 465 | cache_dir=model_args.cache_dir, 466 | revision=model_args.model_revision, 467 | token=model_args.token, 468 | trust_remote_code=model_args.trust_remote_code, 469 | torch_dtype=torch_dtype, 470 | low_cpu_mem_usage=model_args.low_cpu_mem_usage, 471 | attn_implementation=model_args.attn_implementation, 472 | ) 473 | else: 474 | torch_dtype = ( 475 | model_args.torch_dtype 476 | if model_args.torch_dtype in ["auto", None] 477 | else getattr(torch, model_args.torch_dtype) 478 | ) 479 | config._attn_implementation = model_args.attn_implementation 480 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code, torch_dtype=torch_dtype) 481 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values()) 482 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params") 483 | 484 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 485 | # on a small vocab and want a smaller embedding size, remove this test. 486 | embedding_size = model.get_input_embeddings().weight.shape[0] 487 | if len(tokenizer) > embedding_size: 488 | model.resize_token_embeddings(len(tokenizer)) 489 | 490 | # Preprocessing the datasets. 491 | # First we tokenize all the texts. 492 | if training_args.do_train: 493 | column_names = list(raw_datasets["train"].features) 494 | else: 495 | column_names = list(raw_datasets["validation"].features) 496 | text_column_name = "text" if "text" in column_names else column_names[0] 497 | 498 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 499 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") 500 | 501 | def tokenize_function(examples): 502 | with CaptureLogger(tok_logger) as cl: 503 | output = tokenizer(examples[text_column_name]) 504 | # clm input could be much much longer than block_size 505 | if "Token indices sequence length is longer than the" in cl.out: 506 | tok_logger.warning( 507 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" 508 | " before being passed to the model." 509 | ) 510 | return output 511 | 512 | with training_args.main_process_first(desc="dataset map tokenization"): 513 | if not data_args.streaming: 514 | tokenized_datasets = raw_datasets.map( 515 | tokenize_function, 516 | batched=True, 517 | num_proc=data_args.preprocessing_num_workers, 518 | remove_columns=column_names, 519 | load_from_cache_file=not data_args.overwrite_cache, 520 | desc="Running tokenizer on dataset", 521 | ) 522 | else: 523 | tokenized_datasets = raw_datasets.map( 524 | tokenize_function, 525 | batched=True, 526 | remove_columns=column_names, 527 | ) 528 | 529 | if data_args.block_size is None: 530 | block_size = tokenizer.model_max_length 531 | if block_size > 1024: 532 | logger.warning( 533 | "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value" 534 | " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can" 535 | " override this default with `--block_size xxx`." 536 | ) 537 | block_size = 1024 538 | else: 539 | if data_args.block_size > tokenizer.model_max_length: 540 | logger.warning( 541 | f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" 542 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 543 | ) 544 | block_size = min(data_args.block_size, tokenizer.model_max_length) 545 | 546 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 547 | def group_texts(examples): 548 | # Concatenate all texts. 549 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} 550 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 551 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. 552 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. 553 | total_length = (total_length // block_size) * block_size 554 | # Split by chunks of max_len. 555 | result = { 556 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 557 | for k, t in concatenated_examples.items() 558 | } 559 | result["labels"] = result["input_ids"].copy() 560 | return result 561 | 562 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder 563 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower 564 | # to preprocess. 565 | # 566 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 567 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 568 | 569 | with training_args.main_process_first(desc="grouping texts together"): 570 | if not data_args.streaming: 571 | lm_datasets = tokenized_datasets.map( 572 | group_texts, 573 | batched=True, 574 | num_proc=data_args.preprocessing_num_workers, 575 | load_from_cache_file=not data_args.overwrite_cache, 576 | desc=f"Grouping texts in chunks of {block_size}", 577 | ) 578 | else: 579 | lm_datasets = tokenized_datasets.map( 580 | group_texts, 581 | batched=True, 582 | ) 583 | 584 | if training_args.do_train: 585 | if "train" not in tokenized_datasets: 586 | raise ValueError("--do_train requires a train dataset") 587 | train_dataset = lm_datasets["train"] 588 | if data_args.max_train_samples is not None: 589 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 590 | train_dataset = train_dataset.select(range(max_train_samples)) 591 | 592 | if training_args.do_eval: 593 | if "validation" not in tokenized_datasets: 594 | raise ValueError("--do_eval requires a validation dataset") 595 | eval_dataset = lm_datasets["validation"] 596 | if data_args.max_eval_samples is not None: 597 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 598 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 599 | 600 | def preprocess_logits_for_metrics(logits, labels): 601 | if isinstance(logits, tuple): 602 | # Depending on the model and config, logits may contain extra tensors, 603 | # like past_key_values, but logits always come first 604 | logits = logits[0] 605 | return logits.argmax(dim=-1) 606 | 607 | metric = evaluate.load("accuracy") 608 | 609 | def compute_metrics(eval_preds): 610 | preds, labels = eval_preds 611 | # preds have the same shape as the labels, after the argmax(-1) has been calculated 612 | # by preprocess_logits_for_metrics but we need to shift the labels 613 | labels = labels[:, 1:].reshape(-1) 614 | preds = preds[:, :-1].reshape(-1) 615 | return metric.compute(predictions=preds, references=labels) 616 | 617 | if training_args.do_predict: 618 | if "test" not in tokenized_datasets: 619 | raise ValueError("--do_predict requires a test dataset") 620 | test_dataset = lm_datasets["test"] 621 | 622 | # Initialize our Trainer 623 | trainer = Trainer( 624 | model=model, 625 | args=training_args, 626 | train_dataset=train_dataset if training_args.do_train else None, 627 | eval_dataset=eval_dataset if training_args.do_eval else None, 628 | tokenizer=tokenizer, 629 | # Data collator will default to DataCollatorWithPadding, so we change it. 630 | data_collator=default_data_collator, 631 | compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None, 632 | preprocess_logits_for_metrics=preprocess_logits_for_metrics 633 | if training_args.do_eval and not is_torch_tpu_available() 634 | else None, 635 | ) 636 | 637 | # Training 638 | if training_args.do_train: 639 | checkpoint = None 640 | if training_args.resume_from_checkpoint is not None: 641 | checkpoint = training_args.resume_from_checkpoint 642 | elif last_checkpoint is not None: 643 | checkpoint = last_checkpoint 644 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 645 | trainer.save_model() # Saves the tokenizer too for easy upload 646 | 647 | metrics = train_result.metrics 648 | 649 | max_train_samples = ( 650 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 651 | ) 652 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 653 | 654 | trainer.log_metrics("train", metrics) 655 | trainer.save_metrics("train", metrics) 656 | trainer.save_state() 657 | 658 | # Evaluation 659 | if training_args.do_eval: 660 | logger.info("*** Evaluate ***") 661 | 662 | metrics = trainer.evaluate() 663 | 664 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 665 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 666 | try: 667 | perplexity = math.exp(metrics["eval_loss"]) 668 | except OverflowError: 669 | perplexity = float("inf") 670 | metrics["eval_perplexity"] = perplexity 671 | 672 | trainer.log_metrics("eval", metrics) 673 | trainer.save_metrics("eval", metrics) 674 | 675 | # Predict 676 | if training_args.do_predict: 677 | logger.info("*** Predict ***") 678 | 679 | metrics = trainer.predict(test_dataset).metrics 680 | 681 | metrics["test_samples"] = len(test_dataset) 682 | try: 683 | perplexity = math.exp(metrics["test_loss"]) 684 | except OverflowError: 685 | perplexity = float("inf") 686 | metrics["test_perplexity"] = perplexity 687 | 688 | trainer.log_metrics("test", metrics) 689 | trainer.save_metrics("test", metrics) 690 | 691 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"} 692 | if data_args.dataset_name is not None: 693 | kwargs["dataset_tags"] = data_args.dataset_name 694 | if data_args.dataset_config_name is not None: 695 | kwargs["dataset_args"] = data_args.dataset_config_name 696 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 697 | else: 698 | kwargs["dataset"] = data_args.dataset_name 699 | 700 | # if training_args.push_to_hub: 701 | # trainer.push_to_hub(**kwargs) 702 | # else: 703 | # trainer.create_model_card(**kwargs) 704 | 705 | 706 | def _mp_fn(index): 707 | # For xla_spawn (TPUs) 708 | main() 709 | 710 | 711 | if __name__ == "__main__": 712 | main() 713 | -------------------------------------------------------------------------------- /run_clm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ## pretrain code for llama-tiny 4 | # - to pretrain a tinyllama, change the config to `TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T` 5 | # - to intialize the model with a pretrained model, add `--model_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T` 6 | # - to use the minipile dataset, use `--dataset_name JeanKaddour/minipile`, with proper `--preprocessing_num_workers` 7 | # - to enable wandb, use `--report_to wandb` 8 | accelerate launch run_clm.py \ 9 | --tokenizer_name TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T \ 10 | --config_name configs/llama_tiny_lckv.json \ 11 | --config_overrides layer_types=0_6_6_6_6_6_6_7,forward_passes=7,backward_passes=2 \ 12 | --dataset_name wikitext \ 13 | --dataset_config_name wikitext-103-raw-v1 \ 14 | --per_device_train_batch_size 32 \ 15 | --per_device_eval_batch_size 32 \ 16 | --auto_find_batch_size \ 17 | --gradient_accumulation_steps 1 \ 18 | --block_size 1024 \ 19 | --lr_scheduler_type cosine \ 20 | --warmup_ratio 0.015 \ 21 | --learning_rate 3e-4 \ 22 | --weight_decay 1e-1 \ 23 | --bf16 \ 24 | --torch_dtype bfloat16 \ 25 | --do_train \ 26 | --do_eval \ 27 | --use_liger_kernel \ 28 | --num_train_epochs 3 \ 29 | --save_total_limit 1 \ 30 | --save_strategy steps \ 31 | --save_steps 500 \ 32 | --evaluation_strategy steps \ 33 | --eval_steps 500 \ 34 | --load_best_model_at_end True \ 35 | --metric_for_best_model eval_loss \ 36 | --report_to none \ 37 | --run_name llamatiny-test \ 38 | --overwrite_output_dir \ 39 | --output_dir outputs/llamatiny-test 40 | -------------------------------------------------------------------------------- /run_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) 18 | """ 19 | 20 | 21 | import argparse 22 | import inspect 23 | import logging 24 | from typing import Tuple 25 | 26 | import torch 27 | from accelerate import PartialState 28 | from accelerate.utils import set_seed 29 | 30 | from models import LCKVLlamaForCausalLM 31 | from models.cache_utils import IndexedHybridCache 32 | from models.utils import LayerTypeParser 33 | from transformers import ( 34 | AutoTokenizer, 35 | BloomForCausalLM, 36 | BloomTokenizerFast, 37 | CTRLLMHeadModel, 38 | CTRLTokenizer, 39 | GenerationMixin, 40 | GPT2LMHeadModel, 41 | GPT2Tokenizer, 42 | GPTJForCausalLM, 43 | LlamaForCausalLM, 44 | LlamaTokenizer, 45 | OpenAIGPTLMHeadModel, 46 | OpenAIGPTTokenizer, 47 | OPTForCausalLM, 48 | TransfoXLLMHeadModel, 49 | TransfoXLTokenizer, 50 | XLMTokenizer, 51 | XLMWithLMHeadModel, 52 | XLNetLMHeadModel, 53 | XLNetTokenizer, 54 | ) 55 | from transformers.cache_utils import SinkCache 56 | from transformers.modeling_outputs import CausalLMOutputWithPast 57 | 58 | 59 | logging.basicConfig( 60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 61 | datefmt="%m/%d/%Y %H:%M:%S", 62 | level=logging.INFO, 63 | ) 64 | logger = logging.getLogger(__name__) 65 | 66 | MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop 67 | 68 | MODEL_CLASSES = { 69 | "lckv-llama": (LCKVLlamaForCausalLM, LlamaTokenizer), 70 | "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), 71 | "ctrl": (CTRLLMHeadModel, CTRLTokenizer), 72 | "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), 73 | "xlnet": (XLNetLMHeadModel, XLNetTokenizer), 74 | "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), 75 | "xlm": (XLMWithLMHeadModel, XLMTokenizer), 76 | "gptj": (GPTJForCausalLM, AutoTokenizer), 77 | "bloom": (BloomForCausalLM, BloomTokenizerFast), 78 | "llama": (LlamaForCausalLM, LlamaTokenizer), 79 | "opt": (OPTForCausalLM, GPT2Tokenizer), 80 | } 81 | 82 | # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia 83 | # in https://github.com/rusiaaman/XLNet-gen#methodology 84 | # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e 85 | PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family 86 | (except for Alexei and Maria) are discovered. 87 | The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the 88 | remainder of the story. 1883 Western Siberia, 89 | a young Grigori Rasputin is asked by his father and a group of men to perform magic. 90 | Rasputin has a vision and denounces one of the men as a horse thief. Although his 91 | father initially slaps him for making such an accusation, Rasputin watches as the 92 | man is chased outside and beaten. Twenty years later, Rasputin sees a vision of 93 | the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, 94 | with people, even a bishop, begging for his blessing. """ 95 | 96 | 97 | # 98 | # Functions to prepare models' input 99 | # 100 | 101 | 102 | def prepare_ctrl_input(args, _, tokenizer, prompt_text): 103 | if args.temperature > 0.7: 104 | logger.info("CTRL typically works better with lower temperatures (and lower top_k).") 105 | 106 | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) 107 | if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): 108 | logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") 109 | return prompt_text 110 | 111 | 112 | def prepare_xlm_input(args, model, tokenizer, prompt_text): 113 | # kwargs = {"language": None, "mask_token_id": None} 114 | 115 | # Set the language 116 | use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb 117 | if hasattr(model.config, "lang2id") and use_lang_emb: 118 | available_languages = model.config.lang2id.keys() 119 | if args.xlm_language in available_languages: 120 | language = args.xlm_language 121 | else: 122 | language = None 123 | while language not in available_languages: 124 | language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") 125 | 126 | model.config.lang_id = model.config.lang2id[language] 127 | # kwargs["language"] = tokenizer.lang2id[language] 128 | 129 | # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers 130 | # XLM masked-language modeling (MLM) models need masked token 131 | # is_xlm_mlm = "mlm" in args.model_name_or_path 132 | # if is_xlm_mlm: 133 | # kwargs["mask_token_id"] = tokenizer.mask_token_id 134 | 135 | return prompt_text 136 | 137 | 138 | def prepare_xlnet_input(args, _, tokenizer, prompt_text): 139 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX 140 | prompt_text = prefix + prompt_text 141 | return prompt_text 142 | 143 | 144 | def prepare_transfoxl_input(args, _, tokenizer, prompt_text): 145 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX 146 | prompt_text = prefix + prompt_text 147 | return prompt_text 148 | 149 | 150 | PREPROCESSING_FUNCTIONS = { 151 | "ctrl": prepare_ctrl_input, 152 | "xlm": prepare_xlm_input, 153 | "xlnet": prepare_xlnet_input, 154 | "transfo-xl": prepare_transfoxl_input, 155 | } 156 | 157 | 158 | def adjust_length_to_model(length, max_sequence_length): 159 | if length < 0 and max_sequence_length > 0: 160 | length = max_sequence_length 161 | elif 0 < max_sequence_length < length: 162 | length = max_sequence_length # No generation bigger than model size 163 | elif length < 0: 164 | length = MAX_LENGTH # avoid infinite loop 165 | return length 166 | 167 | 168 | def sparse_model_config(model_config): 169 | embedding_size = None 170 | if hasattr(model_config, "hidden_size"): 171 | embedding_size = model_config.hidden_size 172 | elif hasattr(model_config, "n_embed"): 173 | embedding_size = model_config.n_embed 174 | elif hasattr(model_config, "n_embd"): 175 | embedding_size = model_config.n_embd 176 | 177 | num_head = None 178 | if hasattr(model_config, "num_attention_heads"): 179 | num_head = model_config.num_attention_heads 180 | elif hasattr(model_config, "n_head"): 181 | num_head = model_config.n_head 182 | 183 | if embedding_size is None or num_head is None or num_head == 0: 184 | raise ValueError("Check the model config") 185 | 186 | num_embedding_size_per_head = int(embedding_size / num_head) 187 | if hasattr(model_config, "n_layer"): 188 | num_layer = model_config.n_layer 189 | elif hasattr(model_config, "num_hidden_layers"): 190 | num_layer = model_config.num_hidden_layers 191 | else: 192 | raise ValueError("Number of hidden layers couldn't be determined from the model config") 193 | 194 | return num_layer, num_head, num_embedding_size_per_head 195 | 196 | 197 | def generate_past_key_values(model, batch_size, seq_len): 198 | num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) 199 | if model.config.model_type == "bloom": 200 | past_key_values = tuple( 201 | ( 202 | torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len) 203 | .to(model.dtype) 204 | .to(model.device), 205 | torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head) 206 | .to(model.dtype) 207 | .to(model.device), 208 | ) 209 | for _ in range(num_block_layers) 210 | ) 211 | else: 212 | past_key_values = tuple( 213 | ( 214 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) 215 | .to(model.dtype) 216 | .to(model.device), 217 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) 218 | .to(model.dtype) 219 | .to(model.device), 220 | ) 221 | for _ in range(num_block_layers) 222 | ) 223 | return past_key_values 224 | 225 | 226 | def prepare_jit_inputs(inputs, model, tokenizer): 227 | batch_size = len(inputs) 228 | dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") 229 | dummy_input = dummy_input.to(model.device) 230 | if model.config.use_cache: 231 | dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1) 232 | dummy_input["attention_mask"] = torch.cat( 233 | [ 234 | torch.zeros(dummy_input["attention_mask"].shape[0], 1) 235 | .to(dummy_input["attention_mask"].dtype) 236 | .to(model.device), 237 | dummy_input["attention_mask"], 238 | ], 239 | -1, 240 | ) 241 | return dummy_input 242 | 243 | 244 | class _ModelFallbackWrapper(GenerationMixin): 245 | __slots__ = ("_optimized", "_default") 246 | 247 | def __init__(self, optimized, default): 248 | self._optimized = optimized 249 | self._default = default 250 | 251 | def __call__(self, *args, **kwargs): 252 | if kwargs["past_key_values"] is None and self._default.config.use_cache: 253 | kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0) 254 | kwargs.pop("position_ids", None) 255 | for k in list(kwargs.keys()): 256 | if kwargs[k] is None or isinstance(kwargs[k], bool): 257 | kwargs.pop(k) 258 | outputs = self._optimized(**kwargs) 259 | lm_logits = outputs[0] 260 | past_key_values = outputs[1] 261 | fixed_output = CausalLMOutputWithPast( 262 | loss=None, 263 | logits=lm_logits, 264 | past_key_values=past_key_values, 265 | hidden_states=None, 266 | attentions=None, 267 | ) 268 | return fixed_output 269 | 270 | def __getattr__(self, item): 271 | return getattr(self._default, item) 272 | 273 | def prepare_inputs_for_generation( 274 | self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs 275 | ): 276 | return self._default.prepare_inputs_for_generation( 277 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs 278 | ) 279 | 280 | def _reorder_cache( 281 | self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor 282 | ) -> Tuple[Tuple[torch.Tensor]]: 283 | """ 284 | This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or 285 | [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct 286 | beam_idx at every generation step. 287 | """ 288 | return self._default._reorder_cache(past_key_values, beam_idx) 289 | 290 | 291 | def main(): 292 | parser = argparse.ArgumentParser() 293 | parser.add_argument( 294 | "--model_type", 295 | default=None, 296 | type=str, 297 | required=True, 298 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 299 | ) 300 | parser.add_argument( 301 | "--model_name_or_path", 302 | default=None, 303 | type=str, 304 | required=True, 305 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), 306 | ) 307 | parser.add_argument( 308 | "--tokenizer_name", 309 | default=None, 310 | type=str, 311 | help="Pretrained tokenizer name or path if not the same as model_name", 312 | ) 313 | parser.add_argument( 314 | "--config_overrides", 315 | default=None, 316 | type=str, 317 | help=( 318 | "Override some existing default config settings when a model is trained from scratch. Example: " 319 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 320 | ), 321 | ) 322 | 323 | parser.add_argument("--prompt", type=str, default="") 324 | parser.add_argument("--length", type=int, default=20) 325 | parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") 326 | 327 | parser.add_argument( 328 | "--temperature", 329 | type=float, 330 | default=1.0, 331 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling", 332 | ) 333 | parser.add_argument( 334 | "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" 335 | ) 336 | parser.add_argument("--k", type=int, default=0) 337 | parser.add_argument("--p", type=float, default=0.9) 338 | 339 | parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") 340 | parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") 341 | parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") 342 | 343 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 344 | parser.add_argument( 345 | "--use_cpu", 346 | action="store_true", 347 | help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available", 348 | ) 349 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") 350 | parser.add_argument( 351 | "--fp16", 352 | action="store_true", 353 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit", 354 | ) 355 | parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") 356 | parser.add_argument("--torch_dtype", type=str, default=None, help=( 357 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 358 | "dtype will be automatically derived from the model's weights." 359 | ), choices=["auto", "bfloat16", "float16", "float32"]) 360 | 361 | # sink cache related arguments 362 | parser.add_argument("--sink_cache", action="store_true", help="Whether to use sink cache.") 363 | parser.add_argument("--window_length", type=int, default=256, help="Window size for sink cache.") 364 | parser.add_argument("--num_sink_tokens", type=int, default=2, help="Number of sink tokens.") 365 | 366 | args = parser.parse_args() 367 | 368 | # Initialize the distributed state. 369 | distributed_state = PartialState(cpu=args.use_cpu) 370 | 371 | logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}") 372 | 373 | if args.seed is not None: 374 | set_seed(args.seed) 375 | 376 | # Initialize the model and tokenizer 377 | try: 378 | args.model_type = args.model_type.lower() 379 | model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 380 | except KeyError: 381 | raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") 382 | 383 | if args.tokenizer_name: 384 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) 385 | else: 386 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) 387 | if tokenizer.pad_token is None: 388 | tokenizer.pad_token = tokenizer.eos_token 389 | # maybe we need to override configs 390 | config = model_class.config_class.from_pretrained(args.model_name_or_path) 391 | if args.config_overrides is not None: 392 | logger.info(f"Overriding config: {args.config_overrides}") 393 | config.update_from_string(args.config_overrides) 394 | logger.info(f"New config: {config}") 395 | torch_dtype = ( 396 | args.torch_dtype 397 | if args.torch_dtype in ["auto", None] 398 | else getattr(torch, args.torch_dtype) 399 | ) 400 | model = model_class.from_pretrained(args.model_name_or_path, config=config, torch_dtype=torch_dtype, attn_implementation="flash_attention_2") 401 | 402 | # Set the model to the right device 403 | model.to(distributed_state.device) 404 | 405 | if args.fp16: 406 | model.half() 407 | # XXX: we should not adjust the length to the model configs 408 | # max_seq_length = getattr(model.config, "max_position_embeddings", 0) 409 | # args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length) 410 | logger.info(args) 411 | 412 | prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") 413 | 414 | # Different models need different input formatting and/or extra arguments 415 | requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() 416 | if requires_preprocessing: 417 | prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) 418 | preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) 419 | 420 | if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: 421 | tokenizer_kwargs = {"add_space_before_punct_symbol": True} 422 | else: 423 | tokenizer_kwargs = {} 424 | 425 | encoded_prompt = tokenizer.encode( 426 | preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs 427 | ) 428 | else: 429 | prefix = args.prefix if args.prefix else args.padding_text 430 | encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") 431 | encoded_prompt = encoded_prompt.to(distributed_state.device) 432 | 433 | if encoded_prompt.size()[-1] == 0: 434 | input_ids = None 435 | else: 436 | input_ids = encoded_prompt 437 | 438 | if args.jit: 439 | jit_input_texts = ["enable jit"] 440 | jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) 441 | torch._C._jit_set_texpr_fuser_enabled(False) 442 | model.config.return_dict = False 443 | if hasattr(model, "forward"): 444 | sig = inspect.signature(model.forward) 445 | else: 446 | sig = inspect.signature(model.__call__) 447 | jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) 448 | traced_model = torch.jit.trace(model, jit_inputs, strict=False) 449 | traced_model = torch.jit.freeze(traced_model.eval()) 450 | traced_model(*jit_inputs) 451 | traced_model(*jit_inputs) 452 | 453 | model = _ModelFallbackWrapper(traced_model, model) 454 | 455 | kwargs = {} 456 | if hasattr(model.config, "layer_types"): 457 | parser = LayerTypeParser(model.config.layer_types) 458 | if parser.use_sliding_window(): 459 | kwargs["past_key_values"] = IndexedHybridCache(parser, model.config.sliding_window) 460 | if args.sink_cache: 461 | raise ValueError("Sliding window and sink cache cannot be used together.") 462 | if args.sink_cache: 463 | kwargs["past_key_values"] = SinkCache(args.window_length, args.num_sink_tokens) 464 | 465 | output_sequences = model.generate( 466 | input_ids=input_ids, 467 | max_length=args.length + len(encoded_prompt[0]), 468 | temperature=args.temperature, 469 | top_k=args.k, 470 | top_p=args.p, 471 | repetition_penalty=args.repetition_penalty, 472 | do_sample=True, 473 | num_return_sequences=args.num_return_sequences, 474 | **kwargs, 475 | ) 476 | 477 | # Remove the batch dimension when returning multiple sequences 478 | if len(output_sequences.shape) > 2: 479 | output_sequences.squeeze_() 480 | 481 | generated_sequences = [] 482 | 483 | for generated_sequence_idx, generated_sequence in enumerate(output_sequences): 484 | print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") 485 | generated_sequence = generated_sequence.tolist() 486 | 487 | # Decode text 488 | text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) 489 | 490 | # Remove all text after the stop token 491 | text = text[: text.find(args.stop_token) if args.stop_token else None] 492 | 493 | # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing 494 | total_sequence = ( 495 | prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] 496 | ) 497 | 498 | generated_sequences.append(total_sequence) 499 | print(total_sequence) 500 | 501 | return generated_sequences 502 | 503 | if __name__ == "__main__": 504 | main() 505 | -------------------------------------------------------------------------------- /run_generation.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # run generation 4 | python run_generation.py \ 5 | --model_type lckv-llama \ 6 | --torch_dtype bfloat16 \ 7 | --tokenizer_name TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T \ 8 | --model_name_or_path outputs/llamatiny-test \ 9 | --num_return_sequences 1 \ 10 | --prompt "the meaning of life is" \ 11 | --length 512 12 | 13 | # run streaming 14 | python run_generation.py \ 15 | --model_type lckv-llama \ 16 | --torch_dtype bfloat16 \ 17 | --tokenizer_name TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T \ 18 | --model_name_or_path outputs/llamatiny-test \ 19 | --num_return_sequences 1 \ 20 | --prompt "the meaning of life is" \ 21 | --length 2048 \ 22 | --sink_cache \ 23 | --window_length 1024 \ 24 | --num_sink_tokens 4 \ 25 | -------------------------------------------------------------------------------- /run_sft.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from trl.commands.cli_utils import SFTScriptArguments, TrlParser 17 | 18 | 19 | from datasets import load_dataset 20 | 21 | from transformers import AutoTokenizer 22 | import models 23 | 24 | from trl import ( 25 | ModelConfig, 26 | SFTConfig, 27 | SFTTrainer, 28 | get_peft_config, 29 | get_quantization_config, 30 | get_kbit_device_map, 31 | ) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) 36 | args, training_args, model_config = parser.parse_args_and_config() 37 | 38 | ################ 39 | # Model init kwargs & Tokenizer 40 | ################ 41 | quantization_config = get_quantization_config(model_config) 42 | model_kwargs = dict( 43 | revision=model_config.model_revision, 44 | trust_remote_code=model_config.trust_remote_code, 45 | attn_implementation=model_config.attn_implementation, 46 | torch_dtype=model_config.torch_dtype, 47 | use_cache=False, 48 | device_map=get_kbit_device_map() if quantization_config is not None else None, 49 | quantization_config=quantization_config, 50 | ) 51 | training_args.model_init_kwargs = model_kwargs 52 | tokenizer = AutoTokenizer.from_pretrained( 53 | model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True 54 | ) 55 | tokenizer.pad_token = tokenizer.eos_token 56 | 57 | ################ 58 | # Dataset 59 | ################ 60 | dataset = load_dataset(args.dataset_name) 61 | 62 | ################ 63 | # Training 64 | ################ 65 | trainer = SFTTrainer( 66 | model=model_config.model_name_or_path, 67 | args=training_args, 68 | train_dataset=dataset[args.dataset_train_split], 69 | eval_dataset=dataset[args.dataset_test_split], 70 | tokenizer=tokenizer, 71 | peft_config=get_peft_config(model_config), 72 | ) 73 | 74 | trainer.train() 75 | trainer.save_model(training_args.output_dir) 76 | -------------------------------------------------------------------------------- /run_sft.sh: -------------------------------------------------------------------------------- 1 | # Full training 2 | python run_sft.py \ 3 | --model_name_or_path outputs/llamatiny-test \ 4 | --dataset_name timdettmers/openassistant-guanaco \ 5 | --dataset_text_field text \ 6 | --max_seq_length 1024 \ 7 | --use_liger \ 8 | --bf16 \ 9 | --torch_dtype bfloat16 \ 10 | --attn_implementation flash_attention_2 \ 11 | --per_device_train_batch_size 16 \ 12 | --gradient_accumulation_steps 1 \ 13 | --learning_rate 1.41e-5 \ 14 | --logging_steps 1 \ 15 | --num_train_epochs 3 \ 16 | --report_to none \ 17 | --output_dir outputs/llamatiny-sft-test 18 | -------------------------------------------------------------------------------- /run_streaming.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | 6 | import requests 7 | import torch 8 | 9 | import models 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, logging 11 | from transformers.cache_utils import SinkCache 12 | from transformers.generation.streamers import TextStreamer 13 | 14 | 15 | logging.get_logger("transformers.tokenization_utils").setLevel(logging.ERROR) 16 | logging.get_logger("transformers.generation.utils").setLevel(logging.ERROR) 17 | logging.get_logger("transformers.models.llama.modeling_llama").setLevel(logging.ERROR) 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--model_name_or_path", type=str, default="lmsys/vicuna-13b-v1.3") 22 | parser.add_argument("--data_root", type=str, default="data/") 23 | 24 | # generation related arguments 25 | parser.add_argument("--length", type=int, default=1000) 26 | 27 | parser.add_argument( 28 | "--temperature", 29 | type=float, 30 | default=1.0, 31 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling", 32 | ) 33 | parser.add_argument( 34 | "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" 35 | ) 36 | parser.add_argument("--k", type=int, default=0) 37 | parser.add_argument("--p", type=float, default=0.9) 38 | 39 | # sink cache related arguments 40 | parser.add_argument("--sink_cache", action="store_true", help="Whether to use sink cache.") 41 | parser.add_argument("--window_length", type=int, default=256, help="Window size for sink cache.") 42 | parser.add_argument("--num_sink_tokens", type=int, default=2, help="Number of sink tokens.") 43 | 44 | args = parser.parse_args() 45 | 46 | # load model 47 | print(f"Loading model from {args.model_name_or_path} ...") 48 | tokenizer = AutoTokenizer.from_pretrained( 49 | args.model_name_or_path, 50 | trust_remote_code=True, 51 | chat_template=r""" 52 | {%- for message in messages %} 53 | {%- if message['role'] == 'user' %} 54 | {{- 'USER: ' + message['content'].strip() + '\n' }} 55 | {%- elif message['role'] == 'assistant' %} 56 | {{- 'ASSISTANT: ' + message['content'] + ' \n\n' }} 57 | {%- endif %} 58 | {%- endfor %} 59 | {%- if add_generation_prompt %} 60 | {{- 'ASSISTANT: ' }} 61 | {%- endif %} 62 | """ 63 | ) 64 | model = AutoModelForCausalLM.from_pretrained( 65 | args.model_name_or_path, 66 | device_map="auto", 67 | torch_dtype=torch.float16, 68 | trust_remote_code=True, 69 | attn_implementation="flash_attention_2", 70 | ) 71 | model.eval() 72 | 73 | # load data 74 | print(f"Loading data from {args.data_root} ...") 75 | mt_bench = Path(args.data_root) / "mt_bench.jsonl" 76 | if not mt_bench.exists(): 77 | print("Downloading mt_bench data ...") 78 | os.makedirs(args.data_root, exist_ok=True) 79 | with open(mt_bench, "w") as f: 80 | url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" 81 | response = requests.get(url) 82 | f.write(response.text) 83 | 84 | prompts = [] 85 | with open(mt_bench, "r") as f: 86 | for line in f: 87 | prompts += json.loads(line)["turns"] 88 | 89 | # streaming inference 90 | kwargs = {} 91 | if args.sink_cache: 92 | kwargs["past_key_values"] = SinkCache(args.window_length, args.num_sink_tokens) 93 | 94 | chat_history = [] 95 | streamer = TextStreamer(tokenizer, skip_prompt=True) 96 | for prompt in prompts: 97 | new_prompt = {"role": "user", "content": prompt} 98 | print(tokenizer.apply_chat_template([new_prompt], add_generation_prompt=True, tokenize=False), end="") 99 | 100 | chat_history.append(new_prompt) 101 | input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt").to(model.device) 102 | 103 | output_sequences = model.generate( 104 | input_ids=input_ids, 105 | max_new_tokens=args.length, 106 | temperature=args.temperature, 107 | top_k=args.k, 108 | top_p=args.p, 109 | repetition_penalty=args.repetition_penalty, 110 | do_sample=True, 111 | streamer=streamer, 112 | **kwargs, 113 | ) 114 | 115 | chat_history.append({"role": "assistant", "content": tokenizer.decode(output_sequences[0, input_ids.shape[-1]:], skip_special_tokens=True)}) 116 | print() 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /run_streaming.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # streaming test 4 | python test_streaming.py \ 5 | --model_name_or_path outputs/llamatiny-test \ 6 | --output_dir streaming/llamatiny-test \ 7 | --dataset_name pg19 \ 8 | --download_streaming \ 9 | --num_eval_tokens 4000000 \ 10 | --sink_cache \ 11 | --num_sink_tokens 4 \ 12 | --window_length 1024 \ 13 | 14 | # performance test 15 | python test_streaming.py \ 16 | --model_name_or_path outputs/llamatiny-test \ 17 | --output_dir streaming/llamatiny-2048 \ 18 | --dataset_name pg19 \ 19 | --download_streaming \ 20 | --num_eval_tokens 65133 \ 21 | --sink_cache \ 22 | --num_sink_tokens 4 \ 23 | --window_length 2048 \ 24 | -------------------------------------------------------------------------------- /test_harness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | import datasets 6 | import lm_eval 7 | 8 | import models 9 | 10 | 11 | datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T") 16 | parser.add_argument("--dtype", type=str, default="bfloat16") 17 | parser.add_argument("--tasks", type=str, nargs="+", default=["hellaswag", "openbookqa", "winogrande", "arc_challenge", "arc_easy", "boolq", "piqa"]) 18 | parser.add_argument("--num_fewshot", type=int, default=0) 19 | parser.add_argument("--output", type=str, default="harness/result.json") 20 | 21 | args = parser.parse_args() 22 | 23 | task_manager = lm_eval.tasks.TaskManager() 24 | results = lm_eval.simple_evaluate( # call simple_evaluate 25 | model="hf", 26 | model_args=f"pretrained={args.model_name_or_path},dtype={args.dtype},attn_implementation=flash_attention_2", 27 | tasks=args.tasks, 28 | num_fewshot=args.num_fewshot, 29 | log_samples=False, 30 | task_manager=task_manager, 31 | ) 32 | 33 | output_path = Path(args.output) 34 | output_path.parent.mkdir(parents=True, exist_ok=True) 35 | 36 | with open(output_path, "w") as f: 37 | json.dump(results, f, indent=4, default=lambda o: '') 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /test_latency.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections.abc import Iterable 3 | 4 | import torch 5 | from accelerate import PartialState 6 | from accelerate.utils import set_seed 7 | from tqdm import tqdm 8 | 9 | from models import LCKVLlamaConfig 10 | from transformers import ( 11 | AutoConfig, 12 | AutoModelForCausalLM, 13 | AutoTokenizer, 14 | logging, 15 | ) 16 | from transformers.trainer_pt_utils import get_model_param_count 17 | 18 | 19 | set_seed(42) 20 | 21 | logging.get_logger("transformers.tokenization_utils").setLevel(logging.ERROR) 22 | logging.get_logger("transformers.generation.utils").setLevel(logging.ERROR) 23 | logging.get_logger("transformers.models.llama.modeling_llama").setLevel(logging.ERROR) 24 | 25 | prompt_text_2048 = """The meaning of life is the solution to life's problems.' 26 | 27 | It was difficult to know whether the speech was a lie or not. At first Lily was determined to accept that her brother was deeply ill and needed my help in trying to help him. However, this was hardly the right way to treat him. He was a healthy young man and I didn't want to upset his sisters by making them feel as though they had been abandoned. So Lily couldn't see the point of my helping. 28 | 29 | 'Is it true that you only said you were ill because you were selfish?' 30 | 31 | 'Well, yes, and I'm sorry. And, in a way, yes. Because it's true that I care about you. But first and foremost, I care about my brother.' 32 | 33 | 'OK,' Lily said, clearly annoyed at the interference in her own life. 'I do see the point of you.' 34 | 35 | 'So why don't you, forgive me?' 36 | 37 | 'Why don't you forgive me for what?' 38 | 39 | 'The part I played in making you miserable.' 40 | 41 | 'I don't forgive myself. I'm simply uncomfortable to talk to people like you.' 42 | 43 | 'OK, well, maybe it's not fair of me to start talking to people about myself. My mum is not going to come on the same terms as me,' I said defensively. 44 | 45 | 'I guess not,' Lily said. 46 | 47 | 'Yes, she is coming round,' I said hopefully. 'Better late than never.' 48 | 49 | 'So why do you come to see me?' Lily asked me. 'I mean, I was surprised. I've been neglecting you. You don't like me,' she said, somewhat alarmed at the harshness of her tone. 50 | 51 | 'Don't ask me to explain,' I told her, 'I don't really know.' 52 | 53 | 'But I thought you wanted to help,' Lily said slowly. 'What were you expecting me to do? When she became ill, you just seemed so upset about it.' 54 | 55 | 'I'm not in a good place right now,' I admitted. 'I've always hated end-of-life situations. I just don't know how to deal with it, Lily. It makes me feel bad.' 56 | 57 | 'I don't think I understand,' Lily replied. 'Have you tried to talk to her?' 58 | 59 | 'OK, I'll try. She used to live with me, but then she left.' 60 | 61 | 'Do you want to find out why she's given you up?' Lily asked. 62 | 63 | 'She wouldn't talk to me about it,' I said, embarrassed by the pressure I was under. 'She's tried talking to me, but I'm not open to anything.' 64 | 65 | 'You don't look open to anything,' Lily said with a note of contempt. 'She didn't put it that way. She was in love with you. I thought she wanted you to marry her.' 66 | 67 | 'She didn't say anything like that,' I said softly, biting my tongue. 'She just... didn't talk to me.' 68 | 69 | Lily's tone changed from brusque to disapproving. 'I don't care if you are in love with her. She's leaving you.' 70 | 71 | 'Lily, I want to know why she's leaving you.' 72 | 73 | 'Because she's giving you up,' Lily declared. 'That's why.' 74 | 75 | 'I don't want to know.' I glanced over my shoulder. The door was open a crack. I wondered how it was that my brother could have found his way out of it. I wondered if my brother had had help. 76 | 77 | 'You don't want to know the answer?' Lily persisted. 'Get me a drink. Maybe you can find out what you're looking for.' 78 | 79 | I decided to go for it. Lily had the bottle of whisky that I bought for her on her birthday a few months earlier. It was a lovely bottle with a small picture of our father on it. 80 | 81 | 'Well, I'm not leaving you,' Lily told me. 'Go and sit with my mum. I'm going to check on Peter. You need to concentrate on getting through this morning. You can give your sister some solace later.' 82 | 83 | 'Lily, I don't want to see any of you,' I told her. 'You told me you weren't going to make it to the pub but I'm already here.' 84 | 85 | 'OK, don't make me answer that,' Lily said in a firm tone. 'Look, I'm not convinced that any of you are going to make it. I've been left alone for far too long. If I stay for a while, I might be able to take the edge off your stress, but I'm not going to make it my business to see how you're doing. Goodbye, Lily.' 86 | 87 | 'OK,' I told her. 'Bye.' 88 | 89 | 'Where are you going?' Lily asked me from the door. 90 | 91 | 'To see my mum,' I told her. 'I'm taking the day off.' 92 | 93 | 'OK,' Lily said with a slight smile, 'good luck with that.' 94 | 95 | She then went downstairs to order something to eat from the freezer. I stayed with my mum for a few hours, because she wouldn't leave. She told me she was now staying with the sisters, and that she had just been to see Peter. I told her about Peter's wife and how she had made her decision. 96 | 97 | 'Don't you dare leave me,' she told me. 'I'm not going to let you have this bit between us any more.' 98 | 99 | 'You know you have to work some more,' I told her. 'I'm getting up, so will you have a nice day?' 100 | 101 | She said she would. I was touched by the way she told me she didn't hate me. 'I don't want to hear about anything, just know that I'm thinking of you. No matter what happens, I'm there.' 102 | 103 | Then, I packed up my things and made my way to Lily's house. I had to make an effort to get her to talk about her brother. 104 | 105 | 'OK,' I told her, 'I guess I don't have to talk about Peter. I don't want to ask any questions. But, you don't mind if I do, do you?' 106 | 107 | 'Not if you do,' Lily told me. 'But I'd like to know why you're upset.' 108 | 109 | 'You tell me when you want to.' 110 | 111 | 'OK,' Lily said, then shook her head. 'I don't think I'll be telling you. I'll make it my business to see Peter.' 112 | 113 | 'Good,' I said, somewhat disappointed. 114 | 115 | 'Goodbye.' 116 | 117 | 'Goodbye,' I said and walked away. I couldn't help but feel that the night was too soon for Lily to suddenly get so upset about something like this. 118 | 119 | So, what was I going to do? I didn't like getting up at nine in the morning, not just to be alone but to get up. I felt as if I couldn't handle being alone for long periods of time. Maybe I would go for a walk. Or, if I wasn't feeling up to that, I would make some sandwiches and I would go for a stroll. I could go to the pub, and talk about something that didn't involve anything very serious. If I went to the pub and wasn't forced to say anything, then I could relax and get on with my day. 120 | 121 | I went for a stroll down the hill and ended up going into London. It was a lovely early spring morning. The trees were still very early, so I stopped at a shop to buy some fresh produce for lunch. As I was buying it, a cyclist approached me and asked me how I was. I didn't know what to say. He was really nice, though, so I explained about the situations of my brother and sister. 122 | 123 | 'I'm not sure how you're supposed to take that,' he told me, not sure what to say, 'but, hopefully, everything will work out for the best. I'm going off to work today, so I'll see you when I get back to Tottenham.' 124 | 125 | 'Ok,' I told him. 126 | 127 | 'Goodbye,' he added. 128 | 129 | I must have looked relieved when I said goodbye. He did head to work and I went home to find Lily sitting in her chair on her sofa. 130 | 131 | 'So how are you? How is the Wiltshire weather?' 132 | 133 | Lily told me she had been back on a walk in that great, new National Trust park. 'It was lovely,' she said. 'The sheep ran wild.' 134 | 135 | 'Did you manage to see who was out hunting?' I asked. 'Lily, are you watching it on television?' 136 | 137 | Lily nodded. 138 | 139 | 'OK,' I said. 'Do you want me to watch?' 140 | 141 | 'You'""" 142 | prompt_text_512 = """The meaning of life is the solution to life's problems.' 143 | 144 | It was difficult to know whether the speech was a lie or not. At first Lily was determined to accept that her brother was deeply ill and needed my help in trying to help him. However, this was hardly the right way to treat him. He was a healthy young man and I didn't want to upset his sisters by making them feel as though they had been abandoned. So Lily couldn't see the point of my helping. 145 | 146 | 'Is it true that you only said you were ill because you were selfish?' 147 | 148 | 'Well, yes, and I'm sorry. And, in a way, yes. Because it's true that I care about you. But first and foremost, I care about my brother.' 149 | 150 | 'OK,' Lily said, clearly annoyed at the interference in her own life. 'I do see the point of you.' 151 | 152 | 'So why don't you, forgive me?' 153 | 154 | 'Why don't you forgive me for what?' 155 | 156 | 'The part I played in making you miserable.' 157 | 158 | 'I don't forgive myself. I'm simply uncomfortable to talk to people like you.' 159 | 160 | 'OK, well, maybe it's not fair of me to start talking to people about myself. My mum is not going to come on the same terms as me,' I said defensively. 161 | 162 | 'I guess not,' Lily said. 163 | 164 | 'Yes, she is coming round,' I said hopefully. 'Better late than never.' 165 | 166 | 'So why do you come to see me?' Lily asked me. 'I mean, I was surprised. I've been neglecting you. You don't like me,' she said, somewhat alarmed at the harshness of her tone. 167 | 168 | 'Don't ask me to explain,' I told her, 'I don't really know.' 169 | 170 | 'But I thought you wanted to help,' Lily said slowly. 'What were you expecting me to do? When she became ill, you just seemed so upset about it.' 171 | 172 | 'I'm not in a good place right now,' I admitted. 'I've always hated end-of-life situations. I just don't know how to deal with it, Lily. It makes me""" 173 | prompt_text_5 = "The meaning of life is" 174 | prompt_text_4096 = prompt_text_2048*2 175 | 176 | 177 | def empty_cache(): 178 | for _ in range(10): 179 | torch.cuda.empty_cache() 180 | 181 | class BinarySearch: 182 | """ 183 | Binary Search w/o maximum limit, search the upper bound by doubling the index. 184 | """ 185 | def __init__(self, min: int = 0): 186 | self.low = min 187 | self.high = None 188 | self.nxt = min + 1 189 | 190 | def __iter__(self): 191 | return self 192 | 193 | def __next__(self): 194 | if self.high is None: 195 | return self.nxt 196 | elif self.low == self.high: 197 | raise StopIteration 198 | else: 199 | return self.nxt 200 | 201 | def report(self, idx: int, result: bool): 202 | if self.high is None: 203 | if idx <= self.low: 204 | if not result: 205 | raise ValueError(f"We've proven that {self.low} is feasible, but {idx} is not.") 206 | else: 207 | if not result: 208 | self.high = idx - 1 209 | self.nxt = (self.low + self.high) // 2 210 | else: 211 | self.low = idx 212 | self.nxt = idx * 2 213 | else: 214 | if idx < self.low or idx > self.high: 215 | raise ValueError(f"Index {idx} is out of range [{self.low}, {self.high}]") 216 | else: 217 | if not result: 218 | self.high = idx - 1 219 | else: 220 | self.low = idx 221 | self.nxt = (self.low + self.high) // 2 222 | if self.nxt == self.low: 223 | self.nxt = self.high 224 | 225 | class Streamer: 226 | def __init__(self, pbar): 227 | self.pbar = pbar 228 | 229 | def put(self, x): 230 | self.pbar.update(1) 231 | 232 | def end(self, *args, **kwargs): 233 | self.pbar.close() 234 | 235 | def get_ds_model(model_name, config, dtype, cpu_offload, disk_offload, offload_dir, num_gpus = 1): 236 | import deepspeed 237 | import torch.distributed as dist 238 | from transformers.deepspeed import HfDeepSpeedConfig 239 | 240 | hidden_size = config.hidden_size 241 | # debug: ModuleNotFoundError: No module named 'mpi4py' 242 | ## launch with deepspeed .py 243 | deepspeed.init_distributed("nccl") 244 | rank = dist.get_rank() 245 | pin_memory = True 246 | 247 | ds_config = { 248 | "fp16": { 249 | "enabled": dtype == torch.float16, 250 | }, 251 | "bf16": { 252 | "enabled": dtype == torch.bfloat16, 253 | }, 254 | "zero_optimization": { 255 | "stage": 3, 256 | "stage3_prefetch_bucket_size": hidden_size * hidden_size, 257 | "stage3_param_persistence_threshold": 0, 258 | }, 259 | "steps_per_print": 2000, 260 | "train_batch_size": 1, 261 | "wall_clock_breakdown": False, 262 | } 263 | 264 | if cpu_offload: 265 | ds_config["zero_optimization"]["offload_param"] = { 266 | "device": "cpu", 267 | "pin_memory": pin_memory 268 | } 269 | 270 | if disk_offload: 271 | ds_config["zero_optimization"]["offload_param"] = { 272 | "device": "nvme", 273 | "pin_memory": True, 274 | "nvme_path": offload_dir, 275 | "buffer_count": 5, 276 | "buffer_size": 2 * (1 << 30), 277 | } 278 | ds_config["aio"] = { 279 | "block_size": 1048576, 280 | "queue_depth": 8, 281 | "thread_count": 1, 282 | "single_submit": False, 283 | "overlap_events": True, 284 | } 285 | 286 | # dschf = HfDeepSpeedConfig(ds_config) 287 | 288 | model = AutoModelForCausalLM.from_pretrained( 289 | model_name, 290 | config=config, 291 | torch_dtype=dtype 292 | ) 293 | model = model.eval() 294 | ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0] 295 | ds_engine.module.eval() 296 | model = ds_engine.module 297 | 298 | return model 299 | 300 | def get_hf_model(model_name, config, dtype, cpu_offload, disk_offload, offload_dir, num_gpus): 301 | if num_gpus == 1 and dtype != torch.int8: 302 | # Here we use a custom device_map instead of device_map == "auto" 303 | # becase we want to offload as many as possible weights out of GPU 304 | # to allow a larger batch size. 305 | if cpu_offload: 306 | # NOTE: We must put some weights on GPU. Otherwise, huggingface reports errors. 307 | device_map = { 308 | "model.embed_tokens.weight": 0, 309 | "model.norm": "cpu", 310 | "model.layers": "cpu", 311 | "lm_head.weight": 0, 312 | } 313 | elif disk_offload: 314 | device_map = { 315 | "model.embed_tokens.weight": 0, 316 | "model.norm": "disk", 317 | "model.layers": "disk", 318 | "lm_head.weight": 0, 319 | } 320 | else: 321 | device_map = None 322 | max_memory = None 323 | else: 324 | # Here we use device_map == "auto", but set a low `max_memory` threshold 325 | # becase we want to offload as many as possible weights out of GPU 326 | # to allow a larger batch size. 327 | device_map = "auto" 328 | if cpu_offload: 329 | # `max_memory` should be larger than the embedding. 330 | # We use 2GB here because the embeding of opt-175b is 1.2GB. 331 | max_memory = {k: "2GB" for k in range(num_gpus)} 332 | elif disk_offload: 333 | max_memory = {k: "2GB" for k in range(num_gpus)} 334 | else: 335 | max_memory = {k: "14GB" for k in range(num_gpus)} 336 | max_memory["cpu"] = "160GB" 337 | 338 | if dtype == torch.int8: 339 | kwargs = {"load_in_8bit": True} 340 | else: 341 | kwargs = {"torch_dtype": dtype} 342 | 343 | model = AutoModelForCausalLM.from_pretrained(model_name, config=config, 344 | device_map=device_map, max_memory=max_memory, 345 | offload_folder=offload_dir, **kwargs) 346 | if device_map is None: 347 | model.cuda() 348 | 349 | model.eval() 350 | return model 351 | 352 | 353 | distributed_state = PartialState() 354 | 355 | def prepare(model: str, size: str, cpu_offload: str = "none", warmup: int = 2): 356 | """ 357 | Prepare the tokenizer and model. 358 | 359 | Arguments: 360 | model (str): The model name. Options are "llama" and "lckv-llama". 361 | size (str): The model size. Options are "50m", "1.1b", "7b", and "30b". 362 | cpu_offload (str): The CPU offload strategy. Options are "none", "hf", and "ds". 363 | warmup (int): The number of warmup layers for LCKV-LLAMA. 364 | """ 365 | 366 | CONFIG_MAPPING = { 367 | "50m": "configs/llama_tiny.json", 368 | "1.1b": "TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T", 369 | "7b": "huggyllama/llama-7b", # "yahma/llama-7b-hf" 370 | "30b": "huggyllama/llama-30b", 371 | } 372 | 373 | # prepare tokenizer 374 | tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") 375 | 376 | # prepare config 377 | if model == "lckv-llama": 378 | config = LCKVLlamaConfig.from_pretrained(CONFIG_MAPPING[size]) 379 | config._attn_implementation = "flash_attention_2" 380 | start, end = int(warmup) // 2, config.num_hidden_layers - int(warmup) // 2 381 | layer_types = [(end - 1 if i in range(start, end) else i) for i in range(config.num_hidden_layers)] 382 | config.layer_types = "_".join(map(str, layer_types)) 383 | 384 | elif model == "llama": 385 | config = AutoConfig.from_pretrained(CONFIG_MAPPING[size]) 386 | config._attn_implementation = "flash_attention_2" 387 | 388 | else: 389 | raise ValueError(f"Unknown model {model}") 390 | 391 | # prepare model 392 | if cpu_offload == "none": 393 | model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16) 394 | model.to(distributed_state.device) 395 | 396 | elif cpu_offload == "hf": 397 | model = get_hf_model( 398 | model_name = CONFIG_MAPPING[size], 399 | config = config, 400 | dtype = torch.bfloat16, 401 | cpu_offload = True, 402 | disk_offload = False, 403 | offload_dir = None, 404 | num_gpus = 1 405 | ) 406 | 407 | elif cpu_offload == "ds": 408 | model = get_ds_model( 409 | model_name = CONFIG_MAPPING[size], 410 | config = config, 411 | dtype = torch.bfloat16, 412 | cpu_offload = True, 413 | disk_offload = False, 414 | offload_dir = None 415 | ) 416 | 417 | else: 418 | raise ValueError(f"Unknown CPU offload strategy {cpu_offload}") 419 | 420 | print("# of trainable parameters:", get_model_param_count(model, True)) 421 | model.eval() 422 | 423 | return tokenizer, model 424 | 425 | 426 | def inject_callback(model, callback): 427 | """inject a callback into the model. 428 | """ 429 | forward_func = model.forward 430 | 431 | def forward(self, *args, **kwargs): 432 | result = forward_func(*args, **kwargs) 433 | callback() 434 | return result 435 | 436 | model.forward = forward.__get__(model, type(model)) 437 | 438 | 439 | def experiment(tokenizer, model, prompt, max_new_tokens, iterator=None, verbose=False): 440 | 441 | input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") 442 | 443 | # move to cuda 444 | input_ids = input_ids.to(distributed_state.device) 445 | model.eval() 446 | 447 | # Pre-warmup the GPU 448 | _ = model.generate(input_ids[:1,:1], do_sample=True, max_length=50) 449 | 450 | def callback(): 451 | if callback.timer is None: 452 | callback.timer = time.time() 453 | inject_callback(model, callback) 454 | 455 | def run(num_return_sequences = 32): 456 | 457 | callback.timer = None 458 | 459 | if verbose: 460 | pbar = tqdm(total=max_new_tokens, leave=False, desc="Generating") 461 | streamer = Streamer(pbar) 462 | else: 463 | streamer = None 464 | 465 | start = time.time() 466 | 467 | model.generate( 468 | input_ids=input_ids, 469 | max_new_tokens=max_new_tokens, 470 | temperature=1.0, 471 | top_k=0, 472 | top_p=0.9, 473 | repetition_penalty=1.0, 474 | do_sample=True, 475 | num_return_sequences=num_return_sequences, 476 | streamer=streamer, 477 | ) 478 | 479 | end = time.time() 480 | 481 | return start, callback.timer, end 482 | 483 | if isinstance(iterator, Iterable): 484 | 485 | for i in iterator: 486 | # print the current date and time 487 | try: 488 | start, mid, end = run(i) 489 | length = max_new_tokens 490 | print(f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] bsz: {i:3}, time: {end-start:.3f} = {mid-start:.3f} + {end-mid:.3f}, throughput: {i*length/(end-start):.3f} tokens/s") 491 | except Exception as e: 492 | if not str(e).startswith("CUDA out of memory."): 493 | print(e) 494 | raise 495 | return 496 | empty_cache() 497 | 498 | else: 499 | 500 | iterator = BinarySearch(0 if iterator is None else iterator) 501 | for i in iterator: 502 | try: 503 | start, mid, end = run(i) 504 | length = max_new_tokens 505 | print(f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] bsz: {i:3}, time: {end-start:.3f} = {mid-start:.3f} + {end-mid:.3f}, throughput: {i*length/(end-start):.3f} tokens/s") 506 | iterator.report(i, True) 507 | except Exception as e: 508 | print(f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] bsz: {i:3}, FAIL") 509 | if not str(e).startswith("CUDA out of memory."): 510 | print(e) 511 | raise 512 | iterator.report(i, False) 513 | empty_cache() 514 | return iterator.nxt 515 | 516 | 517 | def main(): 518 | 519 | # do not consequectively run the same model, it will cause CUDA out of memory 520 | 521 | print(">>> 7b 2048+2048 llama") 522 | tokenizer, model = prepare("llama", "7b") 523 | experiment(tokenizer, model, prompt_text_2048, max_new_tokens=2048) 524 | 525 | print(">>> 7b 2048+2048 lckv-llama") 526 | tokenizer, model = prepare("lckv-llama", "7b") 527 | experiment(tokenizer, model, prompt_text_2048, max_new_tokens=2048) 528 | 529 | print(">>> 7b 2048+2048 lckv-llama w=10") 530 | tokenizer, model = prepare("lckv-llama", "7b", warmup=10) 531 | experiment(tokenizer, model, prompt_text_2048, max_new_tokens=2048) 532 | 533 | 534 | if __name__ == "__main__": 535 | main() 536 | -------------------------------------------------------------------------------- /test_streaming.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | from datasets import load_dataset 6 | from torch.nn import CrossEntropyLoss 7 | from tqdm import tqdm 8 | 9 | import models 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | from transformers.cache_utils import SinkCache 12 | 13 | 14 | device = "cuda" if torch.cuda.is_available() else "cpu" 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T" 20 | ) 21 | parser.add_argument("--revision", type=str, default="main") 22 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None) 23 | parser.add_argument("--dataset_name", type=str, default=None) 24 | parser.add_argument("--task", type=str, default=None) 25 | parser.add_argument( 26 | "--split", type=str, default="test", choices=["validation", "test"] 27 | ) 28 | parser.add_argument( 29 | "--download_streaming", action="store_true", help="enable streaming download" 30 | ) 31 | 32 | parser.add_argument( 33 | "--num_samples", 34 | type=int, 35 | default=None, 36 | ) 37 | 38 | parser.add_argument( 39 | "--output_dir", 40 | type=str, 41 | default="streaming/debug", 42 | ) 43 | 44 | parser.add_argument( 45 | "--torch_dtype", 46 | type=str, 47 | default="bfloat16", 48 | ) 49 | 50 | # sink cache related arguments 51 | parser.add_argument("--sink_cache", action="store_true", help="Whether to use sink cache.") 52 | parser.add_argument("--window_length", type=int, default=256, help="Window size for sink cache.") 53 | parser.add_argument("--num_sink_tokens", type=int, default=2, help="Number of sink tokens.") 54 | 55 | parser.add_argument("--num_eval_tokens", type=int, default=None) 56 | 57 | args = parser.parse_args() 58 | return args 59 | 60 | def load(model_name_or_path, torch_dtype): 61 | print(f"Loading model from {model_name_or_path} ...") 62 | # if only model type is specified, load from scratch 63 | if ";" in model_name_or_path: 64 | from test_latency import prepare 65 | tokenizer, model = prepare(*model_name_or_path.split(";")) 66 | return model, tokenizer 67 | # however, tensor parallel for running falcon will occur bugs 68 | tokenizer = AutoTokenizer.from_pretrained( 69 | model_name_or_path, 70 | trust_remote_code=True, 71 | ) 72 | torch_dtype = ( 73 | torch_dtype 74 | if torch_dtype in ["auto", None] 75 | else getattr(torch, torch_dtype) 76 | ) 77 | model = AutoModelForCausalLM.from_pretrained( 78 | model_name_or_path, 79 | device_map="auto", 80 | torch_dtype=torch_dtype, 81 | trust_remote_code=True, 82 | ) 83 | if tokenizer.pad_token_id is None: 84 | if tokenizer.eos_token_id is not None: 85 | tokenizer.pad_token_id = tokenizer.eos_token_id 86 | else: 87 | tokenizer.pad_token_id = 0 88 | model.eval() 89 | return model, tokenizer 90 | 91 | def main(): 92 | 93 | args = parse_args() 94 | 95 | data = load_dataset(args.dataset_name, args.task, split=args.split, streaming=args.download_streaming) 96 | if args.num_samples is not None: 97 | data = data.select(range(args.num_samples)) 98 | 99 | model, tokenizer = load(args.model_name_or_path, args.torch_dtype) 100 | 101 | nlls = [] 102 | loss_fn = CrossEntropyLoss(reduction="none") 103 | 104 | # streaming inference 105 | past_key_values = None 106 | if args.sink_cache: 107 | past_key_values = SinkCache(args.window_length, args.num_sink_tokens) 108 | 109 | ## uncomment the following lines to enable latency measurement 110 | os.makedirs(args.output_dir, exist_ok=True) 111 | with open(f"{args.output_dir}/log.txt", "w") as f: 112 | 113 | num_eval_tokens = 0 114 | for item in data: 115 | text = item['text'] 116 | encodings = tokenizer(text, return_tensors="pt") 117 | 118 | print(encodings.input_ids[:, :10]) 119 | 120 | seq_len = encodings.input_ids.size(1) 121 | print(f"num_eval_tokens: {num_eval_tokens}, seq_len: {seq_len}") 122 | pbar = tqdm(range(0, seq_len - 1)) 123 | 124 | # import time 125 | for idx in pbar: 126 | # if idx == args.start_size + args.recent_size: 127 | # print("Starting timer...") 128 | # start = time.time() 129 | input_ids = encodings.input_ids[:, idx : idx + 1].to(device) 130 | with torch.no_grad(): 131 | outputs = model( 132 | input_ids, 133 | past_key_values=past_key_values, 134 | use_cache=True, 135 | ) 136 | logits = outputs.logits.view(-1, model.config.vocab_size) 137 | past_key_values = outputs.past_key_values 138 | label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1) 139 | neg_log_likelihood = loss_fn(logits, label) 140 | nlls.append(neg_log_likelihood) 141 | pbar.set_description( 142 | f"nll: {neg_log_likelihood.item():.2f}, ppl: {torch.exp(neg_log_likelihood).item():.2f}" 143 | ) 144 | print(neg_log_likelihood.item(), file=f, flush=True) 145 | num_eval_tokens += 1 146 | if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens: 147 | # print(f"time: {time.time() - start:.2f}") 148 | break 149 | if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens: 150 | break 151 | 152 | ppl = torch.exp(torch.stack(nlls).mean()) 153 | print(ppl.item()) 154 | with open(f"{args.output_dir}/ppl.txt", "w") as f: 155 | f.write(f"{ppl.item()}\n") 156 | 157 | if __name__ == "__main__": 158 | main() 159 | -------------------------------------------------------------------------------- /tests/test_kernel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from transformers import is_torch_available 4 | from transformers.testing_utils import require_torch_gpu 5 | 6 | 7 | if is_torch_available(): 8 | import torch 9 | 10 | from models.kernel import liger_rotary 11 | from models.modeling_lckv import apply_rotary 12 | 13 | 14 | @require_torch_gpu 15 | class LigerRotaryTest(unittest.TestCase): 16 | def test_liger_rotary(self): 17 | # Test case: Test liger_rotary function 18 | q = torch.randn(2, 3, 4, 6, device="cuda") 19 | freq = torch.randn(1, 4, 3, device="cuda") 20 | embed = torch.cat((freq, freq), dim=-1) 21 | cos = embed.cos() 22 | sin = embed.sin() 23 | unsqueeze_dim = 1 24 | 25 | result_q = liger_rotary(q, cos, sin, unsqueeze_dim) 26 | self.assertEqual(result_q.shape, (2, 3, 4, 6)) 27 | 28 | ref_q = apply_rotary(q, cos, sin, unsqueeze_dim) 29 | self.assertTrue(torch.allclose(result_q, ref_q)) 30 | 31 | 32 | if __name__ == '__main__': 33 | unittest.main() 34 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from models.utils import LayerTypeParser 4 | from transformers import is_torch_available 5 | from transformers.testing_utils import require_flash_attn, require_torch_gpu 6 | 7 | 8 | if is_torch_available(): 9 | import torch 10 | 11 | from models.utils import flash_attention_forward 12 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 13 | 14 | 15 | class LayerTypeParserTest(unittest.TestCase): 16 | def test_init_with_valid_input(self): 17 | # Test case: Initialization with valid input 18 | layer_type = "1_2s_0" 19 | parser = LayerTypeParser(layer_type) 20 | self.assertEqual([item.attends_to for item in parser], [1, 2, 0]) 21 | self.assertEqual([item.use_sliding_window for item in parser], [False, True, False]) 22 | 23 | def test_init_with_invalid_input(self): 24 | # Test case: Initialization with invalid input 25 | with self.assertRaises(Exception): 26 | LayerTypeParser("invalid_input") 27 | 28 | with self.assertRaises(Exception): 29 | LayerTypeParser("0_1t_2") 30 | 31 | def test_init_with_empty_string(self): 32 | # Test case: Initialization with an empty string 33 | with self.assertRaises(Exception): 34 | LayerTypeParser("") 35 | 36 | def test_len(self): 37 | # Test case: Check the length of the LayerTypeParser object 38 | layer_type = "1_2s_0" 39 | parser = LayerTypeParser(layer_type) 40 | self.assertEqual(len(parser), 3) 41 | 42 | def test_getitem(self): 43 | # Test case: Get the layer type information for a specific layer index 44 | layer_type = "1_2s_0" 45 | parser = LayerTypeParser(layer_type) 46 | layer_type_info = parser[1] 47 | self.assertEqual(layer_type_info.attends_to, 2) 48 | self.assertEqual(layer_type_info.attends_top, True) 49 | self.assertEqual(layer_type_info.use_sliding_window, True) 50 | 51 | def test_use_sliding_window(self): 52 | # Test case: Check if there exists a layer that uses sliding window attention 53 | layer_type = "1_2s_0" 54 | parser = LayerTypeParser(layer_type) 55 | self.assertEqual(parser.use_sliding_window(), True) 56 | 57 | layer_type = "0_0_2" 58 | parser = LayerTypeParser(layer_type) 59 | self.assertEqual(parser.use_sliding_window(), False) 60 | 61 | def test_attends_top(self): 62 | # Test case: Check if there exists a layer that attends to layers above it 63 | layer_type = "1_2s_0" 64 | parser = LayerTypeParser(layer_type) 65 | self.assertEqual(parser.attends_top(), True) 66 | 67 | layer_type = "0s_0s_2" 68 | parser = LayerTypeParser(layer_type) 69 | self.assertEqual(parser.attends_top(), False) 70 | 71 | def test_iteration_plan(self): 72 | # Test case: Check the iteration plan for the layer types 73 | layer_type = "1_2s_0" 74 | parser = LayerTypeParser(layer_type) 75 | iteration_plan = parser.iteration_plan(forward_passes=7, backward_passes=2) 76 | # Add assertions for the iteration plan 77 | self.assertEqual(len(iteration_plan), 9) 78 | 79 | # each layer should be updated exactly once 80 | updated_slices = [step.layer_slice for step in iteration_plan if step.update] 81 | updated_layers = [set(range(len(parser))[layer_slice]) for layer_slice in updated_slices] 82 | self.assertEqual(set.union(*updated_layers), set(range(len(parser)))) 83 | 84 | # cyclic dependencies should be resolved 85 | self.assertEqual([step.requires_grad for step in iteration_plan], [False, False, False, False, False, False, False, True, True]) 86 | 87 | 88 | # Test for the case where there is no cyclic dependency 89 | layer_type = "0_1s_2" 90 | parser = LayerTypeParser(layer_type) 91 | iteration_plan = parser.iteration_plan(forward_passes=7, backward_passes=2) 92 | self.assertEqual(len(iteration_plan), 1) 93 | self.assertTrue(iteration_plan[0].requires_grad) 94 | self.assertTrue(iteration_plan[0].update) 95 | 96 | def test_check(self): 97 | # Test case: Check if the layer type is valid 98 | num_hidden_layers = 3 99 | layer_type = "1_2s_0" 100 | parser = LayerTypeParser(layer_type) 101 | self.assertIsNone(parser.check(num_hidden_layers)) 102 | 103 | # Test case: Check for invalid layer type 104 | num_hidden_layers = 3 105 | layer_type = "1_2s_3" 106 | parser = LayerTypeParser(layer_type) 107 | with self.assertRaises(Exception): 108 | parser.check(num_hidden_layers) 109 | 110 | num_hidden_layers = 3 111 | layer_type = "0_1_2s_0" 112 | parser = LayerTypeParser(layer_type) 113 | with self.assertRaises(Exception): 114 | parser.check(num_hidden_layers) 115 | 116 | 117 | @require_torch_gpu 118 | @require_flash_attn 119 | class FlashAttentionForwardTest(unittest.TestCase): 120 | def test_no_diag(self): 121 | # Test case: Test flash_attention_forward with no_diag=True 122 | query_states = torch.randn(2, 5, 3, 4, dtype=torch.bfloat16, device="cuda") 123 | key_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda") 124 | value_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda") 125 | attention_mask = None 126 | query_length = 5 127 | is_causal = True 128 | no_diag = True 129 | 130 | result = flash_attention_forward( 131 | query_states=query_states, 132 | key_states=key_states, 133 | value_states=value_states, 134 | attention_mask=attention_mask, 135 | query_length=query_length, 136 | is_causal=is_causal, 137 | no_diag=no_diag 138 | ) 139 | 140 | self.assertEqual(result.shape, (2, 5, 3, 4)) 141 | 142 | # Test case: attention_mask is not None, square attention matrix 143 | query_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda") 144 | attention_mask = torch.ones(2, 6, dtype=torch.long, device="cuda") 145 | attention_mask[1, 2:] = 0 146 | result = flash_attention_forward( 147 | query_states=query_states, 148 | key_states=key_states, 149 | value_states=value_states, 150 | attention_mask=attention_mask, 151 | query_length=query_length, 152 | is_causal=is_causal, 153 | no_diag=no_diag 154 | ) 155 | 156 | self.assertEqual(result.shape, (2, 5, 3, 4)) 157 | 158 | def test_with_diag(self): 159 | # Test case: Test flash_attention_forward with no_diag=False 160 | query_states = torch.randn(2, 5, 3, 4, dtype=torch.bfloat16, device="cuda") 161 | key_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda") 162 | value_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda") 163 | attention_mask = None 164 | query_length = 5 165 | is_causal = True 166 | no_diag = False 167 | 168 | result = flash_attention_forward( 169 | query_states=query_states, 170 | key_states=key_states, 171 | value_states=value_states, 172 | attention_mask=attention_mask, 173 | query_length=query_length, 174 | is_causal=is_causal, 175 | no_diag=no_diag 176 | ) 177 | 178 | ref = _flash_attention_forward( 179 | query_states=query_states, 180 | key_states=key_states, 181 | value_states=value_states, 182 | attention_mask=attention_mask, 183 | query_length=query_length, 184 | is_causal=is_causal, 185 | ) 186 | 187 | self.assertEqual(result.shape, (2, 5, 3, 4)) 188 | self.assertTrue(torch.allclose(result, ref)) 189 | 190 | 191 | if __name__ == '__main__': 192 | unittest.main() 193 | --------------------------------------------------------------------------------