├── .gitignore
├── README.md
├── configs
├── default_chat_config.yaml
├── llama_small.json
├── llama_small_lckv.json
├── llama_tiny.json
├── llama_tiny_lckv.json
├── tinyllama.json
└── tinyllama_lckv.json
├── convert_pretrained.py
├── models
├── __init__.py
├── cache_utils.py
├── configuration_lckv.py
├── kernel.py
├── modeling_lckv.py
├── ops_rope.py
└── utils.py
├── pyproject.toml
├── requirements.txt
├── run_chat.py
├── run_clm.py
├── run_clm.sh
├── run_generation.py
├── run_generation.sh
├── run_sft.py
├── run_sft.sh
├── run_streaming.py
├── run_streaming.sh
├── test_harness.py
├── test_latency.py
├── test_streaming.py
└── tests
├── test_kernel.py
└── test_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | wandb
3 | outputs
4 | harness
5 | streaming
6 | run*.sh
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Layer-Condensed KV Cache
2 |
3 |
4 |

5 |
6 | The KVs of the top layer
7 |
8 | are the most informative and important.
9 |
10 | So why bother caching the rest?
11 |
12 |
13 |
14 | The code base for project **Layer-Condensed KV Cache**, a new variant of transformer decoders in which queries of all layers are paired with keys and values of just the top layer. It reduces the memory and computation cost, reduces the number of parameters, significantly improves the inference throughput with comparable or better task performance. The paper "[Layer-Condensed KV Cache for Efficient Inference of Large Language Models](https://arxiv.org/abs/2405.10637)" was accepted to ACL 2024 main conference.
15 |
16 | This work is inspired by [Probabilistic Transformer](https://github.com/whyNLP/Probabilistic-Transformer), where we consider the stacking layer structure of a transformer as an iterative process of improving token representation.
17 |
18 |
19 | The Map of AI Approaches
20 |
21 |

22 |
23 |
24 |
25 | ## News
26 |
27 | - [25/01/23] Our paper "[A Systematic Study of Cross-Layer KV Sharing for Efficient LLM Inference](http://arxiv.org/abs/2410.14442)" was accepted to NAACL 2025 main conference.
28 | - [24/12/08] We release the main branch, with a general framework for Cross-Layer KV Sharing. A illustrative post can be found on [PaperWeekly](https://mp.weixin.qq.com/s/Nr7K-xgcQRvHYNs82HU4gQ) (in Chinese). See the [published branch](https://github.com/whyNLP/LCKV/tree/dev-lckv-publish) for the old version of the code.
29 | - [24/10/18] Our new empirical study "[A Systematic Study of Cross-Layer KV Sharing for Efficient LLM Inference](http://arxiv.org/abs/2410.14442)" has released on arXiv. A new configuration has been found to be more efficient than the original LCKV.
30 | - [24/05/28] This code base now also supports Cross-Layer Attention (CLA). The idea is similar, but they 1) divide the transformer layers into small groups with 2-4 layers in each group; 2) pairs the queries of all the layers with the keys and values of the bottom layer in each group. See details in their paper "[Reducing Transformer Key-Value Cache Size with Cross-Layer Attention](http://arxiv.org/abs/2405.12981)".
31 | - [24/05/20] LCKV initial paper and code release.
32 | - [24/05/12] Our paper "[Layer-Condensed KV Cache for Efficient Inference of Large Language Models](http://arxiv.org/abs/2405.10637)" was accepted to ACL 2024 main conference.
33 | - [24/02/14] Our paper "[Layer-Condensed KV Cache for Efficient Inference of Large Language Models](http://arxiv.org/abs/2405.10637)" was submitted to ARR February 2024 cycle.
34 |
35 | ## Quick Start
36 |
37 | We have released a series of pre-trained models described in our paper on HuggingFace. There is no need to clone this repo if you just want to use the pre-trained models. Load the model with the following code:
38 |
39 | ```python
40 | # Use a pipeline as a high-level helper
41 | from transformers import pipeline
42 | pipe = pipeline("text-generation", model="whynlp/tinyllama-lckv-w2-ft-100b", trust_remote_code=True)
43 |
44 | # Load model directly
45 | from transformers import AutoModelForCausalLM
46 | model = AutoModelForCausalLM.from_pretrained("whynlp/tinyllama-lckv-w2-ft-100b", trust_remote_code=True)
47 | ```
48 |
49 | See more models on the [HuggingFace model hub](https://huggingface.co/models?search=whynlp). Note that these models are for research purposes only and may not be suitable for production.
50 |
51 | | Model | Paper Section | Dev ppl. | Common-sense Reasoning |
52 | | --------------------------------------------------------------------------------------------- | ------------------------------ | -------- | ---------------------- |
53 | | [whynlp/tinyllama-lckv-w10-ft-250b](https://huggingface.co/whynlp/tinyllama-lckv-w10-ft-250b) | -- | 7.939 | 50.86 |
54 | | [whynlp/tinyllama-lckv-w2-ft-100b](https://huggingface.co/whynlp/tinyllama-lckv-w2-ft-100b) | Appendix C.1, Table 7 (line 5) | 8.514 | 49.55 |
55 | | [whynlp/tinyllama-lckv-w10-100b](https://huggingface.co/whynlp/tinyllama-lckv-w10-100b) | Section 3.2, Table 2 (line 3) | 9.265 | 46.84 |
56 | | [whynlp/tinyllama-lckv-w2-100b](https://huggingface.co/whynlp/tinyllama-lckv-w2-100b) | Section 3.2, Table 2 (line 2) | 9.746 | 45.45 |
57 |
58 | ## Installation
59 |
60 | You may install the dependencies with the following commands:
61 |
62 | ```sh
63 | conda install pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
64 | pip install -r requirements.txt
65 | ```
66 |
67 | where the CUDA version is set to `12.1`. For other CUDA versions, please refer to installation instructions of [PyTorch](https://pytorch.org/get-started/locally/). See [Trouble shooting](#trouble-shooting) for more details.
68 |
69 | ## Usage
70 |
71 | Our implementation is based on HuggingFace `transformers`. We register a new model `lckv-llama` that supports the Layer-Condensed KV Cache. It inherits from the `llama` model and adds support for the Layer-Condensed KV Cache.
72 |
73 | > [!NOTE]
74 | > It is difficult to support the Layer-Condensed KV Cache for a variety of models with a small amount of code. This is because the Layer-Condensed KV Cache requires to modify the attention mechanism and training recipe of the transformer decoder. Currently, we only implemented the Layer-Condensed KV Cache for the `llama` model, and it is possible to extend it to other models with similar structures.
75 |
76 | ```python
77 | import models # register the lckv-llama model
78 | from transformers import AutoModelForCausalLM, AutoTokenizer
79 |
80 | tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
81 | model = AutoModelForCausalLM.from_config(config="configs/tinyllama_lckv.json")
82 | ```
83 |
84 | and now you have a randomly initialized model with the Layer-Condensed KV Cache.
85 |
86 | ### Optimization
87 |
88 | To accelerate the training and inference of the model, one could apply the liger kernel supported by `transformers` library. The provided training script `run_clm.py` has already activated the liger kernel. See more details [here](https://huggingface.co/docs/transformers/v4.45.2/en/trainer#liger-kernel).
89 |
90 | ### Configuration
91 |
92 | We provide some sample configuration files in the `configs` folder. The config settings are defined in [models/configuration_lckv.py](models/configuration_lckv.py). You may refer to this file for more details.
93 |
94 | #### Option 1: Modify the configurations in python:
95 |
96 | ```python
97 | from models import LCKVLlamaConfig
98 |
99 | # we have prepared a sample configuration file
100 | config = LCKVLlamaConfig.from_pretrained("configs/tinyllama_lckv.json")
101 |
102 | # below is the LCKV config. you may modify the configuration as you like
103 | config.forward_passes = 7 # m in the paper
104 | config.backward_passes = 2 # b in the paper
105 | config.layer_types = "0_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_21" # for each layer, which layer to attend to
106 |
107 | # we also support this
108 | config.layer_types = "0_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_10_21" # the sandwich-middle configuration
109 | config.layer_types = "0_1_2_3_4_5_6_7_8_9_10_11_12_13_14_15_16_17_18_19_20_21" # Llama config
110 | config.layer_types = "0_0_2_2_4_4_6_6_8_8_10_10_12_12_14_14_16_16_18_18_20_20" # CLA config
111 |
112 | config.sliding_window = 1024 # the window size for the sliding window attention
113 | config.layer_types = "0s_1s_2s_3s_4s_5s_6s_7s_8s_9s_10s_11_11_11_11_11_11_11_11_11_11_11" # YOCO config, 's' is for sliding window
114 |
115 | config.sliding_window = 1024 # the window size for the sliding window attention
116 | config.layer_types = "0_1s_1s_3s_3s_3s_0_7s_7s_9s_9s_9s_12_13s_13s_15s_15s_15s_12_19s_19s_19s" # MixAttention (Pairs) config
117 |
118 | # we also support sequential training / inference, which will process the tokens one by one
119 | # corresponding to LCKV paper Figure 2(a)
120 | config.use_sequential = True
121 | ```
122 |
123 | #### Option 2: Modify the configurations in the shell script (via `--config_overrides`):
124 |
125 | ```sh
126 | accelerate launch run_clm.py \
127 | --config_name configs/tinyllama_lckv.json \
128 | --config_overrides forward_passes=7,backward_passes=2,layer_types=0_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_21 \
129 | ...
130 | ```
131 |
132 | With the above configurations, you can create [CLA](http://arxiv.org/abs/2405.12981), [YOCO](https://arxiv.org/abs/2405.05254) or any configurations in [Cross-Layer KV Sharing](http://arxiv.org/abs/2410.14442) or [MixAttention](http://arxiv.org/abs/2409.15012) without changing the code. The only thing you need to do is to write the correct `layer_types` in the configuration file.
133 |
134 | ### Pre-training
135 |
136 | We use the same [training script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_clm.py) as the original `transformers` library. You may refer to the [official documentation](https://huggingface.co/transformers/training.html) for more details.
137 |
138 | We provide a training script `run_clm.sh` for training a 50M parameter model on the `wikitext-103` dataset. You may run the script with:
139 |
140 | ```sh
141 | bash run_clm.sh
142 | ```
143 |
144 | See the script for more details. For pretraining on SlimPajama, please follow the instructions in [tinyllama-zh](https://github.com/whyNLP/tinyllama-zh) and replace the dataset with SlimPajama.
145 |
146 |
147 | #### Initializing from a Pretrained Model
148 |
149 | We may initialize our LCKV model from a pretrained model. Most parts of the model structure are consistent with the standard transformer model and we can directly inherit the weights. For the KV weights $W_K, W_V$, we mainly have 2 options:
150 |
151 | ##### Option 1: Directly Copy the Weights
152 |
153 | Simply add `--model_name_or_path` to the training script:
154 |
155 | ```sh
156 | accelerate launch run_clm.py \
157 | --model_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T \
158 | --config configs/tinyllama_lckv.json \
159 | ...
160 | ```
161 |
162 | See the script `run_clm.sh` for more details.
163 |
164 | ##### Option 2: Average the Weights from Multiple Layers
165 |
166 | Following [MLKV](http://arxiv.org/abs/2406.09297), we may average the weights from multiple layers to initialize the KV weights. We provide a script `convert_pretrained.py` to convert the pretrained model to the LCKV model. You may run the following command:
167 |
168 | ```sh
169 | python convert_pretrained.py --model_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T --config_name configs/tinyllama_lckv.json --output_dir outputs/tinyllama-converted
170 | ```
171 |
172 | The KV weights of each layer will be the average from the all the layers attends to it. For example,
173 |
174 | ```python
175 | # the CLA / MLKV config
176 | config.layer_types = "0_0_2_2_4_4_6_6"
177 | # then layer 0 will have the average KV weights from layer 0 and 1 in the pretrained model
178 | # layer 2 will have the average KV weights from layer 2 and 3 in the pretrained model
179 |
180 | # the LCKV config
181 | config.layer_types = "0_6_6_6_6_6_6_7"
182 | # then layer 0 will inherit the KV weights from layer 0 in the pretrained model
183 | # layer 6 will have the average KV weights from layer 1, 2, 3, 4, 5, 6 in the pretrained model
184 | # layer 7 will inherit the KV weights from layer 7 in the pretrained model
185 | ```
186 |
187 | then, use the converted model to initialize the LCKV model:
188 |
189 | ```sh
190 | accelerate launch run_clm.py \
191 | --model_name_or_path outputs/tinyllama-converted \
192 | ...
193 | ```
194 |
195 | Our experiments show that such an initialization strategy can effectively improve the performance of the model in most cases.
196 |
197 |
198 | ### Inference
199 |
200 | We use the same [inference script](https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-generation/run_generation.py) as the original `transformers` library. To perform inference, you may run the following command:
201 |
202 | ```sh
203 | bash run_generation.sh
204 | ```
205 |
206 | You may get responses from the trained model given any prompts. See the script for more details.
207 |
208 | #### Streaming
209 |
210 | We integrate our model with [StreamingLLM](https://github.com/mit-han-lab/streaming-llm). To perform streaming inference, you may run the following command:
211 |
212 | ```sh
213 | bash run_streaming.sh
214 | ```
215 |
216 | See the script for more details. The `run_generation.py` script also supports streaming inference with the `--sink_cache` flag.
217 |
218 | #### Sliding Window Attention
219 |
220 | The generation script also supports sliding window attention inference. If the model is trained with sliding window attention, the generation script will automatically use the sliding window attention for inference.
221 |
222 | ### Evaluation
223 |
224 | We use [LM-Harness](https://github.com/EleutherAI/lm-evaluation-harness) to evaluate the model. You may run the following command:
225 |
226 | ```sh
227 | python test_harness.py --model_name_or_path ...
228 | ```
229 |
230 | with the path to the model checkpoint. Run `python test_harness.py --help` for more details.
231 |
232 | ### Latency Testing
233 |
234 | To test the latency of the model, you may run the following command:
235 |
236 | ```sh
237 | python test_latency.py
238 | ```
239 |
240 | ### Instruction Fine-tuning
241 |
242 | > [!WARNING]
243 | > This section is currently experimental and may not work as expected.
244 |
245 | We provide a script `run_sft.sh` for supervised instruction fine-tuning. The code is consistent with the official `trl` library from HuggingFace. You may run the script with:
246 |
247 | ```sh
248 | bash run_sft.sh
249 | ```
250 |
251 | See the script for more details.
252 |
253 | To chat with the fine-tuned model, you may run the following command:
254 |
255 | ```sh
256 | python chat.py --model_name_or_path outputs/llamatiny-sft-test
257 | ```
258 |
259 | It will load the fine-tuned model and you can chat with it.
260 |
261 | ## Code Style
262 |
263 | We mostly follow that of `transformers`. Run the following command to check the code style:
264 |
265 | ```sh
266 | # Use `pip install ruff` to install ruff if it is not available
267 | ruff check models
268 | ```
269 |
270 | See more details in `pyproject.toml`.
271 |
272 |
273 | ## Trouble shooting
274 |
275 | ### Flash-Attn Installation
276 |
277 | https://github.com/Dao-AILab/flash-attention/issues/451
278 |
279 | Behavior:
280 |
281 | Runtime error.
282 | ```sh
283 | ImportError: /home/.../flash_attn_2_cuda.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops9_pad_enum4callERKNS_6TensorEN3c108ArrayRefINS5_6SymIntEEElNS5_...
284 | ```
285 |
286 | Solution:
287 | ```sh
288 | pip uninstall flash-attn
289 | FLASH_ATTENTION_FORCE_BUILD=TRUE pip install flash-attn
290 | ```
291 |
292 | ### CUDA version
293 |
294 | The cuda version may affect the installation of:
295 | - [PyTorch](https://pytorch.org/get-started/locally/)
296 | - [Flash-Attn](https://github.com/Dao-AILab/flash-attention)
297 |
298 | Please make sure to install the correct version of the packages (so long as they are consistent, the code would work). Also make sure that `nvcc` is installed and available in the path.
299 |
300 | Our experiment environment uses `CUDA 12.1` and you may install with
301 | ```sh
302 | conda install pytorch==2.5.0 pytorch-cuda=12.1 -c pytorch -c nvidia
303 | pip install -r requirements.txt
304 | ```
305 |
306 | ### Sequential update produces different outputs
307 |
308 | Behavior: Model inference with sequential update will produce different outputs with parallel update.
309 |
310 | This is due to the precision issues. We find that using `bfloat16`, the down projection in Llama MLP will produce different results when inference with different number of tokens.
311 |
312 | ## Questions
313 |
314 | > 1. Is it possible to integrate the LCKV with MQA / GQA?
315 |
316 | Yes. The fact is that we have already done this in our experiments. Tinyllama uses 32 attention heads and 4 KV heads. We follow the same setting in our experiments. If you want to experiment with different settings, you may modify the `num_attention_heads` and `num_key_value_heads` in the configuration file.
317 |
--------------------------------------------------------------------------------
/configs/default_chat_config.yaml:
--------------------------------------------------------------------------------
1 | examples:
2 | llama:
3 | text: There is a Llama in my lawn, how can I get rid of it?
4 | code:
5 | text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].
6 | helicopter:
7 | text: How many helicopters can a human eat in one sitting?
8 | numbers:
9 | text: Count to 10 but skip every number ending with an 'e'
10 | birds:
11 | text: Why aren't birds real?
12 | socks:
13 | text: Why is it important to eat socks after meditating?
--------------------------------------------------------------------------------
/configs/llama_small.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "meta-llama/Llama-2-7b-hf",
3 | "architectures": [
4 | "LlamaForCausalLM"
5 | ],
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 768,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 2048,
12 | "max_position_embeddings": 1024,
13 | "model_type": "lckv-llama",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "num_key_value_heads": 6,
17 | "pretraining_tp": 1,
18 | "rms_norm_eps": 1e-05,
19 | "rope_scaling": null,
20 | "tie_word_embeddings": false,
21 | "tokenizer_class": "LlamaTokenizer",
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.45.2",
24 | "use_cache": true,
25 | "vocab_size": 32000
26 | }
--------------------------------------------------------------------------------
/configs/llama_small_lckv.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "meta-llama/Llama-2-7b-hf",
3 | "architectures": [
4 | "LCKVLlamaForCausalLM"
5 | ],
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 768,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 2048,
12 | "max_position_embeddings": 1024,
13 | "model_type": "lckv-llama",
14 | "num_attention_heads": 12,
15 | "num_hidden_layers": 12,
16 | "num_key_value_heads": 6,
17 | "pretraining_tp": 1,
18 | "rms_norm_eps": 1e-05,
19 | "rope_scaling": null,
20 | "tie_word_embeddings": false,
21 | "tokenizer_class": "LlamaTokenizer",
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.45.2",
24 | "use_cache": true,
25 | "vocab_size": 32000,
26 | "layer_types": "0_10_10_10_10_10_10_10_10_10_10_11",
27 | "forward_passes": 7,
28 | "backward_passes": 2
29 | }
--------------------------------------------------------------------------------
/configs/llama_tiny.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "meta-llama/Llama-2-7b-hf",
3 | "architectures": [
4 | "LlamaForCausalLM"
5 | ],
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 512,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 1024,
12 | "max_position_embeddings": 1024,
13 | "model_type": "llama",
14 | "num_attention_heads": 8,
15 | "num_hidden_layers": 8,
16 | "num_key_value_heads": 4,
17 | "pretraining_tp": 1,
18 | "rms_norm_eps": 1e-05,
19 | "rope_scaling": null,
20 | "tie_word_embeddings": false,
21 | "tokenizer_class": "LlamaTokenizer",
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.45.2",
24 | "use_cache": true,
25 | "vocab_size": 32000
26 | }
--------------------------------------------------------------------------------
/configs/llama_tiny_lckv.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "meta-llama/Llama-2-7b-hf",
3 | "architectures": [
4 | "LCKVLlamaForCausalLM"
5 | ],
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 512,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 1024,
12 | "max_position_embeddings": 1024,
13 | "model_type": "lckv-llama",
14 | "num_attention_heads": 8,
15 | "num_hidden_layers": 8,
16 | "num_key_value_heads": 4,
17 | "pretraining_tp": 1,
18 | "rms_norm_eps": 1e-05,
19 | "rope_scaling": null,
20 | "tie_word_embeddings": false,
21 | "tokenizer_class": "LlamaTokenizer",
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.45.2",
24 | "use_cache": true,
25 | "vocab_size": 32000,
26 | "layer_types": "0_6_6_6_6_6_6_7",
27 | "forward_passes": 7,
28 | "backward_passes": 2
29 | }
--------------------------------------------------------------------------------
/configs/tinyllama.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "meta-llama/Llama-2-7b-hf",
3 | "architectures": [
4 | "LlamaForCausalLM"
5 | ],
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 2048,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 5632,
12 | "max_position_embeddings": 2048,
13 | "model_type": "llama",
14 | "num_attention_heads": 32,
15 | "num_hidden_layers": 22,
16 | "num_key_value_heads": 4,
17 | "pretraining_tp": 1,
18 | "rms_norm_eps": 1e-05,
19 | "rope_scaling": null,
20 | "tie_word_embeddings": false,
21 | "tokenizer_class": "LlamaTokenizer",
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.45.2",
24 | "use_cache": true,
25 | "vocab_size": 32000
26 | }
--------------------------------------------------------------------------------
/configs/tinyllama_lckv.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "meta-llama/Llama-2-7b-hf",
3 | "architectures": [
4 | "LCKVLlamaForCausalLM"
5 | ],
6 | "bos_token_id": 1,
7 | "eos_token_id": 2,
8 | "hidden_act": "silu",
9 | "hidden_size": 2048,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 5632,
12 | "max_position_embeddings": 2048,
13 | "model_type": "lckv-llama",
14 | "num_attention_heads": 32,
15 | "num_hidden_layers": 22,
16 | "num_key_value_heads": 4,
17 | "pretraining_tp": 1,
18 | "rms_norm_eps": 1e-05,
19 | "rope_scaling": null,
20 | "tie_word_embeddings": false,
21 | "tokenizer_class": "LlamaTokenizer",
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.45.2",
24 | "use_cache": true,
25 | "vocab_size": 32000,
26 | "layer_types": "0_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_20_21",
27 | "forward_passes": 7,
28 | "backward_passes": 2
29 | }
--------------------------------------------------------------------------------
/convert_pretrained.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 |
4 | import models
5 | from models.utils import LayerTypeParser
6 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
7 |
8 |
9 | def main():
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--model_name_or_path", help="The pretrained llama model to convert", required=True)
13 | parser.add_argument("--config_name", help="The config file of the expected LCKV model", required=True)
14 | parser.add_argument("--config_overrides", help="Override some existing config settings. Example: layer_types=0_6_6_6_6_6_6_7,forward_passes=7", default=None, required=False)
15 | parser.add_argument("--tokenizer_name", help="Pretrained tokenizer name or path if not the same as the pretrained model.", default=None, required=False)
16 | parser.add_argument("--output_dir", help="The output directory where the converted model will be written.", required=True)
17 | args = parser.parse_args()
18 |
19 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path)
20 | config = AutoConfig.from_pretrained(args.config_name)
21 | pt_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
22 | pt_model_state_dict = pt_model.state_dict()
23 |
24 | assert config.model_type == "lckv-llama", "The target model must be a LCKV model"
25 | # allow config overrides under all circumstances
26 | if args.config_overrides is not None:
27 | print(f"Overriding config: {args.config_overrides}")
28 | config.update_from_string(args.config_overrides)
29 | print(f"New config: {config}")
30 |
31 | model = AutoModelForCausalLM.from_config(config)
32 | model_state_dict = model.state_dict()
33 |
34 | # Copy the weights from the pretrained model to the LCKV model
35 | print("Copying weights from the pretrained model to the LCKV model...")
36 | for name, param in pt_model.named_parameters():
37 | if ('k_proj' in name or 'v_proj' in name):
38 | continue
39 |
40 | if name in model_state_dict:
41 | model_state_dict[name].copy_(param.data)
42 | else:
43 | print(f"WARNING: {name} not found in the model")
44 |
45 | # Average the weights of the k_proj and v_proj layers
46 | # The pretrained layer weights will contribute to the layer it attends to
47 | # XXX: how to align heads?
48 | print("Averaging the weights of the k_proj and v_proj layers...")
49 | parser = LayerTypeParser(config.layer_types)
50 | k_proj, v_proj = defaultdict(list), defaultdict(list)
51 | for layer_type in parser:
52 | k_proj[layer_type.attends_to].append(pt_model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.k_proj.weight"])
53 | v_proj[layer_type.attends_to].append(pt_model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.v_proj.weight"])
54 |
55 | for layer_type in parser:
56 | if layer_type.computes_kv:
57 | model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.k_proj.weight"].copy_(sum(k_proj[layer_type.layer_idx]) / len(k_proj[layer_type.layer_idx]))
58 | model_state_dict[f"model.layers.{layer_type.layer_idx}.self_attn.v_proj.weight"].copy_(sum(v_proj[layer_type.layer_idx]) / len(v_proj[layer_type.layer_idx]))
59 |
60 | # Save the model
61 | print(f"Saving the model to {args.output_dir}...")
62 | model.save_pretrained(args.output_dir)
63 | tokenizer.save_pretrained(args.output_dir)
64 |
65 | print("Model convertion finished successfully")
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
2 | from transformers.utils import is_liger_kernel_available
3 |
4 | from .configuration_lckv import LCKVLlamaConfig
5 | from .modeling_lckv import LCKVLlamaForCausalLM, LCKVLlamaModel
6 |
7 |
8 | AutoConfig.register("lckv-llama", LCKVLlamaConfig)
9 | AutoModel.register(LCKVLlamaConfig, LCKVLlamaModel)
10 | AutoModelForCausalLM.register(LCKVLlamaConfig, LCKVLlamaForCausalLM)
11 |
12 |
13 | if is_liger_kernel_available():
14 | from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
15 |
16 | from .kernel import apply_liger_kernel_to_lckv_llama
17 | MODEL_TYPE_TO_APPLY_LIGER_FN["lckv-llama"] = apply_liger_kernel_to_lckv_llama
18 |
--------------------------------------------------------------------------------
/models/cache_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple
2 |
3 | import torch
4 |
5 | from transformers.cache_utils import Cache, DynamicCache, SinkCache
6 |
7 | from .utils import LayerTypeParser
8 |
9 |
10 | class IndexedCache(Cache):
11 | """
12 | Similar to the `DynamicCache` class, but with the ability to index the cache by layer index. DynamicCache
13 | assumes that all layers compute KVs, while IndexedCache allows for a more flexible cache structure.
14 | """
15 | build_position_ids_based_on_cache = False
16 |
17 | def __init__(self) -> None:
18 | super().__init__()
19 | self.key_cache: Dict[int, torch.Tensor] = {}
20 | self.value_cache: Dict[int, torch.Tensor] = {}
21 | self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
22 | self._update = True # to prevent the cache from updating when inference with iterations
23 |
24 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
25 | """
26 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
27 | sequence length.
28 | """
29 | if layer_idx in self.key_cache:
30 | return (self.key_cache[layer_idx], self.value_cache[layer_idx])
31 | else:
32 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
33 |
34 | def __iter__(self):
35 | """
36 | Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
37 | keys and values
38 | """
39 | for layer_idx in sorted(self.key_cache.keys()):
40 | yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
41 |
42 | def __len__(self):
43 | """
44 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
45 | to the number of layers that compute KVs in the model.
46 | """
47 | return len(self.key_cache)
48 |
49 | @property
50 | def min_layer(self) -> int:
51 | return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None
52 |
53 | def is_min_layer(self, layer_idx: int) -> bool:
54 | return self.min_layer is None or self.min_layer == layer_idx
55 |
56 | def update(
57 | self,
58 | key_states: torch.Tensor,
59 | value_states: torch.Tensor,
60 | layer_idx: int,
61 | cache_kwargs: Optional[Dict[str, Any]] = None,
62 | ) -> Tuple[torch.Tensor, torch.Tensor]:
63 | """
64 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
65 |
66 | Parameters:
67 | key_states (`torch.Tensor`):
68 | The new key states to cache.
69 | value_states (`torch.Tensor`):
70 | The new value states to cache.
71 | layer_idx (`int`):
72 | The index of the layer to cache the states for.
73 | cache_kwargs (`Dict[str, Any]`, `optional`):
74 | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
75 |
76 | Return:
77 | A tuple containing the updated key and value states.
78 | """
79 | # Update the number of seen tokens
80 | if self.is_min_layer(layer_idx):
81 | self._seen_tokens += key_states.shape[-2]
82 |
83 | # Retrieve the cache
84 | if layer_idx not in self.key_cache:
85 | new_key_states = key_states
86 | new_value_states = value_states
87 | else:
88 | new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
89 | new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
90 |
91 | # Update the cache
92 | if self._update:
93 | self.key_cache[layer_idx] = new_key_states
94 | self.value_cache[layer_idx] = new_value_states
95 |
96 | return new_key_states, new_value_states
97 |
98 | def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
99 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
100 | if layer_idx is None:
101 | layer_idx = self.min_layer
102 |
103 | # TODO: deprecate this function in favor of `cache_position`
104 | is_empty_layer = (
105 | (len(self.key_cache) == 0) # no cache in any layer
106 | or (layer_idx not in self.key_cache) # skipped `layer_idx` and hasn't run a layer with cache after it
107 | )
108 | layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
109 | return layer_seq_length
110 |
111 | def get_max_length(self) -> Optional[int]:
112 | """Returns the maximum sequence length of the cached states. IndexedCache does not have a maximum length."""
113 | return None
114 |
115 | @classmethod
116 | def from_cache(cls, dynamic_cache: DynamicCache, *args, **kwargs) -> "IndexedCache":
117 | """Converts a dynamic cache into an equivalent `IndexedCache`."""
118 | cache = cls(*args, **kwargs)
119 |
120 | cache._seen_tokens = dynamic_cache._seen_tokens
121 | for layer_idx in range(len(dynamic_cache.key_cache)):
122 | key_states, value_states = dynamic_cache[layer_idx]
123 | cache.update(key_states, value_states, layer_idx)
124 |
125 | return cache
126 |
127 |
128 | class IndexedSinkCache(Cache):
129 | """
130 | This is a fix to the SinkCache class in the transformers library. It also allows for the cache to be indexed by
131 | layer index, similar to the `IndexedCache` class.
132 | """
133 | build_position_ids_based_on_cache = True
134 |
135 | def __init__(self, window_length: int = None, num_sink_tokens: int = None) -> None:
136 | super().__init__()
137 | self.key_cache: Dict[int, torch.Tensor] = {}
138 | self.value_cache: Dict[int, torch.Tensor] = {}
139 | self.window_length = window_length
140 | self.num_sink_tokens = num_sink_tokens
141 | self.cos_sin_rerotation_cache = {}
142 | self._cos_cache = None
143 | self._sin_cache = None
144 | self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
145 | self._update = True # to prevent the cache from updating when inference with iterations
146 |
147 | @staticmethod
148 | def _rotate_half(x):
149 | x1 = x[..., : x.shape[-1] // 2]
150 | x2 = x[..., x.shape[-1] // 2 :]
151 | return torch.cat((-x2, x1), dim=-1)
152 |
153 | def _apply_key_rotary_pos_emb(
154 | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
155 | ) -> torch.Tensor:
156 | rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
157 | return rotated_key_states
158 |
159 | def _get_rerotation_cos_sin(
160 | self, offset: int, dtype: torch.dtype, cos: torch.Tensor, sin: torch.Tensor
161 | ) -> Tuple[torch.Tensor, torch.Tensor]:
162 | if offset not in self.cos_sin_rerotation_cache:
163 | # Upcast to float32 temporarily for better accuracy
164 | cos = cos.to(torch.float32)
165 | sin = sin.to(torch.float32)
166 |
167 | # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
168 | original_cos = cos[self.num_sink_tokens + offset :]
169 | shifted_cos = cos[self.num_sink_tokens : -offset]
170 | original_sin = sin[self.num_sink_tokens + offset :]
171 | shifted_sin = sin[self.num_sink_tokens : -offset]
172 | rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
173 | rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
174 |
175 | self.cos_sin_rerotation_cache[offset] = (
176 | rerotation_cos.to(dtype).unsqueeze(0),
177 | rerotation_sin.to(dtype).unsqueeze(0),
178 | )
179 | return self.cos_sin_rerotation_cache[offset]
180 |
181 | @property
182 | def min_layer(self) -> int:
183 | return min(self.key_cache.keys()) if len(self.key_cache) > 0 else None
184 |
185 | def is_min_layer(self, layer_idx: int) -> bool:
186 | return self.min_layer is None or self.min_layer == layer_idx
187 |
188 | def get_seq_length(self, layer_idx: Optional[int] = None) -> int:
189 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
190 | # TODO: deprecate this function in favor of `cache_position`
191 | # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
192 | if layer_idx is None:
193 | layer_idx = self.min_layer
194 |
195 | if layer_idx not in self.key_cache:
196 | return 0
197 |
198 | return self.key_cache[layer_idx].shape[-2]
199 |
200 | def get_max_length(self) -> Optional[int]:
201 | """Returns the maximum sequence length of the cached states."""
202 | return self.window_length
203 |
204 | def update(
205 | self,
206 | key_states: torch.Tensor,
207 | value_states: torch.Tensor,
208 | layer_idx: int,
209 | cache_kwargs: Optional[Dict[str, Any]] = None,
210 | ) -> Tuple[torch.Tensor, torch.Tensor]:
211 | """
212 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
213 |
214 | Parameters:
215 | key_states (`torch.Tensor`):
216 | The new key states to cache.
217 | value_states (`torch.Tensor`):
218 | The new value states to cache.
219 | layer_idx (`int`):
220 | The index of the layer to cache the states for.
221 | cache_kwargs (`Dict[str, Any]`, `optional`):
222 | Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
223 | `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
224 | rotation as the tokens are shifted.
225 |
226 | Return:
227 | A tuple containing the updated key and value states.
228 | """
229 | # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
230 | # with partially rotated position embeddings, like Phi or Persimmon.
231 | sin = cache_kwargs.get("sin")
232 | cos = cache_kwargs.get("cos")
233 | partial_rotation_size = cache_kwargs.get("partial_rotation_size")
234 | using_rope = cos is not None and sin is not None
235 |
236 | # Update the number of seen tokens
237 | if self.is_min_layer(layer_idx):
238 | self._seen_tokens += key_states.shape[-2]
239 |
240 | # Update the sin/cos cache, which holds sin/cos values for all possible positions
241 | if using_rope and self.is_min_layer(layer_idx):
242 | # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
243 | # after all RoPE models have a llama-like cache utilization.
244 | if cos.dim() == 2:
245 | self._cos_cache = cos
246 | self._sin_cache = sin
247 | else:
248 | if self._cos_cache is None:
249 | self._cos_cache = cos[0, ...]
250 | self._sin_cache = sin[0, ...]
251 | elif self._cos_cache.shape[0] < self.window_length + key_states.shape[-2]:
252 | self._cos_cache = torch.cat([self._cos_cache[: self.window_length], cos[0, ...]], dim=0)
253 | self._sin_cache = torch.cat([self._sin_cache[: self.window_length], sin[0, ...]], dim=0)
254 |
255 | # [bsz, num_heads, seq_len, head_dim]
256 | if layer_idx not in self.key_cache:
257 | # Empty cache
258 | new_key_states = key_states
259 | new_value_states = value_states
260 |
261 | else:
262 | # Growing cache
263 | new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
264 | new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
265 |
266 | if self._update:
267 | self.key_cache[layer_idx] = new_key_states
268 | self.value_cache[layer_idx] = new_value_states
269 |
270 | # If the cache is full, we need to shift the cache
271 | if (seq_length := self.get_seq_length(layer_idx)) > self.window_length:
272 | # Shifting cache
273 | keys_to_keep = self.key_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :]
274 |
275 | # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
276 | if using_rope:
277 | rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
278 | seq_length - self.window_length,
279 | key_states.dtype,
280 | self._cos_cache[:seq_length],
281 | self._sin_cache[:seq_length],
282 | )
283 | if partial_rotation_size is not None:
284 | keys_to_keep, keys_pass = (
285 | keys_to_keep[..., :partial_rotation_size],
286 | keys_to_keep[..., partial_rotation_size:],
287 | )
288 | keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
289 | if partial_rotation_size is not None:
290 | keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
291 |
292 | # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
293 | sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
294 | self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep], dim=-2)
295 |
296 | sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
297 | values_to_keep = self.value_cache[layer_idx][:, :, -self.window_length + self.num_sink_tokens :]
298 | self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep], dim=-2)
299 |
300 | return new_key_states, new_value_states
301 |
302 | @classmethod
303 | def from_cache(cls, sink_cache: SinkCache, *args, **kwargs) -> "IndexedSinkCache":
304 | """Converts a dynamic cache into an equivalent `IndexedCache`."""
305 | cache = cls(*args, **kwargs)
306 |
307 | cache.window_length = sink_cache.window_length
308 | cache.num_sink_tokens = sink_cache.num_sink_tokens
309 | cache._seen_tokens = sink_cache._seen_tokens
310 | cache._cos_cache = sink_cache._cos_cache
311 | cache._sin_cache = sink_cache._sin_cache
312 | cache.cos_sin_rerotation_cache = sink_cache.cos_sin_rerotation_cache
313 | for layer_idx in range(len(sink_cache.key_cache)):
314 | cache.key_cache[layer_idx] = sink_cache.key_cache[layer_idx]
315 | cache.value_cache[layer_idx] = sink_cache.value_cache[layer_idx]
316 |
317 | return cache
318 |
319 |
320 | class IndexedSlidingWindowCache(IndexedCache):
321 | """
322 | Similar to the `SlidingWindowCache` class, but with the ability to index the cache by layer index. It is no longer
323 | a subclass of `StaticCache` as it is dynamic.
324 | """
325 | build_position_ids_based_on_cache = False
326 |
327 | def __init__(self, sliding_window: int = None) -> None:
328 | super().__init__()
329 | self.sliding_window = sliding_window
330 |
331 | def update(
332 | self,
333 | key_states: torch.Tensor,
334 | value_states: torch.Tensor,
335 | layer_idx: int,
336 | cache_kwargs: Optional[Dict[str, Any]] = None,
337 | ) -> Tuple[torch.Tensor]:
338 | # Update the number of seen tokens
339 | if self.is_min_layer(layer_idx):
340 | self._seen_tokens += key_states.shape[-2]
341 |
342 | # [bsz, num_heads, seq_len, head_dim]
343 | if layer_idx not in self.key_cache:
344 | # Empty cache
345 | new_key_states = key_states
346 | new_value_states = value_states
347 |
348 | else:
349 | # Growing cache
350 | new_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
351 | new_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
352 |
353 | if self._update:
354 | self.key_cache[layer_idx] = new_key_states
355 | self.value_cache[layer_idx] = new_value_states
356 |
357 | # If the cache is full, we need to shift the cache
358 | if self.get_seq_length(layer_idx) > self.sliding_window:
359 | self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :, -self.sliding_window :]
360 | self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :, -self.sliding_window :]
361 |
362 | return new_key_states, new_value_states
363 |
364 | def get_max_length(self) -> Optional[int]:
365 | return self.sliding_window
366 |
367 | @classmethod
368 | def from_cache(cls, sliding_window_cache: "IndexedSlidingWindowCache", *args, **kwargs) -> "IndexedSlidingWindowCache":
369 | """This is to override the `from_cache` method in the `IndexedCache` class."""
370 | cache = cls(*args, **kwargs)
371 |
372 | cache._seen_tokens = sliding_window_cache._seen_tokens
373 | cache.sliding_window = sliding_window_cache.sliding_window
374 | for layer_idx in range(len(sliding_window_cache.key_cache)):
375 | cache.key_cache[layer_idx] = sliding_window_cache.key_cache[layer_idx]
376 | cache.value_cache[layer_idx] = sliding_window_cache.value_cache[layer_idx]
377 |
378 | return cache
379 |
380 |
381 | class IndexedHybridCache(IndexedSlidingWindowCache, IndexedCache):
382 | """
383 | Hybrid Cache class to be used for models that alternate between a local sliding window attention and global
384 | attention in every other layer. Under the hood, Hybrid Cache leverages ["IndexedSlidingWindowCache"] for
385 | sliding window attention and ["IndexedCache"] for global attention.
386 | """
387 | build_position_ids_based_on_cache = False
388 |
389 | def __init__(self, parser: LayerTypeParser = None, sliding_window: int = None) -> None:
390 | super().__init__(sliding_window=sliding_window)
391 | self.parser = parser
392 |
393 | def update(
394 | self,
395 | key_states: torch.Tensor,
396 | value_states: torch.Tensor,
397 | layer_idx: int,
398 | cache_kwargs: Optional[Dict[str, Any]] = None,
399 | ) -> Tuple[torch.Tensor]:
400 | if self.parser[layer_idx].use_sliding_window:
401 | return IndexedSlidingWindowCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
402 | else:
403 | return IndexedCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
404 |
405 | def get_max_length(self) -> Optional[int]:
406 | return IndexedCache.get_max_length(self)
407 |
408 | @classmethod
409 | def from_cache(cls, hybrid_cache: "IndexedHybridCache", *args, **kwargs) -> "IndexedHybridCache":
410 | """This is to override the `from_cache` method in the `IndexedSlidingWindowCache` class."""
411 | cache = cls(*args, **kwargs)
412 |
413 | cache._seen_tokens = hybrid_cache._seen_tokens
414 | cache.sliding_window = hybrid_cache.sliding_window
415 | cache.parser = hybrid_cache.parser
416 | for layer_idx in range(len(hybrid_cache.key_cache)):
417 | cache.key_cache[layer_idx] = hybrid_cache.key_cache[layer_idx]
418 | cache.value_cache[layer_idx] = hybrid_cache.value_cache[layer_idx]
419 |
420 | return cache
421 |
422 |
423 | class LayerCache(torch.nn.Module):
424 | """
425 | A cache for storing the key-value pairs for layers.
426 | """
427 | def __init__(self) -> None:
428 | """
429 | The placeholder is used to expand the key-value pairs if the layer attends to the top layers.
430 | Size: (batch_size, num_key_value_heads, 1, head_dim)
431 | """
432 | super().__init__()
433 | self.key_layer_cache: Dict[int, torch.Tensor] = {}
434 | self.value_layer_cache: Dict[int, torch.Tensor] = {}
435 | self.layer_type = None
436 | self.placeholder = None
437 |
438 | def setup(self, placeholder: torch.Tensor):
439 | """setup the cache, calling this function is necessary if there is a layer that attends to the top layers"""
440 | self.placeholder = placeholder
441 |
442 | def initialize(self, parser: LayerTypeParser, sequence_length: int):
443 | """initialize the cache"""
444 | layers_to_init = {parser[idx].attends_to for idx in range(len(parser)) if parser[idx].attends_top}
445 |
446 | if layers_to_init:
447 | b, h, _, d = self.placeholder.size()
448 | init_kvs = self.placeholder.new_zeros((b, h, sequence_length, d))
449 |
450 | for layer_idx in layers_to_init:
451 | self.layer_append(layer_idx, init_kvs, init_kvs)
452 |
453 | def layer_get(self, layer_idx: int, zerofill: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
454 | key_states = self.key_layer_cache.get(layer_idx, None)
455 | value_states = self.value_layer_cache.get(layer_idx, None)
456 |
457 | if zerofill:
458 | if key_states is None:
459 | key_states = self.placeholder
460 | value_states = self.placeholder
461 | else:
462 | key_states = torch.cat([self.placeholder, key_states], dim=2)
463 | value_states = torch.cat([self.placeholder, value_states], dim=2)
464 |
465 | return key_states, value_states
466 |
467 | def layer_set(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor):
468 | self.key_layer_cache[layer_idx] = key
469 | self.value_layer_cache[layer_idx] = value
470 |
471 | def layer_append(self, layer_idx: int, key: torch.Tensor, value: torch.Tensor):
472 | if layer_idx not in self.key_layer_cache:
473 | self.key_layer_cache[layer_idx] = key
474 | self.value_layer_cache[layer_idx] = value
475 | else:
476 | self.key_layer_cache[layer_idx] = torch.cat([self.key_layer_cache[layer_idx], key], dim=2)
477 | self.value_layer_cache[layer_idx] = torch.cat([self.value_layer_cache[layer_idx], value], dim=2)
478 |
479 |
480 | class LayerIndexedCache(LayerCache, IndexedCache):
481 | """
482 | A cache for storing the key-value pairs for layers, in combination with the ability of standard KV cache.
483 | """
484 | def __init__(self) -> None:
485 | LayerCache.__init__(self)
486 | IndexedCache.__init__(self)
487 |
488 |
489 | class LayerIndexedSinkCache(LayerCache, IndexedSinkCache):
490 | """
491 | A cache for storing the key-value pairs for layers, in combination with the ability of sink KV cache.
492 | """
493 | def __init__(self) -> None:
494 | LayerCache.__init__(self)
495 | IndexedSinkCache.__init__(self)
496 |
497 |
498 | class LayerIndexedSlidingWindowCache(LayerCache, IndexedSlidingWindowCache):
499 | """
500 | A cache for storing the key-value pairs for layers, in combination with the ability of sliding window KV cache.
501 | """
502 | def __init__(self) -> None:
503 | LayerCache.__init__(self)
504 | IndexedSlidingWindowCache.__init__(self)
505 |
506 |
507 | class LayerIndexedHybridCache(LayerCache, IndexedHybridCache):
508 | """
509 | A cache for storing the key-value pairs for layers, in combination with the ability of hybrid KV cache.
510 | """
511 | def __init__(self) -> None:
512 | LayerCache.__init__(self)
513 | IndexedHybridCache.__init__(self)
514 |
515 |
516 | class AutoLayerCache(torch.nn.Module):
517 | """
518 | AutoLayerCache is a module that automatically creates a cache from an existing cache.
519 | """
520 | CACHE_MAPPING = {
521 | DynamicCache: LayerIndexedCache,
522 | SinkCache: LayerIndexedSinkCache,
523 | IndexedSlidingWindowCache: LayerIndexedSlidingWindowCache,
524 | IndexedHybridCache: LayerIndexedHybridCache,
525 | }
526 |
527 | def __init__(self, *args, **kwargs):
528 | raise RuntimeError(
529 | f"{self.__class__.__name__} is designed to be instantiated "
530 | f"using the `{self.__class__.__name__}.from_cache(cache)` method."
531 | )
532 |
533 | @classmethod
534 | def from_cache(cls, cache: Cache, *args, **kwargs):
535 | """
536 | Create a new cache from an existing cache. The new cache will have the same type as the original cache.
537 | """
538 | cache_type = type(cache)
539 | if cache_type not in cls.CACHE_MAPPING:
540 | raise ValueError(f"Cache type {cache_type} is not supported by {cls.__name__}.")
541 |
542 | cache_class = cls.CACHE_MAPPING[cache_type]
543 |
544 | if hasattr(cache_class, "from_cache"):
545 | return cache_class.from_cache(cache, *args, **kwargs)
546 | else:
547 | # we init an empty cache and copy the attributes
548 | new_cache = cache_class(*args, **kwargs)
549 | new_cache.__dict__.update(cache.__dict__)
550 | return new_cache
551 |
--------------------------------------------------------------------------------
/models/configuration_lckv.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ LCKV LLaMA model configuration"""
21 | from transformers.models.llama.configuration_llama import LlamaConfig
22 |
23 | from .utils import LayerTypeParser
24 |
25 |
26 | class LCKVLlamaConfig(LlamaConfig):
27 |
28 | model_type = "lckv-llama"
29 |
30 | def __init__(
31 | self,
32 | layer_types: str = None,
33 | forward_passes: int = 7,
34 | backward_passes: int = 2,
35 | sliding_window: int = 4096,
36 | use_sequential: bool = False,
37 | force_nodiag: bool = False,
38 | **kwargs,
39 | ):
40 | """
41 | Initialize a LCKV LLaMA configuration. Instantiating a configuration with the defaults
42 | will yield a similar configuration to that of the LLaMA-7B with the standard transformer
43 | training scheme.
44 |
45 | Args:
46 | layer_types (`str`, *optional*):
47 | A string of integers separated by underscores. The i-th integer means the layer
48 | will use the key-value pair in the i-th layer as the kv cache. Special characters
49 | may be placed after the integers:
50 | - `s` means the layer will use sliding window attention.
51 | The default value is "0_1_2_..." till the number of layers in the current config.
52 | forward_passes (`int`, *optional*, defaults to 7):
53 | The number of forward passes during training and prompt encoding. Equivlent
54 | to `m` in the paper.
55 | backward_passes (`int`, *optional*, defaults to 2):
56 | The number of backward passes during training and prompt encoding. Equivlent
57 | to `b` in the paper.
58 | sliding_window (`int`, *optional*, defaults to 4096):
59 | Sliding window attention window size. If not specified, will default to `4096`.
60 | It will only be effective if the corresponding layer uses sliding window attention.
61 | use_sequential (`bool`, *optional*, defaults to False):
62 | Whether to do forwarding sequentially, token by token. Useful for testing purpose
63 | for models with cyclic dependency. Also can be used for sequential training.
64 | force_nodiag (`bool`, *optional*, defaults to False):
65 | Whether to force the model to not use the diagonal attention. By default, the model
66 | will mask the diagonal attention only in layers necessary. If set to `True`, the model
67 | will never use the diagonal attention in any layer. This is mainly for backward compatibility.
68 | """
69 | super().__init__(**kwargs)
70 | self.layer_types = layer_types
71 | self.forward_passes = forward_passes
72 | self.backward_passes = backward_passes
73 | self.sliding_window = sliding_window
74 | self.use_sequential = use_sequential
75 | self.force_nodiag = force_nodiag
76 |
77 | if self.layer_types is None:
78 | self.layer_types = "_".join(map(str, range(self.num_hidden_layers)))
79 |
80 | # post check
81 | LayerTypeParser(self.layer_types).check(self.num_hidden_layers)
82 |
--------------------------------------------------------------------------------
/models/kernel.py:
--------------------------------------------------------------------------------
1 | from liger_kernel.transformers.monkey_patch import (
2 | LigerCrossEntropyLoss,
3 | LigerRMSNorm,
4 | LigerSwiGLUMLP,
5 | PreTrainedModel,
6 | _bind_method_to_module,
7 | _patch_rms_norm_module,
8 | llama_lce_forward,
9 | )
10 |
11 | from .ops_rope import SingleLigerRopeFunction
12 |
13 |
14 | def liger_rotary(q, cos, sin, unsqueeze_dim=1):
15 | """
16 | Applies Rotary Positional Embedding (RoPE) operation to query and key states.
17 |
18 | Args:
19 | q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
20 | cos (torch.Tensor): The cosine tensor of shape (1, seq_len, head_dim).
21 | sin (torch.Tensor): The sine tensor of shape (1, seq_len, head_dim).
22 | unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
23 |
24 | Returns:
25 | Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the RoPE operation.
26 | """
27 |
28 | return SingleLigerRopeFunction.apply(q, cos, sin, None, unsqueeze_dim)
29 |
30 |
31 | def apply_liger_kernel_to_lckv_llama(
32 | rope: bool = True,
33 | cross_entropy: bool = False,
34 | fused_linear_cross_entropy: bool = True,
35 | rms_norm: bool = True,
36 | swiglu: bool = True,
37 | model: PreTrainedModel = None,
38 | ) -> None:
39 | """
40 | Apply Liger kernels to replace original implementation in LCKV models.
41 |
42 | Args:
43 | rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
44 | cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
45 | fused_linear_cross_entropy (bool):
46 | Whether to apply Liger's fused linear cross entropy loss. Default is True.
47 | `cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
48 | If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
49 | rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
50 | swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
51 | model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
52 | loaded. Default is None.
53 | """
54 |
55 | assert not (
56 | cross_entropy and fused_linear_cross_entropy
57 | ), "cross_entropy and fused_linear_cross_entropy cannot both be True."
58 |
59 | from transformers.models.llama import modeling_llama
60 |
61 | from . import modeling_lckv
62 |
63 | if rope:
64 | modeling_lckv.apply_rotary = liger_rotary
65 | if rms_norm:
66 | modeling_llama.LlamaRMSNorm = LigerRMSNorm
67 | if swiglu:
68 | modeling_llama.LlamaMLP = LigerSwiGLUMLP
69 | if cross_entropy:
70 | modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
71 | if fused_linear_cross_entropy:
72 | modeling_lckv.LlamaForCausalLM.forward = llama_lce_forward
73 |
74 | if model is not None:
75 | # The model instance already exists, so we need to additionally patch the
76 | # instance variables that reference already-instantiated modules (e.g. LlamaRMSNorm or LlamaMLP)
77 |
78 | base_model = getattr(model, model.base_model_prefix, model)
79 |
80 | if rms_norm:
81 | _patch_rms_norm_module(base_model.norm)
82 |
83 | for decoder_layer in base_model.layers:
84 | if swiglu:
85 | _bind_method_to_module(
86 | decoder_layer.mlp, "forward", LigerSwiGLUMLP.forward
87 | )
88 | if rms_norm:
89 | _patch_rms_norm_module(decoder_layer.input_layernorm)
90 | _patch_rms_norm_module(decoder_layer.post_attention_layernorm)
91 |
--------------------------------------------------------------------------------
/models/ops_rope.py:
--------------------------------------------------------------------------------
1 | """https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/rope.py"""
2 |
3 | import torch
4 | import triton
5 | import triton.language as tl
6 |
7 |
8 | @triton.jit
9 | def _triton_rope(
10 | q_ptr,
11 | q_row_stride,
12 | cos,
13 | cos_row_stride,
14 | sin,
15 | sin_row_stride,
16 | sl,
17 | bs: tl.constexpr,
18 | n_qh: tl.constexpr,
19 | hd: tl.constexpr,
20 | pad_n_qh: tl.constexpr,
21 | pad_hd: tl.constexpr,
22 | BLOCK_SIZE: tl.constexpr,
23 | BACKWARD_PASS: tl.constexpr = False,
24 | ):
25 | # q size: (bsz, seq_len, num_q_heads, head_dim)
26 | # q stride: (seq_len * num_q_heads * head_dim, num_q_heads * head_dim, head_dim, 1)
27 |
28 | # cos size: (1, seq_len, head_dim)
29 | # stride: (seq_len * head_dim, head_dim, 1)
30 | pid = tl.program_id(0)
31 |
32 | # locate start address
33 | q_ptr = q_ptr + pid * q_row_stride
34 |
35 | # ####################################################################
36 | # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
37 | # m of this program instance
38 | # ####################################################################
39 |
40 | # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
41 | # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
42 | # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
43 | # and pid % sl to get the sequence index.
44 | # 2. We only need the left half of cos and sin matrix because the right half is just
45 | # a clone of the left half.
46 | cos_row_idx = pid % (sl)
47 | cos = cos + cos_row_idx * cos_row_stride
48 | sin = sin + cos_row_idx * sin_row_stride
49 | cos_offsets = tl.arange(0, pad_hd // 2)
50 | cos_mask = cos_offsets < hd // 2
51 | cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)
52 | sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)
53 |
54 | # ####################################################################
55 | # Load the left and right half of q for the current
56 | # program instance (i.e. for the current token) separately
57 | # ####################################################################
58 | # left half of the head
59 | first_half_q_offsets = (
60 | tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
61 | )
62 | first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
63 | tl.arange(0, pad_hd // 2)[None, :] < hd // 2
64 | )
65 | q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
66 | sin_row.dtype
67 | )
68 |
69 | # right half of the head
70 | second_half_q_offsets = first_half_q_offsets + (hd // 2)
71 | second_q_mask = first_q_mask
72 | q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
73 | sin_row.dtype
74 | )
75 |
76 | if not BACKWARD_PASS:
77 | # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
78 | new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
79 | tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
80 | new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
81 | tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
82 | else:
83 | # with some math, we can get:
84 | # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
85 | new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
86 | tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
87 | new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
88 | tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
89 |
90 |
91 | def rope_forward(q, cos, sin):
92 |
93 | # transpose it back to the physical shape because Triton looks at the physical storage
94 | # note: q is incontiguous before the transformation and will become contiguous after transpose
95 | q = q.transpose(1, 2)
96 |
97 | batch_size, seq_len, n_q_head, head_dim = q.shape
98 | pad_hd = triton.next_power_of_2(head_dim)
99 | pad_n_q_head = triton.next_power_of_2(n_q_head)
100 | BLOCK_SIZE = pad_n_q_head
101 |
102 | n_row = batch_size * seq_len
103 |
104 | # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
105 | q = q.contiguous()
106 | cos = cos.contiguous()
107 | sin = sin.contiguous()
108 |
109 | _triton_rope[(n_row,)](
110 | q,
111 | q.stride(1),
112 | cos,
113 | cos.stride(-2),
114 | sin,
115 | sin.stride(-2),
116 | seq_len,
117 | batch_size,
118 | n_q_head,
119 | head_dim,
120 | pad_n_q_head,
121 | pad_hd,
122 | BLOCK_SIZE=BLOCK_SIZE,
123 | BACKWARD_PASS=False,
124 | )
125 | return q.transpose(1, 2), cos, sin
126 |
127 |
128 | def rope_backward(dq, cos, sin):
129 | dq = dq.transpose(1, 2)
130 |
131 | batch_size, seq_len, n_q_head, head_dim = dq.shape
132 | pad_hd = triton.next_power_of_2(head_dim)
133 | pad_n_q_head = triton.next_power_of_2(n_q_head)
134 | BLOCK_SIZE = pad_n_q_head
135 |
136 | n_row = batch_size * seq_len
137 |
138 | # ensure dq is contiguous
139 | dq = dq.contiguous()
140 |
141 | # backward is similar to forward except swapping few ops
142 | _triton_rope[(n_row,)](
143 | dq,
144 | dq.stride(1),
145 | cos,
146 | cos.stride(-2),
147 | sin,
148 | sin.stride(-2),
149 | seq_len,
150 | batch_size,
151 | n_q_head,
152 | head_dim,
153 | pad_n_q_head,
154 | pad_hd,
155 | BLOCK_SIZE=BLOCK_SIZE,
156 | BACKWARD_PASS=True,
157 | )
158 | return dq.transpose(1, 2)
159 |
160 |
161 | class SingleLigerRopeFunction(torch.autograd.Function):
162 | """
163 | This function re-implements the RoPE operation with only one input tensor.
164 | """
165 |
166 | @staticmethod
167 | def forward(ctx, q, cos, sin, position_ids=None, unsqueeze_dim=1):
168 | """
169 | q size: (bsz, n_q_head, seq_len, head_dim)
170 | cos size: (1, seq_len, head_dim)
171 | sin size: (1, seq_len, head_dim)
172 | """
173 | q, cos, sin = rope_forward(q, cos, sin)
174 | ctx.save_for_backward(cos, sin)
175 | return q
176 |
177 | def backward(ctx, dq):
178 | """
179 | dq size: (bsz, n_q_head, seq_len, head_dim)
180 | cos size: (1, seq_len, head_dim)
181 | sin size: (1, seq_len, head_dim)
182 | """
183 |
184 | cos, sin = ctx.saved_tensors
185 | dq = rope_backward(dq, cos, sin)
186 | return dq, None, None, None, None
187 |
--------------------------------------------------------------------------------
/models/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | from dataclasses import dataclass
3 | from typing import List, Optional
4 |
5 | import torch
6 |
7 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
8 |
9 |
10 | @dataclass
11 | class IterStep:
12 | """A helper class for the iteration plan"""
13 | layer_slice: slice = slice(None)
14 | requires_grad: bool = True
15 | update: bool = True
16 |
17 | @dataclass
18 | class LayerType:
19 | """A helper class to collect the layer type information"""
20 | layer_idx: int
21 | use_sliding_window: bool
22 | attends_to: int
23 | attends_top: bool
24 | computes_kv: bool
25 |
26 | class LayerTypeParser:
27 | """
28 | A helper class to parse the layer type string and provide some useful methods.
29 |
30 | Arguments:
31 | layer_type (str): A string of integers separated by underscores. The i-th integer
32 | means the layer will use the key-value pair in the i-th layer as the kv cache.
33 | Special characters may be placed after the integers:
34 | - `s` means the layer will use sliding window attention.
35 |
36 | >>> layer_type = LayerTypeParser("0_0_0_5s_5s_5s_8_8_8")[3]
37 | >>> layer_type.attends_to
38 | 5
39 | >>> layer_type.attends_top
40 | True
41 | >>> layer_type.use_sliding_window
42 | True
43 | """
44 | def __init__(self, layer_type: str):
45 | self._layer_type = layer_type
46 |
47 | # parse the layer type
48 | self.layer_indices = []
49 | self.sliding_window = []
50 | for s in layer_type.split("_"):
51 | layer_idx, sliding_window = re.match(r"^(\d+)(s)?$", s).groups()
52 | self.layer_indices.append(int(layer_idx))
53 | self.sliding_window.append(bool(sliding_window))
54 |
55 | def __len__(self):
56 | return len(self.layer_indices)
57 |
58 | def __getitem__(self, layer_idx: int) -> LayerType:
59 | """return the layer type information for the given layer index"""
60 | return LayerType(
61 | layer_idx=layer_idx,
62 | use_sliding_window=self.sliding_window[layer_idx],
63 | attends_to=self.layer_indices[layer_idx],
64 | attends_top=self.layer_indices[layer_idx] > layer_idx,
65 | computes_kv=layer_idx in self.layer_indices,
66 | )
67 |
68 | def use_sliding_window(self) -> bool:
69 | """whether there exists a layer that uses sliding window attention"""
70 | return any(self.sliding_window)
71 |
72 | def attends_top(self) -> bool:
73 | """whether there exists a layer that attends to layers above it"""
74 | return any(self.layer_indices[i] > i for i in range(len(self)))
75 |
76 | def iteration_plan(self, forward_passes: int = 7, backward_passes: int = 2) -> List[IterStep]:
77 | """
78 | Return a iteration plan for the layer types. The plan is a list of IterStep objects.
79 | """
80 | # if there is no cyclic dependency, return the default plan
81 | if not self.attends_top():
82 | return [IterStep()]
83 |
84 | # otherwise, return the plan for the cyclic dependency
85 | plan = []
86 | i = 0
87 | while i < len(self):
88 |
89 | # if the layer attends to top layers, resolve the cyclic dependency
90 | if self[i].attends_top:
91 |
92 | # find the top layer in the cyclic dependency
93 | top = self[i].attends_to
94 | while top < max(self.layer_indices[i: top + 1]):
95 | top = max(self.layer_indices[i: top + 1])
96 | top += 1
97 |
98 | # create iteration plan for this group
99 | layer_slice = slice(i, top)
100 | plan.extend([
101 | *forward_passes * [IterStep(layer_slice, requires_grad=False, update=False)],
102 | *(backward_passes - 1) * [IterStep(layer_slice, update=False)],
103 | IterStep(layer_slice)
104 | ])
105 |
106 | # otherwise, create a default plan
107 | else:
108 |
109 | top = i + 1
110 | while top < len(self) and not self[top].attends_top:
111 | top += 1
112 | plan.append(IterStep(slice(i, top)))
113 |
114 | # update the index
115 | i = top
116 |
117 | return plan
118 |
119 | def check(self, num_hidden_layers: int):
120 | """Check if the layer type is valid"""
121 | if len(self.layer_indices) != num_hidden_layers:
122 | raise ValueError("The number of layer types should be equal to the number of hidden layers.")
123 | for i in range(num_hidden_layers):
124 | if self.layer_indices[i] not in range(num_hidden_layers):
125 | raise ValueError("The layer type should be in the range of the number of hidden layers.")
126 |
127 |
128 | def flash_attention_forward(
129 | query_states: torch.Tensor,
130 | key_states: torch.Tensor,
131 | value_states: torch.Tensor,
132 | attention_mask: torch.Tensor,
133 | query_length: int,
134 | is_causal: bool,
135 | dropout: float = 0.0,
136 | position_ids: Optional[torch.Tensor] = None,
137 | softmax_scale: Optional[float] = None,
138 | sliding_window: Optional[int] = None,
139 | use_top_left_mask: bool = False,
140 | softcap: Optional[float] = None,
141 | deterministic: bool = None,
142 | no_diag: bool = False,
143 | ):
144 | """
145 | This function is a wrapper around the _flash_attention_forward function in the
146 | transformers library. It adds support to mask the diagonal elements of the attention
147 | matrix. The diagonal mask is used to resolve the cyclic dependencies in the LCKV model.
148 | """
149 | prune_query = False
150 | if no_diag:
151 | if key_states.size(1) == 1:
152 | b, l, _, d = value_states.size()
153 | _, _, h, _ = query_states.size()
154 | return value_states.new_zeros((b, l, h, d))
155 |
156 | if key_states.size(1) == query_states.size(1):
157 | prune_query = True
158 | query_states = query_states[:, 1:, :, :]
159 | query_length -= 1
160 |
161 | if attention_mask is not None:
162 | attention_mask = attention_mask[:, 1:]
163 |
164 | key_states = key_states[:, :-1, :, :]
165 | value_states = value_states[:, :-1, :, :]
166 |
167 | if sliding_window is not None:
168 | sliding_window = sliding_window - 1
169 |
170 | result: torch.Tensor = _flash_attention_forward(
171 | query_states=query_states,
172 | key_states=key_states,
173 | value_states=value_states,
174 | attention_mask=attention_mask,
175 | query_length=query_length,
176 | is_causal=is_causal,
177 | dropout=dropout,
178 | position_ids=position_ids,
179 | softmax_scale=softmax_scale,
180 | sliding_window=sliding_window,
181 | use_top_left_mask=use_top_left_mask,
182 | softcap=softcap,
183 | deterministic=deterministic,
184 | )
185 |
186 | if prune_query:
187 | b, _, h, d = result.size()
188 | result = torch.cat([result.new_zeros((b, 1, h, d)), result], dim=1)
189 |
190 | return result
191 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.ruff]
2 | line-length = 119
3 |
4 | [tool.ruff.lint]
5 | # Never enforce `E501` (line length violations).
6 | ignore = ["C901", "E501", "E741", "F402", "F823" ]
7 | select = ["C", "E", "F", "I", "W"]
8 |
9 | # Ignore import violations in all `__init__.py` files.
10 | [tool.ruff.lint.per-file-ignores]
11 | "__init__.py" = ["E402", "F401", "F403", "F811"]
12 |
13 | [tool.ruff.lint.isort]
14 | lines-after-imports = 2
15 | known-first-party = ["transformers"]
16 |
17 | [tool.ruff.format]
18 | # Like Black, use double quotes for strings.
19 | quote-style = "double"
20 |
21 | # Like Black, indent with spaces, rather than tabs.
22 | indent-style = "space"
23 |
24 | # Like Black, respect magic trailing commas.
25 | skip-magic-trailing-comma = false
26 |
27 | # Like Black, automatically detect the appropriate line ending.
28 | line-ending = "auto"
29 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers ~= 4.45.2
2 | trl ~= 0.11.4
3 | accelerate ~= 1.0.1
4 | datasets
5 | evaluate
6 | scikit-learn
7 | sentencepiece
8 | liger-kernel ~= 0.3.0
9 | flash-attn >= 2.6.3
10 | git+https://github.com/EleutherAI/lm-evaluation-harness@v0.4.5
11 |
--------------------------------------------------------------------------------
/run_chat.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # Copyright 2024 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from trl.commands.cli_utils import init_zero_verbose
18 |
19 | init_zero_verbose()
20 |
21 | import copy
22 | import json
23 | import os
24 | import sys
25 | import pwd
26 | import re
27 | import time
28 | from threading import Thread
29 |
30 | import torch
31 | from rich.console import Console
32 | from rich.live import Live
33 | from rich.markdown import Markdown
34 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
35 | import models
36 |
37 | from trl.commands.cli_utils import ChatArguments, TrlParser, init_zero_verbose
38 | from trl.trainer.utils import get_quantization_config
39 |
40 |
41 | HELP_STRING = """\
42 |
43 | **TRL CHAT INTERFACE**
44 |
45 | The chat interface is a simple tool to try out a chat model.
46 |
47 | Besides talking to the model there are several commands:
48 | - **clear**: clears the current conversation and start a new one
49 | - **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
50 | - **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
51 | - **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
52 | - **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
53 | - **exit**: closes the interface
54 | """
55 |
56 | SUPPORTED_GENERATION_KWARGS = [
57 | "max_new_tokens",
58 | "do_sample",
59 | "num_beams",
60 | "temperature",
61 | "top_p",
62 | "top_k",
63 | "repetition_penalty",
64 | ]
65 |
66 | SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
67 |
68 |
69 | # the default chat template for timdettmers/openassistant-guanaco
70 | DEFAULT_CHAT_TEMPLATE = """
71 | {%- for message in messages %}
72 | {%- if message['role'] == 'user' %}
73 | {{- 'Human: ' + message['content'].strip() }}
74 | {%- elif message['role'] == 'assistant' %}
75 | {{- 'Assistant: ' + message['content'].strip() }}
76 | {%- endif %}
77 | {%- endfor %}
78 | """
79 |
80 |
81 | class RichInterface:
82 | def __init__(self, model_name=None, user_name=None):
83 | self._console = Console()
84 | if model_name is None:
85 | self.model_name = "assistant"
86 | else:
87 | self.model_name = model_name
88 | if user_name is None:
89 | self.user_name = "user"
90 | else:
91 | self.user_name = user_name
92 |
93 | def stream_output(self, output_stream):
94 | """Stream output from a role."""
95 | # This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
96 | # Create a Live context for updating the console output
97 | text = ""
98 | self._console.print(f"[bold blue]<{self.model_name}>:")
99 | with Live(console=self._console, refresh_per_second=4) as live:
100 | # Read lines from the stream
101 | for i, outputs in enumerate(output_stream):
102 | if not outputs or i == 0:
103 | continue
104 | text += outputs
105 | # Render the accumulated text as Markdown
106 | # NOTE: this is a workaround for the rendering "unstandard markdown"
107 | # in rich. The chatbots output treat "\n" as a new line for
108 | # better compatibility with real-world text. However, rendering
109 | # in markdown would break the format. It is because standard markdown
110 | # treat a single "\n" in normal text as a space.
111 | # Our workaround is adding two spaces at the end of each line.
112 | # This is not a perfect solution, as it would
113 | # introduce trailing spaces (only) in code block, but it works well
114 | # especially for console output, because in general the console does not
115 | # care about trailing spaces.
116 | lines = []
117 | for line in text.splitlines():
118 | lines.append(line)
119 | if line.startswith("```"):
120 | # Code block marker - do not add trailing spaces, as it would
121 | # break the syntax highlighting
122 | lines.append("\n")
123 | else:
124 | lines.append(" \n")
125 | markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
126 | # Update the Live console output
127 | live.update(markdown)
128 | self._console.print()
129 | return text
130 |
131 | def input(self):
132 | input = self._console.input(f"[bold red]<{self.user_name}>:\n")
133 | self._console.print()
134 | return input
135 |
136 | def clear(self):
137 | self._console.clear()
138 |
139 | def print_user_message(self, text):
140 | self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
141 | self._console.print()
142 |
143 | def print_green(self, text):
144 | self._console.print(f"[bold green]{text}")
145 | self._console.print()
146 |
147 | def print_red(self, text):
148 | self._console.print(f"[bold red]{text}")
149 | self._console.print()
150 |
151 | def print_help(self):
152 | self._console.print(Markdown(HELP_STRING))
153 | self._console.print()
154 |
155 |
156 | def get_username():
157 | return pwd.getpwuid(os.getuid())[0]
158 |
159 |
160 | def create_default_filename(model_name):
161 | time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
162 | return f"{model_name}/chat_{time_str}.json"
163 |
164 |
165 | def save_chat(chat, args, filename):
166 | output_dict = {}
167 | output_dict["settings"] = vars(args)
168 | output_dict["chat_history"] = chat
169 |
170 | folder = args.save_folder
171 |
172 | if filename is None:
173 | filename = create_default_filename(args.model_name_or_path)
174 | filename = os.path.join(folder, filename)
175 | os.makedirs(os.path.dirname(filename), exist_ok=True)
176 |
177 | with open(filename, "w") as f:
178 | json.dump(output_dict, f, indent=4)
179 | return os.path.abspath(filename)
180 |
181 |
182 | def clear_chat_history(system_prompt):
183 | if system_prompt is None:
184 | chat = []
185 | else:
186 | chat = [{"role": "system", "content": system_prompt}]
187 | return chat
188 |
189 |
190 | def parse_settings(user_input, current_args, interface):
191 | settings = user_input[4:].strip().split(";")
192 | settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
193 | settings = dict(settings)
194 | error = False
195 |
196 | for name in settings:
197 | if hasattr(current_args, name):
198 | try:
199 | if isinstance(getattr(current_args, name), bool):
200 | if settings[name] == "True":
201 | settings[name] = True
202 | elif settings[name] == "False":
203 | settings[name] = False
204 | else:
205 | raise ValueError
206 | else:
207 | settings[name] = type(getattr(current_args, name))(settings[name])
208 | except ValueError:
209 | interface.print_red(
210 | f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
211 | )
212 | else:
213 | interface.print_red(f"There is no '{name}' setting.")
214 |
215 | if error:
216 | interface.print_red("There was an issue parsing the settings. No settings have been changed.")
217 | return current_args, False
218 | else:
219 | for name in settings:
220 | setattr(current_args, name, settings[name])
221 | interface.print_green(f"Set {name} to {settings[name]}.")
222 |
223 | time.sleep(1.5) # so the user has time to read the changes
224 | return current_args, True
225 |
226 |
227 | def load_model_and_tokenizer(args):
228 | tokenizer = AutoTokenizer.from_pretrained(
229 | args.model_name_or_path,
230 | revision=args.model_revision,
231 | trust_remote_code=args.trust_remote_code,
232 | )
233 |
234 | torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
235 | quantization_config = get_quantization_config(args)
236 | model_kwargs = dict(
237 | revision=args.model_revision,
238 | attn_implementation=args.attn_implementation,
239 | torch_dtype=torch_dtype,
240 | device_map="auto",
241 | quantization_config=quantization_config,
242 | )
243 | model = AutoModelForCausalLM.from_pretrained(
244 | args.model_name_or_path, trust_remote_code=args.trust_remote_code, **model_kwargs
245 | )
246 |
247 | if getattr(model, "hf_device_map", None) is None:
248 | model = model.to(args.device)
249 |
250 | return model, tokenizer
251 |
252 |
253 | def parse_eos_tokens(tokenizer, eos_tokens, eos_token_ids):
254 | if tokenizer.pad_token_id is None:
255 | pad_token_id = tokenizer.eos_token_id
256 | else:
257 | pad_token_id = tokenizer.pad_token_id
258 |
259 | all_eos_token_ids = []
260 |
261 | if eos_tokens is not None:
262 | all_eos_token_ids.extend(tokenizer.convert_tokens_to_ids(eos_tokens.split(",")))
263 |
264 | if eos_token_ids is not None:
265 | all_eos_token_ids.extend([int(token_id) for token_id in eos_token_ids.split(",")])
266 |
267 | if len(all_eos_token_ids) == 0:
268 | all_eos_token_ids.append(tokenizer.eos_token_id)
269 |
270 | return pad_token_id, all_eos_token_ids
271 |
272 |
273 | def chat_cli():
274 | parser = TrlParser(ChatArguments)
275 |
276 | if "--config" not in sys.argv:
277 | sys.argv.append("--config")
278 | sys.argv.append(os.path.join(os.path.dirname(__file__), "configs/default_chat_config.yaml"))
279 | args = parser.parse_args_and_config()[0]
280 | if args.examples is None:
281 | args.examples = {}
282 |
283 | current_args = copy.deepcopy(args)
284 |
285 | if args.user is None:
286 | user = get_username()
287 | else:
288 | user = args.user
289 |
290 | model, tokenizer = load_model_and_tokenizer(args)
291 | generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
292 |
293 | pad_token_id, eos_token_ids = parse_eos_tokens(tokenizer, args.eos_tokens, args.eos_token_ids)
294 |
295 | interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
296 | interface.clear()
297 | chat = clear_chat_history(current_args.system_prompt)
298 | while True:
299 | try:
300 | user_input = interface.input()
301 |
302 | if user_input == "clear":
303 | chat = clear_chat_history(current_args.system_prompt)
304 | interface.clear()
305 | continue
306 |
307 | if user_input == "help":
308 | interface.print_help()
309 | continue
310 |
311 | if user_input == "exit":
312 | break
313 |
314 | if user_input == "reset":
315 | interface.clear()
316 | current_args = copy.deepcopy(args)
317 | chat = clear_chat_history(current_args.system_prompt)
318 | continue
319 |
320 | if user_input.startswith("save") and len(user_input.split()) < 2:
321 | split_input = user_input.split()
322 |
323 | if len(split_input) == 2:
324 | filename = split_input[1]
325 | else:
326 | filename = None
327 | filename = save_chat(chat, current_args, filename)
328 | interface.print_green(f"Chat saved in {filename}!")
329 | continue
330 |
331 | if re.match(SETTING_RE, user_input):
332 | current_args, success = parse_settings(user_input, current_args, interface)
333 | if success:
334 | chat = []
335 | interface.clear()
336 | continue
337 |
338 | if user_input.startswith("example") and len(user_input.split()) == 2:
339 | example_name = user_input.split()[1]
340 | if example_name in current_args.examples:
341 | interface.clear()
342 | chat = []
343 | interface.print_user_message(current_args.examples[example_name]["text"])
344 | user_input = current_args.examples[example_name]["text"]
345 | else:
346 | interface.print_red(
347 | f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}."
348 | )
349 | continue
350 |
351 | chat.append({"role": "user", "content": user_input})
352 |
353 | if tokenizer.chat_template is None:
354 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
355 |
356 | generation_kwargs = dict(
357 | inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
358 | model.device
359 | ),
360 | streamer=generation_streamer,
361 | max_new_tokens=current_args.max_new_tokens,
362 | do_sample=current_args.do_sample,
363 | num_beams=current_args.num_beams,
364 | temperature=current_args.temperature,
365 | top_k=current_args.top_k,
366 | top_p=current_args.top_p,
367 | repetition_penalty=current_args.repetition_penalty,
368 | pad_token_id=pad_token_id,
369 | eos_token_id=eos_token_ids,
370 | )
371 |
372 | thread = Thread(target=model.generate, kwargs=generation_kwargs)
373 | thread.start()
374 | model_output = interface.stream_output(generation_streamer)
375 | thread.join()
376 | chat.append({"role": "assistant", "content": model_output})
377 |
378 | except KeyboardInterrupt:
379 | break
380 |
381 |
382 | if __name__ == "__main__":
383 | chat_cli()
384 |
--------------------------------------------------------------------------------
/run_clm.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18 |
19 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20 | https://huggingface.co/models?filter=text-generation
21 | """
22 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23 |
24 | import logging
25 | import math
26 | import os
27 | import sys
28 | import warnings
29 | from dataclasses import dataclass, field
30 | from itertools import chain
31 | from typing import Optional
32 |
33 | import datasets
34 | import evaluate
35 | import torch
36 | from datasets import load_dataset
37 |
38 | import transformers
39 | from transformers import (
40 | CONFIG_MAPPING,
41 | MODEL_FOR_CAUSAL_LM_MAPPING,
42 | AutoConfig,
43 | AutoModelForCausalLM,
44 | AutoTokenizer,
45 | HfArgumentParser,
46 | Trainer,
47 | TrainingArguments,
48 | default_data_collator,
49 | is_torch_tpu_available,
50 | set_seed,
51 | )
52 | from transformers.testing_utils import CaptureLogger
53 | from transformers.trainer_utils import get_last_checkpoint
54 | from transformers.utils import check_min_version, send_example_telemetry
55 | from transformers.utils.versions import require_version
56 |
57 | import models
58 |
59 |
60 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
61 | # check_min_version("4.33.0.dev0")
62 |
63 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
64 |
65 | logger = logging.getLogger(__name__)
66 |
67 |
68 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
69 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
70 |
71 |
72 | # learning rate decay scheduler (cosine with warmup)
73 | def tinyllama_get_lr(
74 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
75 | ):
76 | final_div_factor = 0.1
77 | if current_step < num_warmup_steps:
78 | return float(current_step) / float(max(1, num_warmup_steps))
79 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
80 | coeff = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
81 | return final_div_factor + coeff * (1 - final_div_factor)
82 |
83 | ## Uncomment the following line to use the cosine schedule with the minimum lr set to 10% of the initial lr
84 | # transformers.optimization._get_cosine_schedule_with_warmup_lr_lambda = tinyllama_get_lr
85 |
86 |
87 | @dataclass
88 | class ModelArguments:
89 | """
90 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
91 | """
92 |
93 | model_name_or_path: Optional[str] = field(
94 | default=None,
95 | metadata={
96 | "help": (
97 | "The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
98 | )
99 | },
100 | )
101 | model_type: Optional[str] = field(
102 | default=None,
103 | metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
104 | )
105 | config_overrides: Optional[str] = field(
106 | default=None,
107 | metadata={
108 | "help": (
109 | "Override some existing default config settings when a model is trained from scratch. Example: "
110 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
111 | )
112 | },
113 | )
114 | config_name: Optional[str] = field(
115 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
116 | )
117 | tokenizer_name: Optional[str] = field(
118 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
119 | )
120 | cache_dir: Optional[str] = field(
121 | default=None,
122 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
123 | )
124 | use_fast_tokenizer: bool = field(
125 | default=True,
126 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
127 | )
128 | model_revision: str = field(
129 | default="main",
130 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
131 | )
132 | token: str = field(
133 | default=None,
134 | metadata={
135 | "help": (
136 | "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
137 | "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
138 | )
139 | },
140 | )
141 | use_auth_token: bool = field(
142 | default=None,
143 | metadata={
144 | "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
145 | },
146 | )
147 | trust_remote_code: bool = field(
148 | default=False,
149 | metadata={
150 | "help": (
151 | "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
152 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will"
153 | "execute code present on the Hub on your local machine."
154 | )
155 | },
156 | )
157 | torch_dtype: Optional[str] = field(
158 | default=None,
159 | metadata={
160 | "help": (
161 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
162 | "dtype will be automatically derived from the model's weights."
163 | ),
164 | "choices": ["auto", "bfloat16", "float16", "float32"],
165 | },
166 | )
167 | low_cpu_mem_usage: bool = field(
168 | default=False,
169 | metadata={
170 | "help": (
171 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
172 | "set True will benefit LLM loading time and RAM consumption."
173 | )
174 | },
175 | )
176 |
177 | attn_implementation: Optional[str] = field(
178 | default="flash_attention_2",
179 | metadata={
180 | "help": (
181 | "The attention implementation to use. By default the LCKV implementation uses flash attention implementation."
182 | ),
183 | "choices": ["eager", "flash_attention_2", "sdpa"],
184 | },
185 | )
186 |
187 |
188 | @dataclass
189 | class DataTrainingArguments:
190 | """
191 | Arguments pertaining to what data we are going to input our model for training and eval.
192 | """
193 |
194 | dataset_name: Optional[str] = field(
195 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
196 | )
197 | dataset_config_name: Optional[str] = field(
198 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
199 | )
200 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
201 | validation_file: Optional[str] = field(
202 | default=None,
203 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
204 | )
205 | max_train_samples: Optional[int] = field(
206 | default=None,
207 | metadata={
208 | "help": (
209 | "For debugging purposes or quicker training, truncate the number of training examples to this "
210 | "value if set."
211 | )
212 | },
213 | )
214 | max_eval_samples: Optional[int] = field(
215 | default=None,
216 | metadata={
217 | "help": (
218 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
219 | "value if set."
220 | )
221 | },
222 | )
223 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
224 | block_size: Optional[int] = field(
225 | default=None,
226 | metadata={
227 | "help": (
228 | "Optional input sequence length after tokenization. "
229 | "The training dataset will be truncated in block of this size for training. "
230 | "Default to the model max input length for single sentence inputs (take into account special tokens)."
231 | )
232 | },
233 | )
234 | overwrite_cache: bool = field(
235 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
236 | )
237 | validation_split_percentage: Optional[int] = field(
238 | default=5,
239 | metadata={
240 | "help": "The percentage of the train set used as validation set in case there's no validation split"
241 | },
242 | )
243 | preprocessing_num_workers: Optional[int] = field(
244 | default=None,
245 | metadata={"help": "The number of processes to use for the preprocessing."},
246 | )
247 | keep_linebreaks: bool = field(
248 | default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
249 | )
250 |
251 | def __post_init__(self):
252 | if self.streaming:
253 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
254 |
255 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
256 | raise ValueError("Need either a dataset name or a training/validation file.")
257 | else:
258 | if self.train_file is not None:
259 | extension = self.train_file.split(".")[-1]
260 | assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
261 | if self.validation_file is not None:
262 | extension = self.validation_file.split(".")[-1]
263 | assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
264 |
265 |
266 | def main():
267 | # See all possible arguments in src/transformers/training_args.py
268 | # or by passing the --help flag to this script.
269 | # We now keep distinct sets of args, for a cleaner separation of concerns.
270 |
271 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
272 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
273 | # If we pass only one argument to the script and it's the path to a json file,
274 | # let's parse it to get our arguments.
275 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
276 | else:
277 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
278 |
279 | if model_args.use_auth_token is not None:
280 | warnings.warn("The `use_auth_token` argument is deprecated and will be removed in v4.34.", FutureWarning)
281 | if model_args.token is not None:
282 | raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
283 | model_args.token = model_args.use_auth_token
284 |
285 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
286 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
287 | # send_example_telemetry("run_clm", model_args, data_args)
288 |
289 | # Setup logging
290 | logging.basicConfig(
291 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
292 | datefmt="%m/%d/%Y %H:%M:%S",
293 | handlers=[logging.StreamHandler(sys.stdout)],
294 | )
295 |
296 | if training_args.should_log:
297 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
298 | transformers.utils.logging.set_verbosity_info()
299 |
300 | log_level = training_args.get_process_log_level()
301 | logger.setLevel(log_level)
302 | datasets.utils.logging.set_verbosity(log_level)
303 | transformers.utils.logging.set_verbosity(log_level)
304 | transformers.utils.logging.enable_default_handler()
305 | transformers.utils.logging.enable_explicit_format()
306 |
307 | # Log on each process the small summary:
308 | logger.warning(
309 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
310 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
311 | )
312 | logger.info(f"Training/evaluation parameters {training_args}")
313 |
314 | # Detecting last checkpoint.
315 | last_checkpoint = None
316 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
317 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
318 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
319 | raise ValueError(
320 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
321 | "Use --overwrite_output_dir to overcome."
322 | )
323 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
324 | logger.info(
325 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
326 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
327 | )
328 |
329 | # Set seed before initializing model.
330 | set_seed(training_args.seed)
331 |
332 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
333 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
334 | # (the dataset will be downloaded automatically from the datasets Hub).
335 | #
336 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
337 | # 'text' is found. You can easily tweak this behavior (see below).
338 | #
339 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
340 | # download the dataset.
341 | if data_args.dataset_name is not None:
342 | # Downloading and loading a dataset from the hub.
343 | raw_datasets = load_dataset(
344 | data_args.dataset_name,
345 | data_args.dataset_config_name,
346 | cache_dir=model_args.cache_dir,
347 | token=model_args.token,
348 | streaming=data_args.streaming,
349 | download_config=datasets.DownloadConfig(resume_download=True,max_retries=10000)
350 | )
351 | if "validation" not in raw_datasets.keys():
352 | raw_datasets["validation"] = load_dataset(
353 | data_args.dataset_name,
354 | data_args.dataset_config_name,
355 | split=f"train[:{data_args.validation_split_percentage}%]",
356 | cache_dir=model_args.cache_dir,
357 | token=model_args.token,
358 | streaming=data_args.streaming,
359 | )
360 | raw_datasets["train"] = load_dataset(
361 | data_args.dataset_name,
362 | data_args.dataset_config_name,
363 | split=f"train[{data_args.validation_split_percentage}%:]",
364 | cache_dir=model_args.cache_dir,
365 | token=model_args.token,
366 | streaming=data_args.streaming,
367 | )
368 | else:
369 | data_files = {}
370 | dataset_args = {}
371 | if data_args.train_file is not None:
372 | data_files["train"] = data_args.train_file
373 | if data_args.validation_file is not None:
374 | data_files["validation"] = data_args.validation_file
375 | extension = (
376 | data_args.train_file.split(".")[-1]
377 | if data_args.train_file is not None
378 | else data_args.validation_file.split(".")[-1]
379 | )
380 | if extension == "txt":
381 | extension = "text"
382 | dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
383 | raw_datasets = load_dataset(
384 | extension,
385 | data_files=data_files,
386 | cache_dir=model_args.cache_dir,
387 | token=model_args.token,
388 | **dataset_args,
389 | )
390 | # If no validation data is there, validation_split_percentage will be used to divide the dataset.
391 | if "validation" not in raw_datasets.keys():
392 | raw_datasets["validation"] = load_dataset(
393 | extension,
394 | data_files=data_files,
395 | split=f"train[:{data_args.validation_split_percentage}%]",
396 | cache_dir=model_args.cache_dir,
397 | token=model_args.token,
398 | **dataset_args,
399 | )
400 | raw_datasets["train"] = load_dataset(
401 | extension,
402 | data_files=data_files,
403 | split=f"train[{data_args.validation_split_percentage}%:]",
404 | cache_dir=model_args.cache_dir,
405 | token=model_args.token,
406 | **dataset_args,
407 | )
408 |
409 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
410 | # https://huggingface.co/docs/datasets/loading_datasets.html.
411 |
412 | # Load pretrained model and tokenizer
413 | #
414 | # Distributed training:
415 | # The .from_pretrained methods guarantee that only one local process can concurrently
416 | # download model & vocab.
417 |
418 | config_kwargs = {
419 | "cache_dir": model_args.cache_dir,
420 | "revision": model_args.model_revision,
421 | "token": model_args.token,
422 | "trust_remote_code": model_args.trust_remote_code,
423 | }
424 | if model_args.config_name:
425 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
426 | elif model_args.model_name_or_path:
427 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
428 | else:
429 | config = CONFIG_MAPPING[model_args.model_type]()
430 | logger.warning("You are instantiating a new config instance from scratch.")
431 |
432 | # allow config overrides under all circumstances
433 | if model_args.config_overrides is not None:
434 | logger.info(f"Overriding config: {model_args.config_overrides}")
435 | config.update_from_string(model_args.config_overrides)
436 | logger.info(f"New config: {config}")
437 |
438 | tokenizer_kwargs = {
439 | "cache_dir": model_args.cache_dir,
440 | "use_fast": model_args.use_fast_tokenizer,
441 | "revision": model_args.model_revision,
442 | "token": model_args.token,
443 | "trust_remote_code": model_args.trust_remote_code,
444 | }
445 | if model_args.tokenizer_name:
446 | tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
447 | elif model_args.model_name_or_path:
448 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
449 | else:
450 | raise ValueError(
451 | "You are instantiating a new tokenizer from scratch. This is not supported by this script."
452 | "You can do it from another script, save it, and load it from here, using --tokenizer_name."
453 | )
454 |
455 | if model_args.model_name_or_path:
456 | torch_dtype = (
457 | model_args.torch_dtype
458 | if model_args.torch_dtype in ["auto", None]
459 | else getattr(torch, model_args.torch_dtype)
460 | )
461 | model = AutoModelForCausalLM.from_pretrained(
462 | model_args.model_name_or_path,
463 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
464 | config=config,
465 | cache_dir=model_args.cache_dir,
466 | revision=model_args.model_revision,
467 | token=model_args.token,
468 | trust_remote_code=model_args.trust_remote_code,
469 | torch_dtype=torch_dtype,
470 | low_cpu_mem_usage=model_args.low_cpu_mem_usage,
471 | attn_implementation=model_args.attn_implementation,
472 | )
473 | else:
474 | torch_dtype = (
475 | model_args.torch_dtype
476 | if model_args.torch_dtype in ["auto", None]
477 | else getattr(torch, model_args.torch_dtype)
478 | )
479 | config._attn_implementation = model_args.attn_implementation
480 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code, torch_dtype=torch_dtype)
481 | n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
482 | logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
483 |
484 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
485 | # on a small vocab and want a smaller embedding size, remove this test.
486 | embedding_size = model.get_input_embeddings().weight.shape[0]
487 | if len(tokenizer) > embedding_size:
488 | model.resize_token_embeddings(len(tokenizer))
489 |
490 | # Preprocessing the datasets.
491 | # First we tokenize all the texts.
492 | if training_args.do_train:
493 | column_names = list(raw_datasets["train"].features)
494 | else:
495 | column_names = list(raw_datasets["validation"].features)
496 | text_column_name = "text" if "text" in column_names else column_names[0]
497 |
498 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
499 | tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
500 |
501 | def tokenize_function(examples):
502 | with CaptureLogger(tok_logger) as cl:
503 | output = tokenizer(examples[text_column_name])
504 | # clm input could be much much longer than block_size
505 | if "Token indices sequence length is longer than the" in cl.out:
506 | tok_logger.warning(
507 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
508 | " before being passed to the model."
509 | )
510 | return output
511 |
512 | with training_args.main_process_first(desc="dataset map tokenization"):
513 | if not data_args.streaming:
514 | tokenized_datasets = raw_datasets.map(
515 | tokenize_function,
516 | batched=True,
517 | num_proc=data_args.preprocessing_num_workers,
518 | remove_columns=column_names,
519 | load_from_cache_file=not data_args.overwrite_cache,
520 | desc="Running tokenizer on dataset",
521 | )
522 | else:
523 | tokenized_datasets = raw_datasets.map(
524 | tokenize_function,
525 | batched=True,
526 | remove_columns=column_names,
527 | )
528 |
529 | if data_args.block_size is None:
530 | block_size = tokenizer.model_max_length
531 | if block_size > 1024:
532 | logger.warning(
533 | "The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
534 | " of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
535 | " override this default with `--block_size xxx`."
536 | )
537 | block_size = 1024
538 | else:
539 | if data_args.block_size > tokenizer.model_max_length:
540 | logger.warning(
541 | f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
542 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
543 | )
544 | block_size = min(data_args.block_size, tokenizer.model_max_length)
545 |
546 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
547 | def group_texts(examples):
548 | # Concatenate all texts.
549 | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
550 | total_length = len(concatenated_examples[list(examples.keys())[0]])
551 | # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
552 | # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
553 | total_length = (total_length // block_size) * block_size
554 | # Split by chunks of max_len.
555 | result = {
556 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
557 | for k, t in concatenated_examples.items()
558 | }
559 | result["labels"] = result["input_ids"].copy()
560 | return result
561 |
562 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
563 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
564 | # to preprocess.
565 | #
566 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
567 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
568 |
569 | with training_args.main_process_first(desc="grouping texts together"):
570 | if not data_args.streaming:
571 | lm_datasets = tokenized_datasets.map(
572 | group_texts,
573 | batched=True,
574 | num_proc=data_args.preprocessing_num_workers,
575 | load_from_cache_file=not data_args.overwrite_cache,
576 | desc=f"Grouping texts in chunks of {block_size}",
577 | )
578 | else:
579 | lm_datasets = tokenized_datasets.map(
580 | group_texts,
581 | batched=True,
582 | )
583 |
584 | if training_args.do_train:
585 | if "train" not in tokenized_datasets:
586 | raise ValueError("--do_train requires a train dataset")
587 | train_dataset = lm_datasets["train"]
588 | if data_args.max_train_samples is not None:
589 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
590 | train_dataset = train_dataset.select(range(max_train_samples))
591 |
592 | if training_args.do_eval:
593 | if "validation" not in tokenized_datasets:
594 | raise ValueError("--do_eval requires a validation dataset")
595 | eval_dataset = lm_datasets["validation"]
596 | if data_args.max_eval_samples is not None:
597 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
598 | eval_dataset = eval_dataset.select(range(max_eval_samples))
599 |
600 | def preprocess_logits_for_metrics(logits, labels):
601 | if isinstance(logits, tuple):
602 | # Depending on the model and config, logits may contain extra tensors,
603 | # like past_key_values, but logits always come first
604 | logits = logits[0]
605 | return logits.argmax(dim=-1)
606 |
607 | metric = evaluate.load("accuracy")
608 |
609 | def compute_metrics(eval_preds):
610 | preds, labels = eval_preds
611 | # preds have the same shape as the labels, after the argmax(-1) has been calculated
612 | # by preprocess_logits_for_metrics but we need to shift the labels
613 | labels = labels[:, 1:].reshape(-1)
614 | preds = preds[:, :-1].reshape(-1)
615 | return metric.compute(predictions=preds, references=labels)
616 |
617 | if training_args.do_predict:
618 | if "test" not in tokenized_datasets:
619 | raise ValueError("--do_predict requires a test dataset")
620 | test_dataset = lm_datasets["test"]
621 |
622 | # Initialize our Trainer
623 | trainer = Trainer(
624 | model=model,
625 | args=training_args,
626 | train_dataset=train_dataset if training_args.do_train else None,
627 | eval_dataset=eval_dataset if training_args.do_eval else None,
628 | tokenizer=tokenizer,
629 | # Data collator will default to DataCollatorWithPadding, so we change it.
630 | data_collator=default_data_collator,
631 | compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
632 | preprocess_logits_for_metrics=preprocess_logits_for_metrics
633 | if training_args.do_eval and not is_torch_tpu_available()
634 | else None,
635 | )
636 |
637 | # Training
638 | if training_args.do_train:
639 | checkpoint = None
640 | if training_args.resume_from_checkpoint is not None:
641 | checkpoint = training_args.resume_from_checkpoint
642 | elif last_checkpoint is not None:
643 | checkpoint = last_checkpoint
644 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
645 | trainer.save_model() # Saves the tokenizer too for easy upload
646 |
647 | metrics = train_result.metrics
648 |
649 | max_train_samples = (
650 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
651 | )
652 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
653 |
654 | trainer.log_metrics("train", metrics)
655 | trainer.save_metrics("train", metrics)
656 | trainer.save_state()
657 |
658 | # Evaluation
659 | if training_args.do_eval:
660 | logger.info("*** Evaluate ***")
661 |
662 | metrics = trainer.evaluate()
663 |
664 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
665 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
666 | try:
667 | perplexity = math.exp(metrics["eval_loss"])
668 | except OverflowError:
669 | perplexity = float("inf")
670 | metrics["eval_perplexity"] = perplexity
671 |
672 | trainer.log_metrics("eval", metrics)
673 | trainer.save_metrics("eval", metrics)
674 |
675 | # Predict
676 | if training_args.do_predict:
677 | logger.info("*** Predict ***")
678 |
679 | metrics = trainer.predict(test_dataset).metrics
680 |
681 | metrics["test_samples"] = len(test_dataset)
682 | try:
683 | perplexity = math.exp(metrics["test_loss"])
684 | except OverflowError:
685 | perplexity = float("inf")
686 | metrics["test_perplexity"] = perplexity
687 |
688 | trainer.log_metrics("test", metrics)
689 | trainer.save_metrics("test", metrics)
690 |
691 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
692 | if data_args.dataset_name is not None:
693 | kwargs["dataset_tags"] = data_args.dataset_name
694 | if data_args.dataset_config_name is not None:
695 | kwargs["dataset_args"] = data_args.dataset_config_name
696 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
697 | else:
698 | kwargs["dataset"] = data_args.dataset_name
699 |
700 | # if training_args.push_to_hub:
701 | # trainer.push_to_hub(**kwargs)
702 | # else:
703 | # trainer.create_model_card(**kwargs)
704 |
705 |
706 | def _mp_fn(index):
707 | # For xla_spawn (TPUs)
708 | main()
709 |
710 |
711 | if __name__ == "__main__":
712 | main()
713 |
--------------------------------------------------------------------------------
/run_clm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ## pretrain code for llama-tiny
4 | # - to pretrain a tinyllama, change the config to `TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T`
5 | # - to intialize the model with a pretrained model, add `--model_name_or_path TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T`
6 | # - to use the minipile dataset, use `--dataset_name JeanKaddour/minipile`, with proper `--preprocessing_num_workers`
7 | # - to enable wandb, use `--report_to wandb`
8 | accelerate launch run_clm.py \
9 | --tokenizer_name TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T \
10 | --config_name configs/llama_tiny_lckv.json \
11 | --config_overrides layer_types=0_6_6_6_6_6_6_7,forward_passes=7,backward_passes=2 \
12 | --dataset_name wikitext \
13 | --dataset_config_name wikitext-103-raw-v1 \
14 | --per_device_train_batch_size 32 \
15 | --per_device_eval_batch_size 32 \
16 | --auto_find_batch_size \
17 | --gradient_accumulation_steps 1 \
18 | --block_size 1024 \
19 | --lr_scheduler_type cosine \
20 | --warmup_ratio 0.015 \
21 | --learning_rate 3e-4 \
22 | --weight_decay 1e-1 \
23 | --bf16 \
24 | --torch_dtype bfloat16 \
25 | --do_train \
26 | --do_eval \
27 | --use_liger_kernel \
28 | --num_train_epochs 3 \
29 | --save_total_limit 1 \
30 | --save_strategy steps \
31 | --save_steps 500 \
32 | --evaluation_strategy steps \
33 | --eval_steps 500 \
34 | --load_best_model_at_end True \
35 | --metric_for_best_model eval_loss \
36 | --report_to none \
37 | --run_name llamatiny-test \
38 | --overwrite_output_dir \
39 | --output_dir outputs/llamatiny-test
40 |
--------------------------------------------------------------------------------
/run_generation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
18 | """
19 |
20 |
21 | import argparse
22 | import inspect
23 | import logging
24 | from typing import Tuple
25 |
26 | import torch
27 | from accelerate import PartialState
28 | from accelerate.utils import set_seed
29 |
30 | from models import LCKVLlamaForCausalLM
31 | from models.cache_utils import IndexedHybridCache
32 | from models.utils import LayerTypeParser
33 | from transformers import (
34 | AutoTokenizer,
35 | BloomForCausalLM,
36 | BloomTokenizerFast,
37 | CTRLLMHeadModel,
38 | CTRLTokenizer,
39 | GenerationMixin,
40 | GPT2LMHeadModel,
41 | GPT2Tokenizer,
42 | GPTJForCausalLM,
43 | LlamaForCausalLM,
44 | LlamaTokenizer,
45 | OpenAIGPTLMHeadModel,
46 | OpenAIGPTTokenizer,
47 | OPTForCausalLM,
48 | TransfoXLLMHeadModel,
49 | TransfoXLTokenizer,
50 | XLMTokenizer,
51 | XLMWithLMHeadModel,
52 | XLNetLMHeadModel,
53 | XLNetTokenizer,
54 | )
55 | from transformers.cache_utils import SinkCache
56 | from transformers.modeling_outputs import CausalLMOutputWithPast
57 |
58 |
59 | logging.basicConfig(
60 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
61 | datefmt="%m/%d/%Y %H:%M:%S",
62 | level=logging.INFO,
63 | )
64 | logger = logging.getLogger(__name__)
65 |
66 | MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop
67 |
68 | MODEL_CLASSES = {
69 | "lckv-llama": (LCKVLlamaForCausalLM, LlamaTokenizer),
70 | "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
71 | "ctrl": (CTRLLMHeadModel, CTRLTokenizer),
72 | "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
73 | "xlnet": (XLNetLMHeadModel, XLNetTokenizer),
74 | "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer),
75 | "xlm": (XLMWithLMHeadModel, XLMTokenizer),
76 | "gptj": (GPTJForCausalLM, AutoTokenizer),
77 | "bloom": (BloomForCausalLM, BloomTokenizerFast),
78 | "llama": (LlamaForCausalLM, LlamaTokenizer),
79 | "opt": (OPTForCausalLM, GPT2Tokenizer),
80 | }
81 |
82 | # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
83 | # in https://github.com/rusiaaman/XLNet-gen#methodology
84 | # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
85 | PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family
86 | (except for Alexei and Maria) are discovered.
87 | The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
88 | remainder of the story. 1883 Western Siberia,
89 | a young Grigori Rasputin is asked by his father and a group of men to perform magic.
90 | Rasputin has a vision and denounces one of the men as a horse thief. Although his
91 | father initially slaps him for making such an accusation, Rasputin watches as the
92 | man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
93 | the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
94 | with people, even a bishop, begging for his blessing. """
95 |
96 |
97 | #
98 | # Functions to prepare models' input
99 | #
100 |
101 |
102 | def prepare_ctrl_input(args, _, tokenizer, prompt_text):
103 | if args.temperature > 0.7:
104 | logger.info("CTRL typically works better with lower temperatures (and lower top_k).")
105 |
106 | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False)
107 | if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()):
108 | logger.info("WARNING! You are not starting your generation from a control code so you won't get good results")
109 | return prompt_text
110 |
111 |
112 | def prepare_xlm_input(args, model, tokenizer, prompt_text):
113 | # kwargs = {"language": None, "mask_token_id": None}
114 |
115 | # Set the language
116 | use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb
117 | if hasattr(model.config, "lang2id") and use_lang_emb:
118 | available_languages = model.config.lang2id.keys()
119 | if args.xlm_language in available_languages:
120 | language = args.xlm_language
121 | else:
122 | language = None
123 | while language not in available_languages:
124 | language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ")
125 |
126 | model.config.lang_id = model.config.lang2id[language]
127 | # kwargs["language"] = tokenizer.lang2id[language]
128 |
129 | # TODO fix mask_token_id setup when configurations will be synchronized between models and tokenizers
130 | # XLM masked-language modeling (MLM) models need masked token
131 | # is_xlm_mlm = "mlm" in args.model_name_or_path
132 | # if is_xlm_mlm:
133 | # kwargs["mask_token_id"] = tokenizer.mask_token_id
134 |
135 | return prompt_text
136 |
137 |
138 | def prepare_xlnet_input(args, _, tokenizer, prompt_text):
139 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
140 | prompt_text = prefix + prompt_text
141 | return prompt_text
142 |
143 |
144 | def prepare_transfoxl_input(args, _, tokenizer, prompt_text):
145 | prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX
146 | prompt_text = prefix + prompt_text
147 | return prompt_text
148 |
149 |
150 | PREPROCESSING_FUNCTIONS = {
151 | "ctrl": prepare_ctrl_input,
152 | "xlm": prepare_xlm_input,
153 | "xlnet": prepare_xlnet_input,
154 | "transfo-xl": prepare_transfoxl_input,
155 | }
156 |
157 |
158 | def adjust_length_to_model(length, max_sequence_length):
159 | if length < 0 and max_sequence_length > 0:
160 | length = max_sequence_length
161 | elif 0 < max_sequence_length < length:
162 | length = max_sequence_length # No generation bigger than model size
163 | elif length < 0:
164 | length = MAX_LENGTH # avoid infinite loop
165 | return length
166 |
167 |
168 | def sparse_model_config(model_config):
169 | embedding_size = None
170 | if hasattr(model_config, "hidden_size"):
171 | embedding_size = model_config.hidden_size
172 | elif hasattr(model_config, "n_embed"):
173 | embedding_size = model_config.n_embed
174 | elif hasattr(model_config, "n_embd"):
175 | embedding_size = model_config.n_embd
176 |
177 | num_head = None
178 | if hasattr(model_config, "num_attention_heads"):
179 | num_head = model_config.num_attention_heads
180 | elif hasattr(model_config, "n_head"):
181 | num_head = model_config.n_head
182 |
183 | if embedding_size is None or num_head is None or num_head == 0:
184 | raise ValueError("Check the model config")
185 |
186 | num_embedding_size_per_head = int(embedding_size / num_head)
187 | if hasattr(model_config, "n_layer"):
188 | num_layer = model_config.n_layer
189 | elif hasattr(model_config, "num_hidden_layers"):
190 | num_layer = model_config.num_hidden_layers
191 | else:
192 | raise ValueError("Number of hidden layers couldn't be determined from the model config")
193 |
194 | return num_layer, num_head, num_embedding_size_per_head
195 |
196 |
197 | def generate_past_key_values(model, batch_size, seq_len):
198 | num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config)
199 | if model.config.model_type == "bloom":
200 | past_key_values = tuple(
201 | (
202 | torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len)
203 | .to(model.dtype)
204 | .to(model.device),
205 | torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head)
206 | .to(model.dtype)
207 | .to(model.device),
208 | )
209 | for _ in range(num_block_layers)
210 | )
211 | else:
212 | past_key_values = tuple(
213 | (
214 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
215 | .to(model.dtype)
216 | .to(model.device),
217 | torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head)
218 | .to(model.dtype)
219 | .to(model.device),
220 | )
221 | for _ in range(num_block_layers)
222 | )
223 | return past_key_values
224 |
225 |
226 | def prepare_jit_inputs(inputs, model, tokenizer):
227 | batch_size = len(inputs)
228 | dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt")
229 | dummy_input = dummy_input.to(model.device)
230 | if model.config.use_cache:
231 | dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1)
232 | dummy_input["attention_mask"] = torch.cat(
233 | [
234 | torch.zeros(dummy_input["attention_mask"].shape[0], 1)
235 | .to(dummy_input["attention_mask"].dtype)
236 | .to(model.device),
237 | dummy_input["attention_mask"],
238 | ],
239 | -1,
240 | )
241 | return dummy_input
242 |
243 |
244 | class _ModelFallbackWrapper(GenerationMixin):
245 | __slots__ = ("_optimized", "_default")
246 |
247 | def __init__(self, optimized, default):
248 | self._optimized = optimized
249 | self._default = default
250 |
251 | def __call__(self, *args, **kwargs):
252 | if kwargs["past_key_values"] is None and self._default.config.use_cache:
253 | kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0)
254 | kwargs.pop("position_ids", None)
255 | for k in list(kwargs.keys()):
256 | if kwargs[k] is None or isinstance(kwargs[k], bool):
257 | kwargs.pop(k)
258 | outputs = self._optimized(**kwargs)
259 | lm_logits = outputs[0]
260 | past_key_values = outputs[1]
261 | fixed_output = CausalLMOutputWithPast(
262 | loss=None,
263 | logits=lm_logits,
264 | past_key_values=past_key_values,
265 | hidden_states=None,
266 | attentions=None,
267 | )
268 | return fixed_output
269 |
270 | def __getattr__(self, item):
271 | return getattr(self._default, item)
272 |
273 | def prepare_inputs_for_generation(
274 | self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs
275 | ):
276 | return self._default.prepare_inputs_for_generation(
277 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs
278 | )
279 |
280 | def _reorder_cache(
281 | self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
282 | ) -> Tuple[Tuple[torch.Tensor]]:
283 | """
284 | This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
285 | [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
286 | beam_idx at every generation step.
287 | """
288 | return self._default._reorder_cache(past_key_values, beam_idx)
289 |
290 |
291 | def main():
292 | parser = argparse.ArgumentParser()
293 | parser.add_argument(
294 | "--model_type",
295 | default=None,
296 | type=str,
297 | required=True,
298 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
299 | )
300 | parser.add_argument(
301 | "--model_name_or_path",
302 | default=None,
303 | type=str,
304 | required=True,
305 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
306 | )
307 | parser.add_argument(
308 | "--tokenizer_name",
309 | default=None,
310 | type=str,
311 | help="Pretrained tokenizer name or path if not the same as model_name",
312 | )
313 | parser.add_argument(
314 | "--config_overrides",
315 | default=None,
316 | type=str,
317 | help=(
318 | "Override some existing default config settings when a model is trained from scratch. Example: "
319 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
320 | ),
321 | )
322 |
323 | parser.add_argument("--prompt", type=str, default="")
324 | parser.add_argument("--length", type=int, default=20)
325 | parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
326 |
327 | parser.add_argument(
328 | "--temperature",
329 | type=float,
330 | default=1.0,
331 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
332 | )
333 | parser.add_argument(
334 | "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
335 | )
336 | parser.add_argument("--k", type=int, default=0)
337 | parser.add_argument("--p", type=float, default=0.9)
338 |
339 | parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.")
340 | parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.")
341 | parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
342 |
343 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
344 | parser.add_argument(
345 | "--use_cpu",
346 | action="store_true",
347 | help="Whether or not to use cpu. If set to False, " "we will use gpu/npu or mps device if available",
348 | )
349 | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
350 | parser.add_argument(
351 | "--fp16",
352 | action="store_true",
353 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
354 | )
355 | parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference")
356 | parser.add_argument("--torch_dtype", type=str, default=None, help=(
357 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
358 | "dtype will be automatically derived from the model's weights."
359 | ), choices=["auto", "bfloat16", "float16", "float32"])
360 |
361 | # sink cache related arguments
362 | parser.add_argument("--sink_cache", action="store_true", help="Whether to use sink cache.")
363 | parser.add_argument("--window_length", type=int, default=256, help="Window size for sink cache.")
364 | parser.add_argument("--num_sink_tokens", type=int, default=2, help="Number of sink tokens.")
365 |
366 | args = parser.parse_args()
367 |
368 | # Initialize the distributed state.
369 | distributed_state = PartialState(cpu=args.use_cpu)
370 |
371 | logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}")
372 |
373 | if args.seed is not None:
374 | set_seed(args.seed)
375 |
376 | # Initialize the model and tokenizer
377 | try:
378 | args.model_type = args.model_type.lower()
379 | model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
380 | except KeyError:
381 | raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
382 |
383 | if args.tokenizer_name:
384 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name)
385 | else:
386 | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
387 | if tokenizer.pad_token is None:
388 | tokenizer.pad_token = tokenizer.eos_token
389 | # maybe we need to override configs
390 | config = model_class.config_class.from_pretrained(args.model_name_or_path)
391 | if args.config_overrides is not None:
392 | logger.info(f"Overriding config: {args.config_overrides}")
393 | config.update_from_string(args.config_overrides)
394 | logger.info(f"New config: {config}")
395 | torch_dtype = (
396 | args.torch_dtype
397 | if args.torch_dtype in ["auto", None]
398 | else getattr(torch, args.torch_dtype)
399 | )
400 | model = model_class.from_pretrained(args.model_name_or_path, config=config, torch_dtype=torch_dtype, attn_implementation="flash_attention_2")
401 |
402 | # Set the model to the right device
403 | model.to(distributed_state.device)
404 |
405 | if args.fp16:
406 | model.half()
407 | # XXX: we should not adjust the length to the model configs
408 | # max_seq_length = getattr(model.config, "max_position_embeddings", 0)
409 | # args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length)
410 | logger.info(args)
411 |
412 | prompt_text = args.prompt if args.prompt else input("Model prompt >>> ")
413 |
414 | # Different models need different input formatting and/or extra arguments
415 | requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
416 | if requires_preprocessing:
417 | prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
418 | preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
419 |
420 | if model.__class__.__name__ in ["TransfoXLLMHeadModel"]:
421 | tokenizer_kwargs = {"add_space_before_punct_symbol": True}
422 | else:
423 | tokenizer_kwargs = {}
424 |
425 | encoded_prompt = tokenizer.encode(
426 | preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs
427 | )
428 | else:
429 | prefix = args.prefix if args.prefix else args.padding_text
430 | encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt")
431 | encoded_prompt = encoded_prompt.to(distributed_state.device)
432 |
433 | if encoded_prompt.size()[-1] == 0:
434 | input_ids = None
435 | else:
436 | input_ids = encoded_prompt
437 |
438 | if args.jit:
439 | jit_input_texts = ["enable jit"]
440 | jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer)
441 | torch._C._jit_set_texpr_fuser_enabled(False)
442 | model.config.return_dict = False
443 | if hasattr(model, "forward"):
444 | sig = inspect.signature(model.forward)
445 | else:
446 | sig = inspect.signature(model.__call__)
447 | jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None)
448 | traced_model = torch.jit.trace(model, jit_inputs, strict=False)
449 | traced_model = torch.jit.freeze(traced_model.eval())
450 | traced_model(*jit_inputs)
451 | traced_model(*jit_inputs)
452 |
453 | model = _ModelFallbackWrapper(traced_model, model)
454 |
455 | kwargs = {}
456 | if hasattr(model.config, "layer_types"):
457 | parser = LayerTypeParser(model.config.layer_types)
458 | if parser.use_sliding_window():
459 | kwargs["past_key_values"] = IndexedHybridCache(parser, model.config.sliding_window)
460 | if args.sink_cache:
461 | raise ValueError("Sliding window and sink cache cannot be used together.")
462 | if args.sink_cache:
463 | kwargs["past_key_values"] = SinkCache(args.window_length, args.num_sink_tokens)
464 |
465 | output_sequences = model.generate(
466 | input_ids=input_ids,
467 | max_length=args.length + len(encoded_prompt[0]),
468 | temperature=args.temperature,
469 | top_k=args.k,
470 | top_p=args.p,
471 | repetition_penalty=args.repetition_penalty,
472 | do_sample=True,
473 | num_return_sequences=args.num_return_sequences,
474 | **kwargs,
475 | )
476 |
477 | # Remove the batch dimension when returning multiple sequences
478 | if len(output_sequences.shape) > 2:
479 | output_sequences.squeeze_()
480 |
481 | generated_sequences = []
482 |
483 | for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
484 | print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
485 | generated_sequence = generated_sequence.tolist()
486 |
487 | # Decode text
488 | text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
489 |
490 | # Remove all text after the stop token
491 | text = text[: text.find(args.stop_token) if args.stop_token else None]
492 |
493 | # Add the prompt at the beginning of the sequence. Remove the excess text that was used for pre-processing
494 | total_sequence = (
495 | prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :]
496 | )
497 |
498 | generated_sequences.append(total_sequence)
499 | print(total_sequence)
500 |
501 | return generated_sequences
502 |
503 | if __name__ == "__main__":
504 | main()
505 |
--------------------------------------------------------------------------------
/run_generation.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # run generation
4 | python run_generation.py \
5 | --model_type lckv-llama \
6 | --torch_dtype bfloat16 \
7 | --tokenizer_name TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T \
8 | --model_name_or_path outputs/llamatiny-test \
9 | --num_return_sequences 1 \
10 | --prompt "the meaning of life is" \
11 | --length 512
12 |
13 | # run streaming
14 | python run_generation.py \
15 | --model_type lckv-llama \
16 | --torch_dtype bfloat16 \
17 | --tokenizer_name TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T \
18 | --model_name_or_path outputs/llamatiny-test \
19 | --num_return_sequences 1 \
20 | --prompt "the meaning of life is" \
21 | --length 2048 \
22 | --sink_cache \
23 | --window_length 1024 \
24 | --num_sink_tokens 4 \
25 |
--------------------------------------------------------------------------------
/run_sft.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from trl.commands.cli_utils import SFTScriptArguments, TrlParser
17 |
18 |
19 | from datasets import load_dataset
20 |
21 | from transformers import AutoTokenizer
22 | import models
23 |
24 | from trl import (
25 | ModelConfig,
26 | SFTConfig,
27 | SFTTrainer,
28 | get_peft_config,
29 | get_quantization_config,
30 | get_kbit_device_map,
31 | )
32 |
33 |
34 | if __name__ == "__main__":
35 | parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig))
36 | args, training_args, model_config = parser.parse_args_and_config()
37 |
38 | ################
39 | # Model init kwargs & Tokenizer
40 | ################
41 | quantization_config = get_quantization_config(model_config)
42 | model_kwargs = dict(
43 | revision=model_config.model_revision,
44 | trust_remote_code=model_config.trust_remote_code,
45 | attn_implementation=model_config.attn_implementation,
46 | torch_dtype=model_config.torch_dtype,
47 | use_cache=False,
48 | device_map=get_kbit_device_map() if quantization_config is not None else None,
49 | quantization_config=quantization_config,
50 | )
51 | training_args.model_init_kwargs = model_kwargs
52 | tokenizer = AutoTokenizer.from_pretrained(
53 | model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
54 | )
55 | tokenizer.pad_token = tokenizer.eos_token
56 |
57 | ################
58 | # Dataset
59 | ################
60 | dataset = load_dataset(args.dataset_name)
61 |
62 | ################
63 | # Training
64 | ################
65 | trainer = SFTTrainer(
66 | model=model_config.model_name_or_path,
67 | args=training_args,
68 | train_dataset=dataset[args.dataset_train_split],
69 | eval_dataset=dataset[args.dataset_test_split],
70 | tokenizer=tokenizer,
71 | peft_config=get_peft_config(model_config),
72 | )
73 |
74 | trainer.train()
75 | trainer.save_model(training_args.output_dir)
76 |
--------------------------------------------------------------------------------
/run_sft.sh:
--------------------------------------------------------------------------------
1 | # Full training
2 | python run_sft.py \
3 | --model_name_or_path outputs/llamatiny-test \
4 | --dataset_name timdettmers/openassistant-guanaco \
5 | --dataset_text_field text \
6 | --max_seq_length 1024 \
7 | --use_liger \
8 | --bf16 \
9 | --torch_dtype bfloat16 \
10 | --attn_implementation flash_attention_2 \
11 | --per_device_train_batch_size 16 \
12 | --gradient_accumulation_steps 1 \
13 | --learning_rate 1.41e-5 \
14 | --logging_steps 1 \
15 | --num_train_epochs 3 \
16 | --report_to none \
17 | --output_dir outputs/llamatiny-sft-test
18 |
--------------------------------------------------------------------------------
/run_streaming.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from pathlib import Path
5 |
6 | import requests
7 | import torch
8 |
9 | import models
10 | from transformers import AutoModelForCausalLM, AutoTokenizer, logging
11 | from transformers.cache_utils import SinkCache
12 | from transformers.generation.streamers import TextStreamer
13 |
14 |
15 | logging.get_logger("transformers.tokenization_utils").setLevel(logging.ERROR)
16 | logging.get_logger("transformers.generation.utils").setLevel(logging.ERROR)
17 | logging.get_logger("transformers.models.llama.modeling_llama").setLevel(logging.ERROR)
18 |
19 | def main():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--model_name_or_path", type=str, default="lmsys/vicuna-13b-v1.3")
22 | parser.add_argument("--data_root", type=str, default="data/")
23 |
24 | # generation related arguments
25 | parser.add_argument("--length", type=int, default=1000)
26 |
27 | parser.add_argument(
28 | "--temperature",
29 | type=float,
30 | default=1.0,
31 | help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
32 | )
33 | parser.add_argument(
34 | "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
35 | )
36 | parser.add_argument("--k", type=int, default=0)
37 | parser.add_argument("--p", type=float, default=0.9)
38 |
39 | # sink cache related arguments
40 | parser.add_argument("--sink_cache", action="store_true", help="Whether to use sink cache.")
41 | parser.add_argument("--window_length", type=int, default=256, help="Window size for sink cache.")
42 | parser.add_argument("--num_sink_tokens", type=int, default=2, help="Number of sink tokens.")
43 |
44 | args = parser.parse_args()
45 |
46 | # load model
47 | print(f"Loading model from {args.model_name_or_path} ...")
48 | tokenizer = AutoTokenizer.from_pretrained(
49 | args.model_name_or_path,
50 | trust_remote_code=True,
51 | chat_template=r"""
52 | {%- for message in messages %}
53 | {%- if message['role'] == 'user' %}
54 | {{- 'USER: ' + message['content'].strip() + '\n' }}
55 | {%- elif message['role'] == 'assistant' %}
56 | {{- 'ASSISTANT: ' + message['content'] + ' \n\n' }}
57 | {%- endif %}
58 | {%- endfor %}
59 | {%- if add_generation_prompt %}
60 | {{- 'ASSISTANT: ' }}
61 | {%- endif %}
62 | """
63 | )
64 | model = AutoModelForCausalLM.from_pretrained(
65 | args.model_name_or_path,
66 | device_map="auto",
67 | torch_dtype=torch.float16,
68 | trust_remote_code=True,
69 | attn_implementation="flash_attention_2",
70 | )
71 | model.eval()
72 |
73 | # load data
74 | print(f"Loading data from {args.data_root} ...")
75 | mt_bench = Path(args.data_root) / "mt_bench.jsonl"
76 | if not mt_bench.exists():
77 | print("Downloading mt_bench data ...")
78 | os.makedirs(args.data_root, exist_ok=True)
79 | with open(mt_bench, "w") as f:
80 | url = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl"
81 | response = requests.get(url)
82 | f.write(response.text)
83 |
84 | prompts = []
85 | with open(mt_bench, "r") as f:
86 | for line in f:
87 | prompts += json.loads(line)["turns"]
88 |
89 | # streaming inference
90 | kwargs = {}
91 | if args.sink_cache:
92 | kwargs["past_key_values"] = SinkCache(args.window_length, args.num_sink_tokens)
93 |
94 | chat_history = []
95 | streamer = TextStreamer(tokenizer, skip_prompt=True)
96 | for prompt in prompts:
97 | new_prompt = {"role": "user", "content": prompt}
98 | print(tokenizer.apply_chat_template([new_prompt], add_generation_prompt=True, tokenize=False), end="")
99 |
100 | chat_history.append(new_prompt)
101 | input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt").to(model.device)
102 |
103 | output_sequences = model.generate(
104 | input_ids=input_ids,
105 | max_new_tokens=args.length,
106 | temperature=args.temperature,
107 | top_k=args.k,
108 | top_p=args.p,
109 | repetition_penalty=args.repetition_penalty,
110 | do_sample=True,
111 | streamer=streamer,
112 | **kwargs,
113 | )
114 |
115 | chat_history.append({"role": "assistant", "content": tokenizer.decode(output_sequences[0, input_ids.shape[-1]:], skip_special_tokens=True)})
116 | print()
117 |
118 |
119 | if __name__ == "__main__":
120 | main()
121 |
--------------------------------------------------------------------------------
/run_streaming.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # streaming test
4 | python test_streaming.py \
5 | --model_name_or_path outputs/llamatiny-test \
6 | --output_dir streaming/llamatiny-test \
7 | --dataset_name pg19 \
8 | --download_streaming \
9 | --num_eval_tokens 4000000 \
10 | --sink_cache \
11 | --num_sink_tokens 4 \
12 | --window_length 1024 \
13 |
14 | # performance test
15 | python test_streaming.py \
16 | --model_name_or_path outputs/llamatiny-test \
17 | --output_dir streaming/llamatiny-2048 \
18 | --dataset_name pg19 \
19 | --download_streaming \
20 | --num_eval_tokens 65133 \
21 | --sink_cache \
22 | --num_sink_tokens 4 \
23 | --window_length 2048 \
24 |
--------------------------------------------------------------------------------
/test_harness.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 |
5 | import datasets
6 | import lm_eval
7 |
8 | import models
9 |
10 |
11 | datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
12 |
13 | def main():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T")
16 | parser.add_argument("--dtype", type=str, default="bfloat16")
17 | parser.add_argument("--tasks", type=str, nargs="+", default=["hellaswag", "openbookqa", "winogrande", "arc_challenge", "arc_easy", "boolq", "piqa"])
18 | parser.add_argument("--num_fewshot", type=int, default=0)
19 | parser.add_argument("--output", type=str, default="harness/result.json")
20 |
21 | args = parser.parse_args()
22 |
23 | task_manager = lm_eval.tasks.TaskManager()
24 | results = lm_eval.simple_evaluate( # call simple_evaluate
25 | model="hf",
26 | model_args=f"pretrained={args.model_name_or_path},dtype={args.dtype},attn_implementation=flash_attention_2",
27 | tasks=args.tasks,
28 | num_fewshot=args.num_fewshot,
29 | log_samples=False,
30 | task_manager=task_manager,
31 | )
32 |
33 | output_path = Path(args.output)
34 | output_path.parent.mkdir(parents=True, exist_ok=True)
35 |
36 | with open(output_path, "w") as f:
37 | json.dump(results, f, indent=4, default=lambda o: '')
38 |
39 | if __name__ == "__main__":
40 | main()
41 |
--------------------------------------------------------------------------------
/test_latency.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections.abc import Iterable
3 |
4 | import torch
5 | from accelerate import PartialState
6 | from accelerate.utils import set_seed
7 | from tqdm import tqdm
8 |
9 | from models import LCKVLlamaConfig
10 | from transformers import (
11 | AutoConfig,
12 | AutoModelForCausalLM,
13 | AutoTokenizer,
14 | logging,
15 | )
16 | from transformers.trainer_pt_utils import get_model_param_count
17 |
18 |
19 | set_seed(42)
20 |
21 | logging.get_logger("transformers.tokenization_utils").setLevel(logging.ERROR)
22 | logging.get_logger("transformers.generation.utils").setLevel(logging.ERROR)
23 | logging.get_logger("transformers.models.llama.modeling_llama").setLevel(logging.ERROR)
24 |
25 | prompt_text_2048 = """The meaning of life is the solution to life's problems.'
26 |
27 | It was difficult to know whether the speech was a lie or not. At first Lily was determined to accept that her brother was deeply ill and needed my help in trying to help him. However, this was hardly the right way to treat him. He was a healthy young man and I didn't want to upset his sisters by making them feel as though they had been abandoned. So Lily couldn't see the point of my helping.
28 |
29 | 'Is it true that you only said you were ill because you were selfish?'
30 |
31 | 'Well, yes, and I'm sorry. And, in a way, yes. Because it's true that I care about you. But first and foremost, I care about my brother.'
32 |
33 | 'OK,' Lily said, clearly annoyed at the interference in her own life. 'I do see the point of you.'
34 |
35 | 'So why don't you, forgive me?'
36 |
37 | 'Why don't you forgive me for what?'
38 |
39 | 'The part I played in making you miserable.'
40 |
41 | 'I don't forgive myself. I'm simply uncomfortable to talk to people like you.'
42 |
43 | 'OK, well, maybe it's not fair of me to start talking to people about myself. My mum is not going to come on the same terms as me,' I said defensively.
44 |
45 | 'I guess not,' Lily said.
46 |
47 | 'Yes, she is coming round,' I said hopefully. 'Better late than never.'
48 |
49 | 'So why do you come to see me?' Lily asked me. 'I mean, I was surprised. I've been neglecting you. You don't like me,' she said, somewhat alarmed at the harshness of her tone.
50 |
51 | 'Don't ask me to explain,' I told her, 'I don't really know.'
52 |
53 | 'But I thought you wanted to help,' Lily said slowly. 'What were you expecting me to do? When she became ill, you just seemed so upset about it.'
54 |
55 | 'I'm not in a good place right now,' I admitted. 'I've always hated end-of-life situations. I just don't know how to deal with it, Lily. It makes me feel bad.'
56 |
57 | 'I don't think I understand,' Lily replied. 'Have you tried to talk to her?'
58 |
59 | 'OK, I'll try. She used to live with me, but then she left.'
60 |
61 | 'Do you want to find out why she's given you up?' Lily asked.
62 |
63 | 'She wouldn't talk to me about it,' I said, embarrassed by the pressure I was under. 'She's tried talking to me, but I'm not open to anything.'
64 |
65 | 'You don't look open to anything,' Lily said with a note of contempt. 'She didn't put it that way. She was in love with you. I thought she wanted you to marry her.'
66 |
67 | 'She didn't say anything like that,' I said softly, biting my tongue. 'She just... didn't talk to me.'
68 |
69 | Lily's tone changed from brusque to disapproving. 'I don't care if you are in love with her. She's leaving you.'
70 |
71 | 'Lily, I want to know why she's leaving you.'
72 |
73 | 'Because she's giving you up,' Lily declared. 'That's why.'
74 |
75 | 'I don't want to know.' I glanced over my shoulder. The door was open a crack. I wondered how it was that my brother could have found his way out of it. I wondered if my brother had had help.
76 |
77 | 'You don't want to know the answer?' Lily persisted. 'Get me a drink. Maybe you can find out what you're looking for.'
78 |
79 | I decided to go for it. Lily had the bottle of whisky that I bought for her on her birthday a few months earlier. It was a lovely bottle with a small picture of our father on it.
80 |
81 | 'Well, I'm not leaving you,' Lily told me. 'Go and sit with my mum. I'm going to check on Peter. You need to concentrate on getting through this morning. You can give your sister some solace later.'
82 |
83 | 'Lily, I don't want to see any of you,' I told her. 'You told me you weren't going to make it to the pub but I'm already here.'
84 |
85 | 'OK, don't make me answer that,' Lily said in a firm tone. 'Look, I'm not convinced that any of you are going to make it. I've been left alone for far too long. If I stay for a while, I might be able to take the edge off your stress, but I'm not going to make it my business to see how you're doing. Goodbye, Lily.'
86 |
87 | 'OK,' I told her. 'Bye.'
88 |
89 | 'Where are you going?' Lily asked me from the door.
90 |
91 | 'To see my mum,' I told her. 'I'm taking the day off.'
92 |
93 | 'OK,' Lily said with a slight smile, 'good luck with that.'
94 |
95 | She then went downstairs to order something to eat from the freezer. I stayed with my mum for a few hours, because she wouldn't leave. She told me she was now staying with the sisters, and that she had just been to see Peter. I told her about Peter's wife and how she had made her decision.
96 |
97 | 'Don't you dare leave me,' she told me. 'I'm not going to let you have this bit between us any more.'
98 |
99 | 'You know you have to work some more,' I told her. 'I'm getting up, so will you have a nice day?'
100 |
101 | She said she would. I was touched by the way she told me she didn't hate me. 'I don't want to hear about anything, just know that I'm thinking of you. No matter what happens, I'm there.'
102 |
103 | Then, I packed up my things and made my way to Lily's house. I had to make an effort to get her to talk about her brother.
104 |
105 | 'OK,' I told her, 'I guess I don't have to talk about Peter. I don't want to ask any questions. But, you don't mind if I do, do you?'
106 |
107 | 'Not if you do,' Lily told me. 'But I'd like to know why you're upset.'
108 |
109 | 'You tell me when you want to.'
110 |
111 | 'OK,' Lily said, then shook her head. 'I don't think I'll be telling you. I'll make it my business to see Peter.'
112 |
113 | 'Good,' I said, somewhat disappointed.
114 |
115 | 'Goodbye.'
116 |
117 | 'Goodbye,' I said and walked away. I couldn't help but feel that the night was too soon for Lily to suddenly get so upset about something like this.
118 |
119 | So, what was I going to do? I didn't like getting up at nine in the morning, not just to be alone but to get up. I felt as if I couldn't handle being alone for long periods of time. Maybe I would go for a walk. Or, if I wasn't feeling up to that, I would make some sandwiches and I would go for a stroll. I could go to the pub, and talk about something that didn't involve anything very serious. If I went to the pub and wasn't forced to say anything, then I could relax and get on with my day.
120 |
121 | I went for a stroll down the hill and ended up going into London. It was a lovely early spring morning. The trees were still very early, so I stopped at a shop to buy some fresh produce for lunch. As I was buying it, a cyclist approached me and asked me how I was. I didn't know what to say. He was really nice, though, so I explained about the situations of my brother and sister.
122 |
123 | 'I'm not sure how you're supposed to take that,' he told me, not sure what to say, 'but, hopefully, everything will work out for the best. I'm going off to work today, so I'll see you when I get back to Tottenham.'
124 |
125 | 'Ok,' I told him.
126 |
127 | 'Goodbye,' he added.
128 |
129 | I must have looked relieved when I said goodbye. He did head to work and I went home to find Lily sitting in her chair on her sofa.
130 |
131 | 'So how are you? How is the Wiltshire weather?'
132 |
133 | Lily told me she had been back on a walk in that great, new National Trust park. 'It was lovely,' she said. 'The sheep ran wild.'
134 |
135 | 'Did you manage to see who was out hunting?' I asked. 'Lily, are you watching it on television?'
136 |
137 | Lily nodded.
138 |
139 | 'OK,' I said. 'Do you want me to watch?'
140 |
141 | 'You'"""
142 | prompt_text_512 = """The meaning of life is the solution to life's problems.'
143 |
144 | It was difficult to know whether the speech was a lie or not. At first Lily was determined to accept that her brother was deeply ill and needed my help in trying to help him. However, this was hardly the right way to treat him. He was a healthy young man and I didn't want to upset his sisters by making them feel as though they had been abandoned. So Lily couldn't see the point of my helping.
145 |
146 | 'Is it true that you only said you were ill because you were selfish?'
147 |
148 | 'Well, yes, and I'm sorry. And, in a way, yes. Because it's true that I care about you. But first and foremost, I care about my brother.'
149 |
150 | 'OK,' Lily said, clearly annoyed at the interference in her own life. 'I do see the point of you.'
151 |
152 | 'So why don't you, forgive me?'
153 |
154 | 'Why don't you forgive me for what?'
155 |
156 | 'The part I played in making you miserable.'
157 |
158 | 'I don't forgive myself. I'm simply uncomfortable to talk to people like you.'
159 |
160 | 'OK, well, maybe it's not fair of me to start talking to people about myself. My mum is not going to come on the same terms as me,' I said defensively.
161 |
162 | 'I guess not,' Lily said.
163 |
164 | 'Yes, she is coming round,' I said hopefully. 'Better late than never.'
165 |
166 | 'So why do you come to see me?' Lily asked me. 'I mean, I was surprised. I've been neglecting you. You don't like me,' she said, somewhat alarmed at the harshness of her tone.
167 |
168 | 'Don't ask me to explain,' I told her, 'I don't really know.'
169 |
170 | 'But I thought you wanted to help,' Lily said slowly. 'What were you expecting me to do? When she became ill, you just seemed so upset about it.'
171 |
172 | 'I'm not in a good place right now,' I admitted. 'I've always hated end-of-life situations. I just don't know how to deal with it, Lily. It makes me"""
173 | prompt_text_5 = "The meaning of life is"
174 | prompt_text_4096 = prompt_text_2048*2
175 |
176 |
177 | def empty_cache():
178 | for _ in range(10):
179 | torch.cuda.empty_cache()
180 |
181 | class BinarySearch:
182 | """
183 | Binary Search w/o maximum limit, search the upper bound by doubling the index.
184 | """
185 | def __init__(self, min: int = 0):
186 | self.low = min
187 | self.high = None
188 | self.nxt = min + 1
189 |
190 | def __iter__(self):
191 | return self
192 |
193 | def __next__(self):
194 | if self.high is None:
195 | return self.nxt
196 | elif self.low == self.high:
197 | raise StopIteration
198 | else:
199 | return self.nxt
200 |
201 | def report(self, idx: int, result: bool):
202 | if self.high is None:
203 | if idx <= self.low:
204 | if not result:
205 | raise ValueError(f"We've proven that {self.low} is feasible, but {idx} is not.")
206 | else:
207 | if not result:
208 | self.high = idx - 1
209 | self.nxt = (self.low + self.high) // 2
210 | else:
211 | self.low = idx
212 | self.nxt = idx * 2
213 | else:
214 | if idx < self.low or idx > self.high:
215 | raise ValueError(f"Index {idx} is out of range [{self.low}, {self.high}]")
216 | else:
217 | if not result:
218 | self.high = idx - 1
219 | else:
220 | self.low = idx
221 | self.nxt = (self.low + self.high) // 2
222 | if self.nxt == self.low:
223 | self.nxt = self.high
224 |
225 | class Streamer:
226 | def __init__(self, pbar):
227 | self.pbar = pbar
228 |
229 | def put(self, x):
230 | self.pbar.update(1)
231 |
232 | def end(self, *args, **kwargs):
233 | self.pbar.close()
234 |
235 | def get_ds_model(model_name, config, dtype, cpu_offload, disk_offload, offload_dir, num_gpus = 1):
236 | import deepspeed
237 | import torch.distributed as dist
238 | from transformers.deepspeed import HfDeepSpeedConfig
239 |
240 | hidden_size = config.hidden_size
241 | # debug: ModuleNotFoundError: No module named 'mpi4py'
242 | ## launch with deepspeed .py
243 | deepspeed.init_distributed("nccl")
244 | rank = dist.get_rank()
245 | pin_memory = True
246 |
247 | ds_config = {
248 | "fp16": {
249 | "enabled": dtype == torch.float16,
250 | },
251 | "bf16": {
252 | "enabled": dtype == torch.bfloat16,
253 | },
254 | "zero_optimization": {
255 | "stage": 3,
256 | "stage3_prefetch_bucket_size": hidden_size * hidden_size,
257 | "stage3_param_persistence_threshold": 0,
258 | },
259 | "steps_per_print": 2000,
260 | "train_batch_size": 1,
261 | "wall_clock_breakdown": False,
262 | }
263 |
264 | if cpu_offload:
265 | ds_config["zero_optimization"]["offload_param"] = {
266 | "device": "cpu",
267 | "pin_memory": pin_memory
268 | }
269 |
270 | if disk_offload:
271 | ds_config["zero_optimization"]["offload_param"] = {
272 | "device": "nvme",
273 | "pin_memory": True,
274 | "nvme_path": offload_dir,
275 | "buffer_count": 5,
276 | "buffer_size": 2 * (1 << 30),
277 | }
278 | ds_config["aio"] = {
279 | "block_size": 1048576,
280 | "queue_depth": 8,
281 | "thread_count": 1,
282 | "single_submit": False,
283 | "overlap_events": True,
284 | }
285 |
286 | # dschf = HfDeepSpeedConfig(ds_config)
287 |
288 | model = AutoModelForCausalLM.from_pretrained(
289 | model_name,
290 | config=config,
291 | torch_dtype=dtype
292 | )
293 | model = model.eval()
294 | ds_engine = deepspeed.initialize(model=model, config_params=ds_config)[0]
295 | ds_engine.module.eval()
296 | model = ds_engine.module
297 |
298 | return model
299 |
300 | def get_hf_model(model_name, config, dtype, cpu_offload, disk_offload, offload_dir, num_gpus):
301 | if num_gpus == 1 and dtype != torch.int8:
302 | # Here we use a custom device_map instead of device_map == "auto"
303 | # becase we want to offload as many as possible weights out of GPU
304 | # to allow a larger batch size.
305 | if cpu_offload:
306 | # NOTE: We must put some weights on GPU. Otherwise, huggingface reports errors.
307 | device_map = {
308 | "model.embed_tokens.weight": 0,
309 | "model.norm": "cpu",
310 | "model.layers": "cpu",
311 | "lm_head.weight": 0,
312 | }
313 | elif disk_offload:
314 | device_map = {
315 | "model.embed_tokens.weight": 0,
316 | "model.norm": "disk",
317 | "model.layers": "disk",
318 | "lm_head.weight": 0,
319 | }
320 | else:
321 | device_map = None
322 | max_memory = None
323 | else:
324 | # Here we use device_map == "auto", but set a low `max_memory` threshold
325 | # becase we want to offload as many as possible weights out of GPU
326 | # to allow a larger batch size.
327 | device_map = "auto"
328 | if cpu_offload:
329 | # `max_memory` should be larger than the embedding.
330 | # We use 2GB here because the embeding of opt-175b is 1.2GB.
331 | max_memory = {k: "2GB" for k in range(num_gpus)}
332 | elif disk_offload:
333 | max_memory = {k: "2GB" for k in range(num_gpus)}
334 | else:
335 | max_memory = {k: "14GB" for k in range(num_gpus)}
336 | max_memory["cpu"] = "160GB"
337 |
338 | if dtype == torch.int8:
339 | kwargs = {"load_in_8bit": True}
340 | else:
341 | kwargs = {"torch_dtype": dtype}
342 |
343 | model = AutoModelForCausalLM.from_pretrained(model_name, config=config,
344 | device_map=device_map, max_memory=max_memory,
345 | offload_folder=offload_dir, **kwargs)
346 | if device_map is None:
347 | model.cuda()
348 |
349 | model.eval()
350 | return model
351 |
352 |
353 | distributed_state = PartialState()
354 |
355 | def prepare(model: str, size: str, cpu_offload: str = "none", warmup: int = 2):
356 | """
357 | Prepare the tokenizer and model.
358 |
359 | Arguments:
360 | model (str): The model name. Options are "llama" and "lckv-llama".
361 | size (str): The model size. Options are "50m", "1.1b", "7b", and "30b".
362 | cpu_offload (str): The CPU offload strategy. Options are "none", "hf", and "ds".
363 | warmup (int): The number of warmup layers for LCKV-LLAMA.
364 | """
365 |
366 | CONFIG_MAPPING = {
367 | "50m": "configs/llama_tiny.json",
368 | "1.1b": "TinyLlama/TinyLlama-1.1B-intermediate-step-955k-token-2T",
369 | "7b": "huggyllama/llama-7b", # "yahma/llama-7b-hf"
370 | "30b": "huggyllama/llama-30b",
371 | }
372 |
373 | # prepare tokenizer
374 | tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
375 |
376 | # prepare config
377 | if model == "lckv-llama":
378 | config = LCKVLlamaConfig.from_pretrained(CONFIG_MAPPING[size])
379 | config._attn_implementation = "flash_attention_2"
380 | start, end = int(warmup) // 2, config.num_hidden_layers - int(warmup) // 2
381 | layer_types = [(end - 1 if i in range(start, end) else i) for i in range(config.num_hidden_layers)]
382 | config.layer_types = "_".join(map(str, layer_types))
383 |
384 | elif model == "llama":
385 | config = AutoConfig.from_pretrained(CONFIG_MAPPING[size])
386 | config._attn_implementation = "flash_attention_2"
387 |
388 | else:
389 | raise ValueError(f"Unknown model {model}")
390 |
391 | # prepare model
392 | if cpu_offload == "none":
393 | model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)
394 | model.to(distributed_state.device)
395 |
396 | elif cpu_offload == "hf":
397 | model = get_hf_model(
398 | model_name = CONFIG_MAPPING[size],
399 | config = config,
400 | dtype = torch.bfloat16,
401 | cpu_offload = True,
402 | disk_offload = False,
403 | offload_dir = None,
404 | num_gpus = 1
405 | )
406 |
407 | elif cpu_offload == "ds":
408 | model = get_ds_model(
409 | model_name = CONFIG_MAPPING[size],
410 | config = config,
411 | dtype = torch.bfloat16,
412 | cpu_offload = True,
413 | disk_offload = False,
414 | offload_dir = None
415 | )
416 |
417 | else:
418 | raise ValueError(f"Unknown CPU offload strategy {cpu_offload}")
419 |
420 | print("# of trainable parameters:", get_model_param_count(model, True))
421 | model.eval()
422 |
423 | return tokenizer, model
424 |
425 |
426 | def inject_callback(model, callback):
427 | """inject a callback into the model.
428 | """
429 | forward_func = model.forward
430 |
431 | def forward(self, *args, **kwargs):
432 | result = forward_func(*args, **kwargs)
433 | callback()
434 | return result
435 |
436 | model.forward = forward.__get__(model, type(model))
437 |
438 |
439 | def experiment(tokenizer, model, prompt, max_new_tokens, iterator=None, verbose=False):
440 |
441 | input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
442 |
443 | # move to cuda
444 | input_ids = input_ids.to(distributed_state.device)
445 | model.eval()
446 |
447 | # Pre-warmup the GPU
448 | _ = model.generate(input_ids[:1,:1], do_sample=True, max_length=50)
449 |
450 | def callback():
451 | if callback.timer is None:
452 | callback.timer = time.time()
453 | inject_callback(model, callback)
454 |
455 | def run(num_return_sequences = 32):
456 |
457 | callback.timer = None
458 |
459 | if verbose:
460 | pbar = tqdm(total=max_new_tokens, leave=False, desc="Generating")
461 | streamer = Streamer(pbar)
462 | else:
463 | streamer = None
464 |
465 | start = time.time()
466 |
467 | model.generate(
468 | input_ids=input_ids,
469 | max_new_tokens=max_new_tokens,
470 | temperature=1.0,
471 | top_k=0,
472 | top_p=0.9,
473 | repetition_penalty=1.0,
474 | do_sample=True,
475 | num_return_sequences=num_return_sequences,
476 | streamer=streamer,
477 | )
478 |
479 | end = time.time()
480 |
481 | return start, callback.timer, end
482 |
483 | if isinstance(iterator, Iterable):
484 |
485 | for i in iterator:
486 | # print the current date and time
487 | try:
488 | start, mid, end = run(i)
489 | length = max_new_tokens
490 | print(f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] bsz: {i:3}, time: {end-start:.3f} = {mid-start:.3f} + {end-mid:.3f}, throughput: {i*length/(end-start):.3f} tokens/s")
491 | except Exception as e:
492 | if not str(e).startswith("CUDA out of memory."):
493 | print(e)
494 | raise
495 | return
496 | empty_cache()
497 |
498 | else:
499 |
500 | iterator = BinarySearch(0 if iterator is None else iterator)
501 | for i in iterator:
502 | try:
503 | start, mid, end = run(i)
504 | length = max_new_tokens
505 | print(f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] bsz: {i:3}, time: {end-start:.3f} = {mid-start:.3f} + {end-mid:.3f}, throughput: {i*length/(end-start):.3f} tokens/s")
506 | iterator.report(i, True)
507 | except Exception as e:
508 | print(f"[{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] bsz: {i:3}, FAIL")
509 | if not str(e).startswith("CUDA out of memory."):
510 | print(e)
511 | raise
512 | iterator.report(i, False)
513 | empty_cache()
514 | return iterator.nxt
515 |
516 |
517 | def main():
518 |
519 | # do not consequectively run the same model, it will cause CUDA out of memory
520 |
521 | print(">>> 7b 2048+2048 llama")
522 | tokenizer, model = prepare("llama", "7b")
523 | experiment(tokenizer, model, prompt_text_2048, max_new_tokens=2048)
524 |
525 | print(">>> 7b 2048+2048 lckv-llama")
526 | tokenizer, model = prepare("lckv-llama", "7b")
527 | experiment(tokenizer, model, prompt_text_2048, max_new_tokens=2048)
528 |
529 | print(">>> 7b 2048+2048 lckv-llama w=10")
530 | tokenizer, model = prepare("lckv-llama", "7b", warmup=10)
531 | experiment(tokenizer, model, prompt_text_2048, max_new_tokens=2048)
532 |
533 |
534 | if __name__ == "__main__":
535 | main()
536 |
--------------------------------------------------------------------------------
/test_streaming.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | from datasets import load_dataset
6 | from torch.nn import CrossEntropyLoss
7 | from tqdm import tqdm
8 |
9 | import models
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 | from transformers.cache_utils import SinkCache
12 |
13 |
14 | device = "cuda" if torch.cuda.is_available() else "cpu"
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "--model_name_or_path", type=str, default="TinyLlama/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T"
20 | )
21 | parser.add_argument("--revision", type=str, default="main")
22 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
23 | parser.add_argument("--dataset_name", type=str, default=None)
24 | parser.add_argument("--task", type=str, default=None)
25 | parser.add_argument(
26 | "--split", type=str, default="test", choices=["validation", "test"]
27 | )
28 | parser.add_argument(
29 | "--download_streaming", action="store_true", help="enable streaming download"
30 | )
31 |
32 | parser.add_argument(
33 | "--num_samples",
34 | type=int,
35 | default=None,
36 | )
37 |
38 | parser.add_argument(
39 | "--output_dir",
40 | type=str,
41 | default="streaming/debug",
42 | )
43 |
44 | parser.add_argument(
45 | "--torch_dtype",
46 | type=str,
47 | default="bfloat16",
48 | )
49 |
50 | # sink cache related arguments
51 | parser.add_argument("--sink_cache", action="store_true", help="Whether to use sink cache.")
52 | parser.add_argument("--window_length", type=int, default=256, help="Window size for sink cache.")
53 | parser.add_argument("--num_sink_tokens", type=int, default=2, help="Number of sink tokens.")
54 |
55 | parser.add_argument("--num_eval_tokens", type=int, default=None)
56 |
57 | args = parser.parse_args()
58 | return args
59 |
60 | def load(model_name_or_path, torch_dtype):
61 | print(f"Loading model from {model_name_or_path} ...")
62 | # if only model type is specified, load from scratch
63 | if ";" in model_name_or_path:
64 | from test_latency import prepare
65 | tokenizer, model = prepare(*model_name_or_path.split(";"))
66 | return model, tokenizer
67 | # however, tensor parallel for running falcon will occur bugs
68 | tokenizer = AutoTokenizer.from_pretrained(
69 | model_name_or_path,
70 | trust_remote_code=True,
71 | )
72 | torch_dtype = (
73 | torch_dtype
74 | if torch_dtype in ["auto", None]
75 | else getattr(torch, torch_dtype)
76 | )
77 | model = AutoModelForCausalLM.from_pretrained(
78 | model_name_or_path,
79 | device_map="auto",
80 | torch_dtype=torch_dtype,
81 | trust_remote_code=True,
82 | )
83 | if tokenizer.pad_token_id is None:
84 | if tokenizer.eos_token_id is not None:
85 | tokenizer.pad_token_id = tokenizer.eos_token_id
86 | else:
87 | tokenizer.pad_token_id = 0
88 | model.eval()
89 | return model, tokenizer
90 |
91 | def main():
92 |
93 | args = parse_args()
94 |
95 | data = load_dataset(args.dataset_name, args.task, split=args.split, streaming=args.download_streaming)
96 | if args.num_samples is not None:
97 | data = data.select(range(args.num_samples))
98 |
99 | model, tokenizer = load(args.model_name_or_path, args.torch_dtype)
100 |
101 | nlls = []
102 | loss_fn = CrossEntropyLoss(reduction="none")
103 |
104 | # streaming inference
105 | past_key_values = None
106 | if args.sink_cache:
107 | past_key_values = SinkCache(args.window_length, args.num_sink_tokens)
108 |
109 | ## uncomment the following lines to enable latency measurement
110 | os.makedirs(args.output_dir, exist_ok=True)
111 | with open(f"{args.output_dir}/log.txt", "w") as f:
112 |
113 | num_eval_tokens = 0
114 | for item in data:
115 | text = item['text']
116 | encodings = tokenizer(text, return_tensors="pt")
117 |
118 | print(encodings.input_ids[:, :10])
119 |
120 | seq_len = encodings.input_ids.size(1)
121 | print(f"num_eval_tokens: {num_eval_tokens}, seq_len: {seq_len}")
122 | pbar = tqdm(range(0, seq_len - 1))
123 |
124 | # import time
125 | for idx in pbar:
126 | # if idx == args.start_size + args.recent_size:
127 | # print("Starting timer...")
128 | # start = time.time()
129 | input_ids = encodings.input_ids[:, idx : idx + 1].to(device)
130 | with torch.no_grad():
131 | outputs = model(
132 | input_ids,
133 | past_key_values=past_key_values,
134 | use_cache=True,
135 | )
136 | logits = outputs.logits.view(-1, model.config.vocab_size)
137 | past_key_values = outputs.past_key_values
138 | label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
139 | neg_log_likelihood = loss_fn(logits, label)
140 | nlls.append(neg_log_likelihood)
141 | pbar.set_description(
142 | f"nll: {neg_log_likelihood.item():.2f}, ppl: {torch.exp(neg_log_likelihood).item():.2f}"
143 | )
144 | print(neg_log_likelihood.item(), file=f, flush=True)
145 | num_eval_tokens += 1
146 | if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens:
147 | # print(f"time: {time.time() - start:.2f}")
148 | break
149 | if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens:
150 | break
151 |
152 | ppl = torch.exp(torch.stack(nlls).mean())
153 | print(ppl.item())
154 | with open(f"{args.output_dir}/ppl.txt", "w") as f:
155 | f.write(f"{ppl.item()}\n")
156 |
157 | if __name__ == "__main__":
158 | main()
159 |
--------------------------------------------------------------------------------
/tests/test_kernel.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from transformers import is_torch_available
4 | from transformers.testing_utils import require_torch_gpu
5 |
6 |
7 | if is_torch_available():
8 | import torch
9 |
10 | from models.kernel import liger_rotary
11 | from models.modeling_lckv import apply_rotary
12 |
13 |
14 | @require_torch_gpu
15 | class LigerRotaryTest(unittest.TestCase):
16 | def test_liger_rotary(self):
17 | # Test case: Test liger_rotary function
18 | q = torch.randn(2, 3, 4, 6, device="cuda")
19 | freq = torch.randn(1, 4, 3, device="cuda")
20 | embed = torch.cat((freq, freq), dim=-1)
21 | cos = embed.cos()
22 | sin = embed.sin()
23 | unsqueeze_dim = 1
24 |
25 | result_q = liger_rotary(q, cos, sin, unsqueeze_dim)
26 | self.assertEqual(result_q.shape, (2, 3, 4, 6))
27 |
28 | ref_q = apply_rotary(q, cos, sin, unsqueeze_dim)
29 | self.assertTrue(torch.allclose(result_q, ref_q))
30 |
31 |
32 | if __name__ == '__main__':
33 | unittest.main()
34 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from models.utils import LayerTypeParser
4 | from transformers import is_torch_available
5 | from transformers.testing_utils import require_flash_attn, require_torch_gpu
6 |
7 |
8 | if is_torch_available():
9 | import torch
10 |
11 | from models.utils import flash_attention_forward
12 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
13 |
14 |
15 | class LayerTypeParserTest(unittest.TestCase):
16 | def test_init_with_valid_input(self):
17 | # Test case: Initialization with valid input
18 | layer_type = "1_2s_0"
19 | parser = LayerTypeParser(layer_type)
20 | self.assertEqual([item.attends_to for item in parser], [1, 2, 0])
21 | self.assertEqual([item.use_sliding_window for item in parser], [False, True, False])
22 |
23 | def test_init_with_invalid_input(self):
24 | # Test case: Initialization with invalid input
25 | with self.assertRaises(Exception):
26 | LayerTypeParser("invalid_input")
27 |
28 | with self.assertRaises(Exception):
29 | LayerTypeParser("0_1t_2")
30 |
31 | def test_init_with_empty_string(self):
32 | # Test case: Initialization with an empty string
33 | with self.assertRaises(Exception):
34 | LayerTypeParser("")
35 |
36 | def test_len(self):
37 | # Test case: Check the length of the LayerTypeParser object
38 | layer_type = "1_2s_0"
39 | parser = LayerTypeParser(layer_type)
40 | self.assertEqual(len(parser), 3)
41 |
42 | def test_getitem(self):
43 | # Test case: Get the layer type information for a specific layer index
44 | layer_type = "1_2s_0"
45 | parser = LayerTypeParser(layer_type)
46 | layer_type_info = parser[1]
47 | self.assertEqual(layer_type_info.attends_to, 2)
48 | self.assertEqual(layer_type_info.attends_top, True)
49 | self.assertEqual(layer_type_info.use_sliding_window, True)
50 |
51 | def test_use_sliding_window(self):
52 | # Test case: Check if there exists a layer that uses sliding window attention
53 | layer_type = "1_2s_0"
54 | parser = LayerTypeParser(layer_type)
55 | self.assertEqual(parser.use_sliding_window(), True)
56 |
57 | layer_type = "0_0_2"
58 | parser = LayerTypeParser(layer_type)
59 | self.assertEqual(parser.use_sliding_window(), False)
60 |
61 | def test_attends_top(self):
62 | # Test case: Check if there exists a layer that attends to layers above it
63 | layer_type = "1_2s_0"
64 | parser = LayerTypeParser(layer_type)
65 | self.assertEqual(parser.attends_top(), True)
66 |
67 | layer_type = "0s_0s_2"
68 | parser = LayerTypeParser(layer_type)
69 | self.assertEqual(parser.attends_top(), False)
70 |
71 | def test_iteration_plan(self):
72 | # Test case: Check the iteration plan for the layer types
73 | layer_type = "1_2s_0"
74 | parser = LayerTypeParser(layer_type)
75 | iteration_plan = parser.iteration_plan(forward_passes=7, backward_passes=2)
76 | # Add assertions for the iteration plan
77 | self.assertEqual(len(iteration_plan), 9)
78 |
79 | # each layer should be updated exactly once
80 | updated_slices = [step.layer_slice for step in iteration_plan if step.update]
81 | updated_layers = [set(range(len(parser))[layer_slice]) for layer_slice in updated_slices]
82 | self.assertEqual(set.union(*updated_layers), set(range(len(parser))))
83 |
84 | # cyclic dependencies should be resolved
85 | self.assertEqual([step.requires_grad for step in iteration_plan], [False, False, False, False, False, False, False, True, True])
86 |
87 |
88 | # Test for the case where there is no cyclic dependency
89 | layer_type = "0_1s_2"
90 | parser = LayerTypeParser(layer_type)
91 | iteration_plan = parser.iteration_plan(forward_passes=7, backward_passes=2)
92 | self.assertEqual(len(iteration_plan), 1)
93 | self.assertTrue(iteration_plan[0].requires_grad)
94 | self.assertTrue(iteration_plan[0].update)
95 |
96 | def test_check(self):
97 | # Test case: Check if the layer type is valid
98 | num_hidden_layers = 3
99 | layer_type = "1_2s_0"
100 | parser = LayerTypeParser(layer_type)
101 | self.assertIsNone(parser.check(num_hidden_layers))
102 |
103 | # Test case: Check for invalid layer type
104 | num_hidden_layers = 3
105 | layer_type = "1_2s_3"
106 | parser = LayerTypeParser(layer_type)
107 | with self.assertRaises(Exception):
108 | parser.check(num_hidden_layers)
109 |
110 | num_hidden_layers = 3
111 | layer_type = "0_1_2s_0"
112 | parser = LayerTypeParser(layer_type)
113 | with self.assertRaises(Exception):
114 | parser.check(num_hidden_layers)
115 |
116 |
117 | @require_torch_gpu
118 | @require_flash_attn
119 | class FlashAttentionForwardTest(unittest.TestCase):
120 | def test_no_diag(self):
121 | # Test case: Test flash_attention_forward with no_diag=True
122 | query_states = torch.randn(2, 5, 3, 4, dtype=torch.bfloat16, device="cuda")
123 | key_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda")
124 | value_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda")
125 | attention_mask = None
126 | query_length = 5
127 | is_causal = True
128 | no_diag = True
129 |
130 | result = flash_attention_forward(
131 | query_states=query_states,
132 | key_states=key_states,
133 | value_states=value_states,
134 | attention_mask=attention_mask,
135 | query_length=query_length,
136 | is_causal=is_causal,
137 | no_diag=no_diag
138 | )
139 |
140 | self.assertEqual(result.shape, (2, 5, 3, 4))
141 |
142 | # Test case: attention_mask is not None, square attention matrix
143 | query_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda")
144 | attention_mask = torch.ones(2, 6, dtype=torch.long, device="cuda")
145 | attention_mask[1, 2:] = 0
146 | result = flash_attention_forward(
147 | query_states=query_states,
148 | key_states=key_states,
149 | value_states=value_states,
150 | attention_mask=attention_mask,
151 | query_length=query_length,
152 | is_causal=is_causal,
153 | no_diag=no_diag
154 | )
155 |
156 | self.assertEqual(result.shape, (2, 5, 3, 4))
157 |
158 | def test_with_diag(self):
159 | # Test case: Test flash_attention_forward with no_diag=False
160 | query_states = torch.randn(2, 5, 3, 4, dtype=torch.bfloat16, device="cuda")
161 | key_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda")
162 | value_states = torch.randn(2, 6, 3, 4, dtype=torch.bfloat16, device="cuda")
163 | attention_mask = None
164 | query_length = 5
165 | is_causal = True
166 | no_diag = False
167 |
168 | result = flash_attention_forward(
169 | query_states=query_states,
170 | key_states=key_states,
171 | value_states=value_states,
172 | attention_mask=attention_mask,
173 | query_length=query_length,
174 | is_causal=is_causal,
175 | no_diag=no_diag
176 | )
177 |
178 | ref = _flash_attention_forward(
179 | query_states=query_states,
180 | key_states=key_states,
181 | value_states=value_states,
182 | attention_mask=attention_mask,
183 | query_length=query_length,
184 | is_causal=is_causal,
185 | )
186 |
187 | self.assertEqual(result.shape, (2, 5, 3, 4))
188 | self.assertTrue(torch.allclose(result, ref))
189 |
190 |
191 | if __name__ == '__main__':
192 | unittest.main()
193 |
--------------------------------------------------------------------------------