├── img
└── main_fig.jpg
├── llama_real_share
├── __pycache__
│ ├── cache_utils.cpython-310.pyc
│ └── modeling_llama_kvsharer.cpython-310.pyc
├── cache_utils.py
└── modeling_llama.py
├── internlm2_real_share
├── __pycache__
│ ├── cache_utils.cpython-310.pyc
│ ├── configuration_internlm2.cpython-310.pyc
│ └── modeling_internlm2_kvsharer.cpython-310.pyc
├── configuration_internlm2.py
└── cache_utils.py
├── README.md
├── test_llama.ipynb
└── test_internlm.ipynb
/img/main_fig.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangyifei729/KVSharer/HEAD/img/main_fig.jpg
--------------------------------------------------------------------------------
/llama_real_share/__pycache__/cache_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangyifei729/KVSharer/HEAD/llama_real_share/__pycache__/cache_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/internlm2_real_share/__pycache__/cache_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangyifei729/KVSharer/HEAD/internlm2_real_share/__pycache__/cache_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/llama_real_share/__pycache__/modeling_llama_kvsharer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangyifei729/KVSharer/HEAD/llama_real_share/__pycache__/modeling_llama_kvsharer.cpython-310.pyc
--------------------------------------------------------------------------------
/internlm2_real_share/__pycache__/configuration_internlm2.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangyifei729/KVSharer/HEAD/internlm2_real_share/__pycache__/configuration_internlm2.cpython-310.pyc
--------------------------------------------------------------------------------
/internlm2_real_share/__pycache__/modeling_internlm2_kvsharer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangyifei729/KVSharer/HEAD/internlm2_real_share/__pycache__/modeling_internlm2_kvsharer.cpython-310.pyc
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KVSharer: Efficient Inference via Layer-Wise Dissimilar KV Cache Sharing
2 |
3 |
4 |

5 |
6 |
7 |
8 | ## Recommended software environment
9 | - python == 3.10
10 | - torch == 2.1.2
11 | - transformers >= 4.38.0
12 | - scikit-learn >= 1.0
13 | - tqdm >= 4.49.0
14 | - numpy >= 1.20.2
15 |
16 |
17 | ## Description
18 |
19 | - The process of KVSharer is described in two runnable .ipynb files `test_llama.ipynb` and `test_internlm.ipynb`, detailing how to conduct strategy search and how to integrate KVSharer for inference.
20 | - The main implementation of sharing during inference can be found in the `llama_real_share/cache_utils.py`, `internlm2_real_share/cache_utils.py` where we introduce a new class called DynamicDictCache to store only a portion of the layers' KV cache.
21 | - We add the `kv_cache_share_layers_map` parameter in the `LlamaForCausalLM` and `InternLM2ForCausalLM` to set the sharing strategy. The implementation can be found in `llama_real_share/modeling_llama_kvsharer.py`, `internlm2_real_share/modeling_internlm2_kvsharer.py`.
22 | - We provide a `wiki_demo.txt` file in `./data` folder for test.
23 |
24 | > [!NOTE]
25 | > This repo is under construction.
26 |
27 |
--------------------------------------------------------------------------------
/internlm2_real_share/configuration_internlm2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """ InternLM2 model configuration"""
18 |
19 | from transformers.configuration_utils import PretrainedConfig
20 | from transformers.utils import logging
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 | INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
25 |
26 |
27 | # Modified from transformers.model.llama.configuration_llama.LlamaConfig
28 | class InternLM2Config(PretrainedConfig):
29 | r"""
30 | This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
31 | an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
32 | configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
33 |
34 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35 | documentation from [`PretrainedConfig`] for more information.
36 |
37 |
38 | Args:
39 | vocab_size (`int`, *optional*, defaults to 32000):
40 | Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
41 | `inputs_ids` passed when calling [`InternLM2Model`]
42 | hidden_size (`int`, *optional*, defaults to 4096):
43 | Dimension of the hidden representations.
44 | intermediate_size (`int`, *optional*, defaults to 11008):
45 | Dimension of the MLP representations.
46 | num_hidden_layers (`int`, *optional*, defaults to 32):
47 | Number of hidden layers in the Transformer decoder.
48 | num_attention_heads (`int`, *optional*, defaults to 32):
49 | Number of attention heads for each attention layer in the Transformer decoder.
50 | num_key_value_heads (`int`, *optional*):
51 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55 | by meanpooling all the original heads within that group. For more details checkout [this
56 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57 | `num_attention_heads`.
58 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
59 | The non-linear activation function (function or string) in the decoder.
60 | max_position_embeddings (`int`, *optional*, defaults to 2048):
61 | The maximum sequence length that this model might ever be used with. InternLM2 supports up to 32768 tokens.
62 | initializer_range (`float`, *optional*, defaults to 0.02):
63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
65 | The epsilon used by the rms normalization layers.
66 | use_cache (`bool`, *optional*, defaults to `True`):
67 | Whether or not the model should return the last key/values attentions (not used by all models). Only
68 | relevant if `config.is_decoder=True`.
69 | pad_token_id (`int`, *optional*):
70 | Padding token id.
71 | bos_token_id (`int`, *optional*, defaults to 1):
72 | Beginning of stream token id.
73 | eos_token_id (`int`, *optional*, defaults to 2):
74 | End of stream token id.
75 | pretraining_tp (`int`, *optional*, defaults to 1):
76 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
77 | document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism)
78 | to understand more about it. This value is necessary to ensure exact reproducibility
79 | of the pretraining results. Please refer to [this
80 | issue](https://github.com/pytorch/pytorch/issues/76232).
81 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82 | Whether to tie weight embeddings
83 | rope_theta (`float`, *optional*, defaults to 10000.0):
84 | The base period of the RoPE embeddings.
85 | rope_scaling (`Dict`, *optional*):
86 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
87 | strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
88 | `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
89 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
90 | these scaling strategies behave:
91 | https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
92 | experimental feature, subject to breaking API changes in future versions.
93 | """
94 | _auto_class = "AutoConfig"
95 | model_type = "internlm2"
96 | keys_to_ignore_at_inference = ["past_key_values"]
97 |
98 | def __init__( # pylint: disable=W0102
99 | self,
100 | vocab_size=103168,
101 | hidden_size=4096,
102 | intermediate_size=11008,
103 | num_hidden_layers=32,
104 | num_attention_heads=32,
105 | num_key_value_heads=None,
106 | hidden_act="silu",
107 | max_position_embeddings=2048,
108 | initializer_range=0.02,
109 | rms_norm_eps=1e-6,
110 | use_cache=True,
111 | pad_token_id=0,
112 | bos_token_id=1,
113 | eos_token_id=2,
114 | pretraining_tp=1,
115 | tie_word_embeddings=False,
116 | bias=True,
117 | rope_theta=10000,
118 | rope_scaling=None,
119 | attn_implementation=None,
120 | **kwargs,
121 | ):
122 | self.vocab_size = vocab_size
123 | self.max_position_embeddings = max_position_embeddings
124 | self.hidden_size = hidden_size
125 | self.intermediate_size = intermediate_size
126 | self.num_hidden_layers = num_hidden_layers
127 | self.num_attention_heads = num_attention_heads
128 | self.bias = bias
129 |
130 | if num_key_value_heads is None:
131 | num_key_value_heads = num_attention_heads
132 | self.num_key_value_heads = num_key_value_heads
133 |
134 | self.hidden_act = hidden_act
135 | self.initializer_range = initializer_range
136 | self.rms_norm_eps = rms_norm_eps
137 | self.pretraining_tp = pretraining_tp
138 | self.use_cache = use_cache
139 | self.rope_theta = rope_theta
140 | self.rope_scaling = rope_scaling
141 | self._rope_scaling_validation()
142 | self.attn_implementation = attn_implementation
143 | if self.attn_implementation is None:
144 | self.attn_implementation = "eager"
145 |
146 | super().__init__(
147 | pad_token_id=pad_token_id,
148 | bos_token_id=bos_token_id,
149 | eos_token_id=eos_token_id,
150 | tie_word_embeddings=tie_word_embeddings,
151 | **kwargs,
152 | )
153 |
154 | def _rope_scaling_validation(self):
155 | """
156 | Validate the `rope_scaling` configuration.
157 | """
158 | if self.rope_scaling is None:
159 | return
160 |
161 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
162 | raise ValueError(
163 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
164 | f"got {self.rope_scaling}"
165 | )
166 | rope_scaling_type = self.rope_scaling.get("type", None)
167 | rope_scaling_factor = self.rope_scaling.get("factor", None)
168 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
169 | raise ValueError(
170 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
171 | )
172 | if (
173 | rope_scaling_factor is None
174 | or not isinstance(rope_scaling_factor, (float, int))
175 | or rope_scaling_factor < 1.0
176 | ):
177 | raise ValueError(
178 | f"`rope_scaling`'s factor field must be a number >= 1, got {rope_scaling_factor} "
179 | f"of type {type(rope_scaling_factor)}"
180 | )
181 |
--------------------------------------------------------------------------------
/test_llama.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%env CUDA_VISIBLE_DEVICES=0"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from transformers import AutoTokenizer\n",
19 | "import torch"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from llama_real_share.modeling_llama_kvsharer import LlamaForCausalLM"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "### Load Model"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "llama_path = 'YOUR MODEL'"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "tokenizer = AutoTokenizer.from_pretrained(llama_path, trust_remote_code=True)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "llama = LlamaForCausalLM.from_pretrained(llama_path, device_map='auto')"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "### Load Calibration Dataset"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "wiki_data_path = './data/wiki_demo.txt'\n",
79 | "with open(wiki_data_path, 'r') as f:\n",
80 | " wiki_data = f.readlines()\n",
81 | " f.close()"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "calibration_set = wiki_data[0:30]"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "### Calculate the Euclidean Distance between any two layers of KV cache and sort them"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "from tqdm import tqdm\n",
107 | "import torch\n",
108 | "\n",
109 | "kv_cache_share_layers_map = {i:i for i in range(len(llama.model.layers))}\n",
110 | "kv_cache_list = []\n",
111 | "with torch.no_grad():\n",
112 | " for text in tqdm(calibration_set):\n",
113 | " inp = tokenizer(text, return_tensors='pt', max_length=64, truncation=True)\n",
114 | " inp = inp.to('cuda:0')\n",
115 | " out = llama(**inp, kv_cache_share_layers_map=kv_cache_share_layers_map)\n",
116 | " past_key_values = out.past_key_values\n",
117 | " kv_cache_list.append(past_key_values)"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "num_layers = len(kv_cache_list[0])\n",
127 | "avg_past_key_values = [(torch.zeros_like(kv_cache_list[0][i][0]), torch.zeros_like(kv_cache_list[0][i][1])) for i in range(num_layers)]\n",
128 | "\n",
129 | "for past_key_values in tqdm(kv_cache_list):\n",
130 | " for i, (key, value) in enumerate(past_key_values):\n",
131 | " try:\n",
132 | " avg_past_key_values[i] = (avg_past_key_values[i][0] + key, avg_past_key_values[i][1] + value)\n",
133 | " except:\n",
134 | " pass\n",
135 | "\n",
136 | "num_elements = len(kv_cache_list)\n",
137 | "avg_past_key_values = [(key / num_elements, value / num_elements) for key, value in avg_past_key_values]\n"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "import torch\n",
147 | "import torch.nn.functional as F\n",
148 | "import matplotlib.pyplot as plt\n",
149 | "import seaborn as sns\n",
150 | "import numpy as np\n",
151 | "\n",
152 | "def compute_cosine_similarity(tensor1, tensor2):\n",
153 | " return F.cosine_similarity(tensor1.flatten(1), tensor2.flatten(1), dim=-1).mean().item()\n",
154 | "\n",
155 | "def compute_euclidean_distance(tensor1, tensor2):\n",
156 | " return torch.norm(tensor1 - tensor2, p=2, dim=-1).mean().item()\n",
157 | "\n",
158 | "num_layers = len(avg_past_key_values)\n",
159 | "similarity_matrix = np.zeros((num_layers, num_layers))\n",
160 | "\n",
161 | "for i in range(num_layers):\n",
162 | " for j in range(num_layers):\n",
163 | " if i > j:\n",
164 | " key_i, value_i = avg_past_key_values[i]\n",
165 | " key_j, value_j = avg_past_key_values[j]\n",
166 | " key_similarity = compute_euclidean_distance(key_i, key_j)\n",
167 | " value_similarity = compute_euclidean_distance(value_i, value_j) \n",
168 | " similarity_matrix[i, j] = (key_similarity + value_similarity) / 2\n",
169 | " else:\n",
170 | " similarity_matrix[i, j] = np.nan"
171 | ]
172 | },
173 | {
174 | "cell_type": "code",
175 | "execution_count": null,
176 | "metadata": {},
177 | "outputs": [],
178 | "source": [
179 | "\n",
180 | "flattened_values = similarity_matrix.flatten()\n",
181 | "valid_indices = ~np.isnan(flattened_values)\n",
182 | "\n",
183 | "valid_values = flattened_values[valid_indices]\n",
184 | "valid_flat_indices = np.where(valid_indices)[0]\n",
185 | "\n",
186 | "sorted_valid_indices = np.argsort(valid_values)[::-1]\n",
187 | "sorted_flat_indices = valid_flat_indices[sorted_valid_indices]\n",
188 | "\n",
189 | "sorted_positions = np.unravel_index(sorted_flat_indices, similarity_matrix.shape)\n",
190 | "\n",
191 | "pos_rank = []\n",
192 | "\n",
193 | "for i in range(sorted_positions[0].shape[0]):\n",
194 | " pos = (sorted_positions[0][i], sorted_positions[1][i])\n",
195 | " pos_rank.append(pos)\n",
196 | " "
197 | ]
198 | },
199 | {
200 | "cell_type": "markdown",
201 | "metadata": {},
202 | "source": [
203 | "### Initialize the Sharing Layers and THRESHOLD"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": null,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "SHARE_LAYERS = 4\n",
213 | "THRESHOLD = 0.5"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": null,
219 | "metadata": {},
220 | "outputs": [],
221 | "source": [
222 | "import numpy as np\n",
223 | "def cal_last_hidden_sim(model1, model2, kv_cache_share_layers_map, tokenizer, sents):\n",
224 | " sim_ls = []\n",
225 | " for s in sents:\n",
226 | " encoded_inputs = tokenizer(s, max_length=64, truncation=True, return_tensors='pt')\n",
227 | " encoded_inputs.to('cuda:0')\n",
228 | " with torch.no_grad():\n",
229 | " outputs1 = model1(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map={i:i for i in range(len(model1.model.layers))})\n",
230 | " hidden_states1 = outputs1.hidden_states[-1] # (1, seq_len, hidden)\n",
231 | " with torch.no_grad():\n",
232 | " outputs2 = model2(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map=kv_cache_share_layers_map)\n",
233 | " hidden_states2 = outputs2.hidden_states[-1] # (1, seq_len, hidden)\n",
234 | " sim_ls.append(torch.cosine_similarity(hidden_states1.squeeze(0).flatten().unsqueeze(0), hidden_states2.squeeze(0).flatten().unsqueeze(0)))\n",
235 | " sim_ls = [i.item() for i in sim_ls]\n",
236 | " print(sim_ls, np.mean(sim_ls))\n",
237 | " return np.mean(sim_ls)"
238 | ]
239 | },
240 | {
241 | "cell_type": "code",
242 | "execution_count": null,
243 | "metadata": {},
244 | "outputs": [],
245 | "source": [
246 | "def re_map(kv_cache_share_layers_map):\n",
247 | " tmp_kv_cache_share_layers_map = {}\n",
248 | " for key, values in kv_cache_share_layers_map.items():\n",
249 | " if key == values:\n",
250 | " tmp_kv_cache_share_layers_map[key] = values\n",
251 | " else:\n",
252 | " tmp_kv_cache_share_layers_map[key] = tmp_kv_cache_share_layers_map[values]\n",
253 | " return tmp_kv_cache_share_layers_map"
254 | ]
255 | },
256 | {
257 | "cell_type": "markdown",
258 | "metadata": {},
259 | "source": [
260 | "### Strategy Searching"
261 | ]
262 | },
263 | {
264 | "cell_type": "code",
265 | "execution_count": null,
266 | "metadata": {},
267 | "outputs": [],
268 | "source": [
269 | "from copy import deepcopy\n",
270 | "\n",
271 | "kv_cache_share_layers_map = {i:i for i in range(len(llama.model.layers))}\n",
272 | "\n",
273 | "shared_lay = []\n",
274 | "shared_num_layers = 0\n",
275 | "\n",
276 | "for pair in tqdm(pos_rank):\n",
277 | " tmp_kv_cache_share_layers_map = deepcopy(kv_cache_share_layers_map)\n",
278 | " if pair[0] < pair[1]:\n",
279 | " pair[0], pair[1] = pair[1], pair[0]\n",
280 | " if pair[0] in shared_lay:\n",
281 | " continue\n",
282 | " tmp_kv_cache_share_layers_map[pair[0]] = pair[1]\n",
283 | " tmp_kv_cache_share_layers_map = re_map(tmp_kv_cache_share_layers_map)\n",
284 | " sim_value = cal_last_hidden_sim(llama, llama, tmp_kv_cache_share_layers_map, tokenizer, calibration_set)\n",
285 | " if sim_value > THRESHOLD:\n",
286 | " kv_cache_share_layers_map = deepcopy(tmp_kv_cache_share_layers_map)\n",
287 | " shared_lay.append(pair[0])\n",
288 | " shared_num_layers += 1\n",
289 | " if shared_num_layers >= SHARE_LAYERS:\n",
290 | " break"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": null,
296 | "metadata": {},
297 | "outputs": [],
298 | "source": [
299 | "print(kv_cache_share_layers_map)"
300 | ]
301 | },
302 | {
303 | "cell_type": "markdown",
304 | "metadata": {},
305 | "source": [
306 | "### Inference with KVSharer"
307 | ]
308 | },
309 | {
310 | "cell_type": "code",
311 | "execution_count": null,
312 | "metadata": {},
313 | "outputs": [],
314 | "source": [
315 | "def generate(model, tokenizer, sent, kv_cache_share_layers_map=None):\n",
316 | " inputs = tokenizer(sent, return_tensors='pt')\n",
317 | " inputs = inputs.to('cuda:0')\n",
318 | " pred = model.generate(**inputs, kv_cache_share_layers_map=kv_cache_share_layers_map, max_new_tokens=256, repetition_penalty=1.1)\n",
319 | " print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "metadata": {},
326 | "outputs": [],
327 | "source": [
328 | "sent = 'Hello, what is your name'\n",
329 | "generate(llama, tokenizer, sent, kv_cache_share_layers_map=kv_cache_share_layers_map)"
330 | ]
331 | }
332 | ],
333 | "metadata": {
334 | "kernelspec": {
335 | "display_name": "py310",
336 | "language": "python",
337 | "name": "python3"
338 | },
339 | "language_info": {
340 | "codemirror_mode": {
341 | "name": "ipython",
342 | "version": 3
343 | },
344 | "file_extension": ".py",
345 | "mimetype": "text/x-python",
346 | "name": "python",
347 | "nbconvert_exporter": "python",
348 | "pygments_lexer": "ipython3",
349 | "version": "3.10.12"
350 | }
351 | },
352 | "nbformat": 4,
353 | "nbformat_minor": 2
354 | }
355 |
--------------------------------------------------------------------------------
/test_internlm.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%env CUDA_VISIBLE_DEVICES=0"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "from transformers import AutoTokenizer\n",
19 | "import torch"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "from internlm2_real_share.modeling_internlm2_kvsharer import InternLM2ForCausalLM"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "### Load Model"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "internlm_path = 'YOUR MODEL'"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "tokenizer = AutoTokenizer.from_pretrained(internlm_path, trust_remote_code=True)"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "internlm2 = InternLM2ForCausalLM.from_pretrained(internlm_path, device_map='auto')"
63 | ]
64 | },
65 | {
66 | "cell_type": "markdown",
67 | "metadata": {},
68 | "source": [
69 | "### Load Calibration Dataset"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "wiki_data_path = './data/wiki_demo.txt'\n",
79 | "with open(wiki_data_path, 'r') as f:\n",
80 | " wiki_data = f.readlines()\n",
81 | " f.close()"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "calibration_set = wiki_data[0:30]"
91 | ]
92 | },
93 | {
94 | "cell_type": "markdown",
95 | "metadata": {},
96 | "source": [
97 | "### Calculate the Euclidean Distance between any two layers of KV cache and sort them"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": null,
103 | "metadata": {},
104 | "outputs": [],
105 | "source": [
106 | "from tqdm import tqdm\n",
107 | "import torch\n",
108 | "\n",
109 | "kv_cache_share_layers_map = {i:i for i in range(len(internlm2.model.layers))}\n",
110 | "kv_cache_list = []\n",
111 | "with torch.no_grad():\n",
112 | " for text in tqdm(calibration_set):\n",
113 | " inp = tokenizer(text, return_tensors='pt', max_length=64, truncation=True)\n",
114 | " inp = inp.to('cuda:0')\n",
115 | " out = internlm2(**inp, kv_cache_share_layers_map=kv_cache_share_layers_map)\n",
116 | " past_key_values = out.past_key_values\n",
117 | " kv_cache_list.append(past_key_values)"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "num_layers = len(kv_cache_list[0])\n",
127 | "avg_past_key_values = [(torch.zeros_like(kv_cache_list[0][i][0]), torch.zeros_like(kv_cache_list[0][i][1])) for i in range(num_layers)]\n",
128 | "\n",
129 | "for past_key_values in tqdm(kv_cache_list):\n",
130 | " for i, (key, value) in enumerate(past_key_values):\n",
131 | " try:\n",
132 | " avg_past_key_values[i] = (avg_past_key_values[i][0] + key, avg_past_key_values[i][1] + value)\n",
133 | " except:\n",
134 | " pass\n",
135 | "\n",
136 | "num_elements = len(kv_cache_list)\n",
137 | "avg_past_key_values = [(key / num_elements, value / num_elements) for key, value in avg_past_key_values]\n"
138 | ]
139 | },
140 | {
141 | "cell_type": "code",
142 | "execution_count": null,
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "import torch\n",
147 | "import torch.nn.functional as F\n",
148 | "import matplotlib.pyplot as plt\n",
149 | "import seaborn as sns\n",
150 | "import numpy as np\n",
151 | "\n",
152 | "def compute_cosine_similarity(tensor1, tensor2):\n",
153 | " return F.cosine_similarity(tensor1.flatten(1), tensor2.flatten(1), dim=-1).mean().item()\n",
154 | "\n",
155 | "def compute_euclidean_distance(tensor1, tensor2):\n",
156 | " return torch.norm(tensor1 - tensor2, p=2, dim=-1).mean().item()\n",
157 | "\n",
158 | "num_layers = len(avg_past_key_values)\n",
159 | "similarity_matrix = np.zeros((num_layers, num_layers))\n",
160 | "\n",
161 | "\n",
162 | "for i in range(num_layers):\n",
163 | " for j in range(num_layers):\n",
164 | " if i > j:\n",
165 | " key_i, value_i = avg_past_key_values[i]\n",
166 | " key_j, value_j = avg_past_key_values[j]\n",
167 | " key_similarity = compute_euclidean_distance(key_i, key_j)\n",
168 | " value_similarity = compute_euclidean_distance(value_i, value_j) \n",
169 | " similarity_matrix[i, j] = (key_similarity + value_similarity) / 2\n",
170 | " else:\n",
171 | " similarity_matrix[i, j] = np.nan"
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "\n",
181 | "flattened_values = similarity_matrix.flatten()\n",
182 | "valid_indices = ~np.isnan(flattened_values)\n",
183 | "\n",
184 | "valid_values = flattened_values[valid_indices]\n",
185 | "valid_flat_indices = np.where(valid_indices)[0]\n",
186 | "\n",
187 | "sorted_valid_indices = np.argsort(valid_values)[::-1]\n",
188 | "sorted_flat_indices = valid_flat_indices[sorted_valid_indices]\n",
189 | "\n",
190 | "sorted_positions = np.unravel_index(sorted_flat_indices, similarity_matrix.shape)\n",
191 | "\n",
192 | "pos_rank = []\n",
193 | "\n",
194 | "for i in range(sorted_positions[0].shape[0]):\n",
195 | " pos = (sorted_positions[0][i], sorted_positions[1][i])\n",
196 | " pos_rank.append(pos)\n",
197 | " "
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "metadata": {},
203 | "source": [
204 | "### Initialize the Sharing Layers and THRESHOLD"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "SHARE_LAYERS = 4\n",
214 | "THRESHOLD = 0.5"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "metadata": {},
221 | "outputs": [],
222 | "source": [
223 | "import numpy as np\n",
224 | "def cal_last_hidden_sim(model1, model2, kv_cache_share_layers_map, tokenizer, sents):\n",
225 | " sim_ls = []\n",
226 | " for s in sents:\n",
227 | " encoded_inputs = tokenizer(s, max_length=64, truncation=True, return_tensors='pt')\n",
228 | " encoded_inputs.to('cuda:0')\n",
229 | " with torch.no_grad():\n",
230 | " outputs1 = model1(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map={i:i for i in range(len(model1.model.layers))})\n",
231 | " hidden_states1 = outputs1.hidden_states[-1] # (1, seq_len, hidden)\n",
232 | " with torch.no_grad():\n",
233 | " outputs2 = model2(**encoded_inputs, output_hidden_states=True, kv_cache_share_layers_map=kv_cache_share_layers_map)\n",
234 | " hidden_states2 = outputs2.hidden_states[-1] # (1, seq_len, hidden)\n",
235 | " sim_ls.append(torch.cosine_similarity(hidden_states1.squeeze(0).flatten().unsqueeze(0), hidden_states2.squeeze(0).flatten().unsqueeze(0)))\n",
236 | " sim_ls = [i.item() for i in sim_ls]\n",
237 | " print(sim_ls, np.mean(sim_ls))\n",
238 | " return np.mean(sim_ls)"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": null,
244 | "metadata": {},
245 | "outputs": [],
246 | "source": [
247 | "def re_map(kv_cache_share_layers_map):\n",
248 | " tmp_kv_cache_share_layers_map = {}\n",
249 | " for key, values in kv_cache_share_layers_map.items():\n",
250 | " if key == values:\n",
251 | " tmp_kv_cache_share_layers_map[key] = values\n",
252 | " else:\n",
253 | " tmp_kv_cache_share_layers_map[key] = tmp_kv_cache_share_layers_map[values]\n",
254 | " return tmp_kv_cache_share_layers_map"
255 | ]
256 | },
257 | {
258 | "cell_type": "markdown",
259 | "metadata": {},
260 | "source": [
261 | "### Strategy Searching"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": null,
267 | "metadata": {},
268 | "outputs": [],
269 | "source": [
270 | "from copy import deepcopy\n",
271 | "\n",
272 | "kv_cache_share_layers_map = {i:i for i in range(len(internlm2.model.layers))}\n",
273 | "\n",
274 | "shared_lay = []\n",
275 | "shared_num_layers = 0\n",
276 | "\n",
277 | "for pair in tqdm(pos_rank):\n",
278 | " tmp_kv_cache_share_layers_map = deepcopy(kv_cache_share_layers_map)\n",
279 | " if pair[0] < pair[1]:\n",
280 | " pair[0], pair[1] = pair[1], pair[0]\n",
281 | " if pair[0] in shared_lay:\n",
282 | " continue\n",
283 | " tmp_kv_cache_share_layers_map[pair[0]] = pair[1]\n",
284 | " tmp_kv_cache_share_layers_map = re_map(tmp_kv_cache_share_layers_map)\n",
285 | " sim_value = cal_last_hidden_sim(internlm2, internlm2, tmp_kv_cache_share_layers_map, tokenizer, calibration_set)\n",
286 | " if sim_value > THRESHOLD:\n",
287 | " kv_cache_share_layers_map = deepcopy(tmp_kv_cache_share_layers_map)\n",
288 | " shared_lay.append(pair[0])\n",
289 | " shared_num_layers += 1\n",
290 | " if shared_num_layers >= SHARE_LAYERS:\n",
291 | " break"
292 | ]
293 | },
294 | {
295 | "cell_type": "code",
296 | "execution_count": null,
297 | "metadata": {},
298 | "outputs": [],
299 | "source": [
300 | "print(kv_cache_share_layers_map)"
301 | ]
302 | },
303 | {
304 | "cell_type": "markdown",
305 | "metadata": {},
306 | "source": [
307 | "### Inference with KVSharer"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": null,
313 | "metadata": {},
314 | "outputs": [],
315 | "source": [
316 | "def generate(model, tokenizer, sent, kv_cache_share_layers_map=None):\n",
317 | " inputs = tokenizer(sent, return_tensors='pt')\n",
318 | " inputs = inputs.to('cuda:0')\n",
319 | " pred = model.generate(**inputs, kv_cache_share_layers_map=kv_cache_share_layers_map, max_new_tokens=256, repetition_penalty=1.1)\n",
320 | " print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))"
321 | ]
322 | },
323 | {
324 | "cell_type": "code",
325 | "execution_count": null,
326 | "metadata": {},
327 | "outputs": [],
328 | "source": [
329 | "sent = 'Hello, what is your name'\n",
330 | "generate(internlm2, tokenizer, sent, kv_cache_share_layers_map=kv_cache_share_layers_map)"
331 | ]
332 | }
333 | ],
334 | "metadata": {
335 | "kernelspec": {
336 | "display_name": "py310",
337 | "language": "python",
338 | "name": "python3"
339 | },
340 | "language_info": {
341 | "codemirror_mode": {
342 | "name": "ipython",
343 | "version": 3
344 | },
345 | "file_extension": ".py",
346 | "mimetype": "text/x-python",
347 | "name": "python",
348 | "nbconvert_exporter": "python",
349 | "pygments_lexer": "ipython3",
350 | "version": "3.10.12"
351 | }
352 | },
353 | "nbformat": 4,
354 | "nbformat_minor": 2
355 | }
356 |
--------------------------------------------------------------------------------
/internlm2_real_share/cache_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, List, Optional, Tuple
3 |
4 | import torch
5 |
6 | from transformers.configuration_utils import PretrainedConfig
7 |
8 |
9 | @dataclass
10 | class Cache:
11 | """
12 | Base, abstract class for all caches. The actual data structure is specific to each subclass.
13 | """
14 |
15 | def update(
16 | self,
17 | key_states: torch.Tensor,
18 | value_states: torch.Tensor,
19 | layer_idx: int,
20 | cache_kwargs: Optional[Dict[str, Any]] = None,
21 | ) -> Tuple[torch.Tensor, torch.Tensor]:
22 | """
23 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
24 |
25 | Parameters:
26 | key_states (`torch.Tensor`):
27 | The new key states to cache.
28 | value_states (`torch.Tensor`):
29 | The new value states to cache.
30 | layer_idx (`int`):
31 | The index of the layer to cache the states for.
32 | cache_kwargs (`Dict[str, Any]`, `optional`):
33 | Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
34 | cache to be created.
35 |
36 | Return:
37 | A tuple containing the updated key and value states.
38 | """
39 | raise NotImplementedError("Make sure to implement `update` in a subclass.")
40 |
41 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
42 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
43 | raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
44 |
45 | def get_max_length(self) -> Optional[int]:
46 | """Returns the maximum sequence length of the cached states, if there is any."""
47 | raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
48 |
49 | def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
50 | """Given the sequence length of the new inputs, returns the usable length of the cache."""
51 | # Cache without size limit -> all cache is usable
52 | # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
53 | # length, we will need to evict part of the cache (and thus not all cache is usable)
54 | max_length = self.get_max_length()
55 | previous_seq_length = self.get_seq_length(layer_idx)
56 | if max_length is not None and previous_seq_length + new_seq_length > max_length:
57 | return max_length - new_seq_length
58 | return previous_seq_length
59 |
60 |
61 | class DynamicCache(Cache):
62 | """
63 | A cache that grows dynamically as more tokens are generated. This is the default for generative models.
64 |
65 | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
66 | `[batch_size, num_heads, seq_len, head_dim]`.
67 | """
68 |
69 | def __init__(self) -> None:
70 | self.key_cache: List[torch.Tensor] = []
71 | self.value_cache: List[torch.Tensor] = []
72 | self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
73 |
74 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
75 | """
76 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
77 | sequence length.
78 | """
79 | if layer_idx < len(self):
80 | return (self.key_cache[layer_idx], self.value_cache[layer_idx])
81 | else:
82 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
83 |
84 | def __iter__(self):
85 | """
86 | Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
87 | keys and values
88 | """
89 | for layer_idx in range(len(self)):
90 | yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
91 |
92 | def __len__(self):
93 | """
94 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
95 | to the number of layers in the model.
96 | """
97 | return len(self.key_cache)
98 |
99 | def update(
100 | self,
101 | key_states: torch.Tensor,
102 | value_states: torch.Tensor,
103 | layer_idx: int,
104 | cache_kwargs: Optional[Dict[str, Any]] = None,
105 | ) -> Tuple[torch.Tensor, torch.Tensor]:
106 | """
107 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
108 |
109 | Parameters:
110 | key_states (`torch.Tensor`):
111 | The new key states to cache.
112 | value_states (`torch.Tensor`):
113 | The new value states to cache.
114 | layer_idx (`int`):
115 | The index of the layer to cache the states for.
116 | cache_kwargs (`Dict[str, Any]`, `optional`):
117 | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
118 |
119 | Return:
120 | A tuple containing the updated key and value states.
121 | """
122 | # Update the number of seen tokens
123 | if layer_idx == 0:
124 | self.seen_tokens += key_states.shape[-2]
125 |
126 | # Update the cache
127 | if len(self.key_cache) <= layer_idx:
128 | self.key_cache.append(key_states)
129 | self.value_cache.append(value_states)
130 | else:
131 | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
132 | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
133 |
134 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
135 |
136 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
137 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
138 | if len(self.key_cache) <= layer_idx:
139 | return 0
140 | return self.key_cache[layer_idx].shape[-2]
141 |
142 | def get_max_length(self) -> Optional[int]:
143 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
144 | return None
145 |
146 | def reorder_cache(self, beam_idx: torch.LongTensor):
147 | """Reorders the cache for beam search, given the selected beam indices."""
148 | for layer_idx in range(len(self.key_cache)):
149 | device = self.key_cache[layer_idx].device
150 | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
151 | device = self.value_cache[layer_idx].device
152 | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
153 |
154 | def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
155 | """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
156 | legacy_cache = ()
157 | for layer_idx in range(len(self)):
158 | legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
159 | return legacy_cache
160 |
161 | @classmethod
162 | def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
163 | """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
164 | cache = cls()
165 | if past_key_values is not None:
166 | for layer_idx in range(len(past_key_values)):
167 | key_states, value_states = past_key_values[layer_idx]
168 | cache.update(key_states, value_states, layer_idx)
169 | return cache
170 |
171 |
172 | class DynamicDictCache(Cache):
173 | """
174 | A cache that grows dynamically as more tokens are generated. This is the default for generative models.
175 |
176 | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
177 | `[batch_size, num_heads, seq_len, head_dim]`.
178 |
179 | Use Dict instead of List to store kv cache
180 | """
181 |
182 | def __init__(self) -> None:
183 | self.key_cache: Dict[int, torch.Tensor] = {}
184 | self.value_cache: Dict[int, torch.Tensor] = {}
185 | self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
186 |
187 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
188 | """
189 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
190 | sequence length.
191 | """
192 | if layer_idx < len(self):
193 | return (self.key_cache[layer_idx], self.value_cache[layer_idx])
194 | else:
195 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
196 |
197 | def __iter__(self):
198 | """
199 | Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
200 | keys and values
201 | """
202 | for layer_idx in range(len(self)):
203 | yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
204 |
205 | def __len__(self):
206 | """
207 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
208 | to the number of layers in the model.
209 | """
210 | return len(self.key_cache)
211 |
212 | def update(
213 | self,
214 | key_states: torch.Tensor,
215 | value_states: torch.Tensor,
216 | layer_idx: int,
217 | skip: bool = False,
218 | cache_kwargs: Optional[Dict[str, Any]] = None,
219 | ) -> Tuple[torch.Tensor, torch.Tensor]:
220 | # 如果当前层不计算,那么当前层的kv_cache也不用更新
221 | """
222 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
223 |
224 | Parameters:
225 | key_states (`torch.Tensor`):
226 | The new key states to cache.
227 | value_states (`torch.Tensor`):
228 | The new value states to cache.
229 | layer_idx (`int`):
230 | The index of the layer to cache the states for.
231 | cache_kwargs (`Dict[str, Any]`, `optional`):
232 | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
233 |
234 | Return:
235 | A tuple containing the updated key and value states.
236 | """
237 | # Update the number of seen tokens
238 | if layer_idx == 0:
239 | self.seen_tokens += key_states.shape[-2]
240 |
241 |
242 | # 如果
243 | if not skip:
244 | # Update the cache
245 | if len(self.key_cache) <= layer_idx:
246 | self.key_cache[layer_idx] = key_states
247 | self.value_cache[layer_idx] = value_states
248 | else:
249 | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
250 | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
251 |
252 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
253 |
254 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
255 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
256 | if len(self.key_cache) <= layer_idx:
257 | return 0
258 | return self.key_cache[layer_idx].shape[-2]
259 |
260 | def get_max_length(self) -> Optional[int]:
261 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
262 | return None
263 |
264 | def reorder_cache(self, beam_idx: torch.LongTensor):
265 | """Reorders the cache for beam search, given the selected beam indices."""
266 | for layer_idx in range(len(self.key_cache)):
267 | device = self.key_cache[layer_idx].device
268 | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
269 | device = self.value_cache[layer_idx].device
270 | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
271 |
272 | def to_legacy_cache(self, kv_cache_share_layers_map=None) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
273 | """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
274 | legacy_cache = ()
275 | # 这里也要换成dict
276 | if kv_cache_share_layers_map is None:
277 | for layer_idx in range(len(self)):
278 | legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
279 | return legacy_cache
280 | else:
281 | for layer_idx in range(len(self)):
282 | # TODO: 这里是不是不用保存所有的kv cache,应该只用几层就行?还要看看这个legacy_cache后续怎么用
283 | legacy_cache += ((self.key_cache[kv_cache_share_layers_map[layer_idx]], self.value_cache[kv_cache_share_layers_map[layer_idx]]),)
284 | return legacy_cache
285 |
286 | @classmethod
287 | def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicDictCache":
288 | """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
289 | cache = cls()
290 | if past_key_values is not None:
291 | for layer_idx in range(len(past_key_values)):
292 | key_states, value_states = past_key_values[layer_idx]
293 | cache.update(key_states, value_states, layer_idx)
294 | return cache
295 |
296 |
297 | class SinkCache(Cache):
298 | """
299 | A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
300 | generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
301 | tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
302 |
303 | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
304 | `[batch_size, num_heads, seq_len, head_dim]`.
305 |
306 | Parameters:
307 | window_length (`int`):
308 | The length of the context window.
309 | num_sink_tokens (`int`):
310 | The number of sink tokens. See the original paper for more information.
311 | """
312 |
313 | def __init__(self, window_length: int, num_sink_tokens: int) -> None:
314 | self.key_cache: List[torch.Tensor] = []
315 | self.value_cache: List[torch.Tensor] = []
316 | self.window_length = window_length
317 | self.num_sink_tokens = num_sink_tokens
318 | self.cos_sin_cache = {}
319 | self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
320 |
321 | @staticmethod
322 | def _rotate_half(x):
323 | x1 = x[..., : x.shape[-1] // 2]
324 | x2 = x[..., x.shape[-1] // 2 :]
325 | return torch.cat((-x2, x1), dim=-1)
326 |
327 | def _apply_key_rotary_pos_emb(
328 | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
329 | ) -> torch.Tensor:
330 | rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
331 | return rotated_key_states
332 |
333 | def _get_rerotation_cos_sin(
334 | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
335 | ) -> Tuple[torch.Tensor, torch.Tensor]:
336 | if key_states.shape[-2] not in self.cos_sin_cache:
337 | # Upcast to float32 temporarily for better accuracy
338 | cos = cos.to(torch.float32)
339 | sin = sin.to(torch.float32)
340 |
341 | # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
342 | original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
343 | shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
344 | original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
345 | shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
346 | rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
347 | rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
348 |
349 | self.cos_sin_cache[key_states.shape[-2]] = (
350 | rerotation_cos.to(key_states.dtype).unsqueeze(0),
351 | rerotation_sin.to(key_states.dtype).unsqueeze(0),
352 | )
353 | return self.cos_sin_cache[key_states.shape[-2]]
354 |
355 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
356 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
357 | # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
358 | if len(self.key_cache) <= layer_idx:
359 | return 0
360 | return self.key_cache[layer_idx].shape[-2]
361 |
362 | def get_max_length(self) -> Optional[int]:
363 | """Returns the maximum sequence length of the cached states."""
364 | return self.window_length
365 |
366 | def update(
367 | self,
368 | key_states: torch.Tensor,
369 | value_states: torch.Tensor,
370 | layer_idx: int,
371 | cache_kwargs: Optional[Dict[str, Any]] = None,
372 | ) -> Tuple[torch.Tensor, torch.Tensor]:
373 | """
374 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
375 |
376 | Parameters:
377 | key_states (`torch.Tensor`):
378 | The new key states to cache.
379 | value_states (`torch.Tensor`):
380 | The new value states to cache.
381 | layer_idx (`int`):
382 | The index of the layer to cache the states for.
383 | cache_kwargs (`Dict[str, Any]`, `optional`):
384 | Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
385 | `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
386 | rotation as the tokens are shifted.
387 |
388 | Return:
389 | A tuple containing the updated key and value states.
390 | """
391 | # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
392 | # with partially rotated position embeddings, like Phi or Persimmon.
393 | sin = cache_kwargs.get("sin")
394 | cos = cache_kwargs.get("cos")
395 | partial_rotation_size = cache_kwargs.get("partial_rotation_size")
396 | using_rope = cos is not None and sin is not None
397 |
398 | # Update the number of seen tokens
399 | if layer_idx == 0:
400 | self.seen_tokens += key_states.shape[-2]
401 |
402 | # [bsz, num_heads, seq_len, head_dim]
403 | if len(self.key_cache) <= layer_idx:
404 | # Empty cache
405 | self.key_cache.append(key_states)
406 | self.value_cache.append(value_states)
407 |
408 | elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
409 | # Growing cache
410 | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
411 | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
412 |
413 | else:
414 | # Shifting cache
415 | keys_to_keep = self.key_cache[layer_idx][
416 | :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
417 | ]
418 |
419 | # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
420 | if using_rope:
421 | rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
422 | key_states, cos[: self.window_length], sin[: self.window_length]
423 | )
424 | if partial_rotation_size is not None:
425 | keys_to_keep, keys_pass = (
426 | keys_to_keep[..., :partial_rotation_size],
427 | keys_to_keep[..., partial_rotation_size:],
428 | )
429 | keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
430 | if partial_rotation_size is not None:
431 | keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
432 |
433 | # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
434 | sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
435 | self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
436 |
437 | sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
438 | values_to_keep = self.value_cache[layer_idx][
439 | :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
440 | ]
441 | self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
442 |
443 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
444 |
445 | def reorder_cache(self, beam_idx: torch.LongTensor):
446 | """Reorders the cache for beam search, given the selected beam indices."""
447 | for layer_idx in range(len(self.key_cache)):
448 | device = self.key_cache[layer_idx].device
449 | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
450 | device = self.value_cache[layer_idx].device
451 | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
452 |
453 |
454 | class StaticCache(Cache):
455 | """
456 | Static Cache class to be used with `torch.compile(model)`.
457 |
458 | Parameters:
459 | config (`PretrainedConfig):
460 | The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
461 | required to initialize the static cache.
462 | max_batch_size (`int`):
463 | The maximum batch size with which the model will be used.
464 | max_cache_len (`int`):
465 | The maximum sequence length with which the model will be used.
466 | device (`torch.device`):
467 | The device on which the cache should be initialized. Should be the same as the layer.
468 | dtype (*optional*, defaults to `torch.float32`):
469 | The default `dtype` to use when initializing the layer.
470 | """
471 |
472 | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
473 | super().__init__()
474 | self.max_batch_size = max_batch_size
475 | self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
476 | # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
477 | self.head_dim = (
478 | config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
479 | )
480 |
481 | self.dtype = dtype if dtype is not None else torch.float32
482 | self.num_key_value_heads = (
483 | config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
484 | )
485 |
486 | cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
487 | self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
488 | self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
489 | self.seen_tokens = 0
490 |
491 | def update(
492 | self,
493 | key_states: torch.Tensor,
494 | value_states: torch.Tensor,
495 | layer_idx: int,
496 | cache_kwargs: Optional[Dict[str, Any]] = None,
497 | ) -> Tuple[torch.Tensor, torch.Tensor]:
498 | """
499 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
500 | It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
501 |
502 | Parameters:
503 | key_states (`torch.Tensor`):
504 | The new key states to cache.
505 | value_states (`torch.Tensor`):
506 | The new value states to cache.
507 | layer_idx (`int`):
508 | The index of the layer to cache the states for. Kept for backward compatibility
509 | cache_kwargs (`Dict[str, Any]`, `optional`):
510 | Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
511 | to know how much of the cache it should overwrite.
512 |
513 | Return:
514 | A tuple containing the updated key and value states.
515 | """
516 | new_cache_positions = cache_kwargs.get("cache_position")
517 | k_out = self.key_cache
518 | v_out = self.value_cache
519 |
520 | k_out[:, :, new_cache_positions] = key_states
521 | v_out[:, :, new_cache_positions] = value_states
522 |
523 | self.seen_tokens += key_states.shape[2]
524 | return k_out, v_out
525 |
526 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
527 | """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
528 | return self.seen_tokens
529 |
530 | def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
531 | return self.seen_tokens
532 |
533 | def get_max_length(self) -> Optional[int]:
534 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
535 | return self.max_cache_len
536 |
537 | def reorder_cache(self, beam_idx: torch.LongTensor):
538 | """Reorders the cache for beam search, given the selected beam indices."""
539 | device = self.key_cache.device
540 | self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
541 | device = self.value_cache.device
542 | self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
543 |
544 | def to_legacy_cache(self):
545 | """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
546 | return None
547 |
--------------------------------------------------------------------------------
/llama_real_share/cache_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, List, Optional, Tuple
3 |
4 | import torch
5 |
6 | from transformers.configuration_utils import PretrainedConfig
7 |
8 |
9 | @dataclass
10 | class Cache:
11 | """
12 | Base, abstract class for all caches. The actual data structure is specific to each subclass.
13 | """
14 |
15 | def update(
16 | self,
17 | key_states: torch.Tensor,
18 | value_states: torch.Tensor,
19 | layer_idx: int,
20 | cache_kwargs: Optional[Dict[str, Any]] = None,
21 | ) -> Tuple[torch.Tensor, torch.Tensor]:
22 | """
23 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
24 |
25 | Parameters:
26 | key_states (`torch.Tensor`):
27 | The new key states to cache.
28 | value_states (`torch.Tensor`):
29 | The new value states to cache.
30 | layer_idx (`int`):
31 | The index of the layer to cache the states for.
32 | cache_kwargs (`Dict[str, Any]`, `optional`):
33 | Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
34 | cache to be created.
35 |
36 | Return:
37 | A tuple containing the updated key and value states.
38 | """
39 | raise NotImplementedError("Make sure to implement `update` in a subclass.")
40 |
41 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
42 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
43 | raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
44 |
45 | def get_max_length(self) -> Optional[int]:
46 | """Returns the maximum sequence length of the cached states, if there is any."""
47 | raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
48 |
49 | def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
50 | """Given the sequence length of the new inputs, returns the usable length of the cache."""
51 | # Cache without size limit -> all cache is usable
52 | # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
53 | # length, we will need to evict part of the cache (and thus not all cache is usable)
54 | max_length = self.get_max_length()
55 | previous_seq_length = self.get_seq_length(layer_idx)
56 | if max_length is not None and previous_seq_length + new_seq_length > max_length:
57 | return max_length - new_seq_length
58 | return previous_seq_length
59 |
60 |
61 | class DynamicCache(Cache):
62 | """
63 | A cache that grows dynamically as more tokens are generated. This is the default for generative models.
64 |
65 | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
66 | `[batch_size, num_heads, seq_len, head_dim]`.
67 | """
68 |
69 | def __init__(self) -> None:
70 | self.key_cache: List[torch.Tensor] = []
71 | self.value_cache: List[torch.Tensor] = []
72 | self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
73 |
74 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
75 | """
76 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
77 | sequence length.
78 | """
79 | if layer_idx < len(self):
80 | return (self.key_cache[layer_idx], self.value_cache[layer_idx])
81 | else:
82 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
83 |
84 | def __iter__(self):
85 | """
86 | Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
87 | keys and values
88 | """
89 | for layer_idx in range(len(self)):
90 | yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
91 |
92 | def __len__(self):
93 | """
94 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
95 | to the number of layers in the model.
96 | """
97 | return len(self.key_cache)
98 |
99 | def update(
100 | self,
101 | key_states: torch.Tensor,
102 | value_states: torch.Tensor,
103 | layer_idx: int,
104 | cache_kwargs: Optional[Dict[str, Any]] = None,
105 | ) -> Tuple[torch.Tensor, torch.Tensor]:
106 | """
107 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
108 |
109 | Parameters:
110 | key_states (`torch.Tensor`):
111 | The new key states to cache.
112 | value_states (`torch.Tensor`):
113 | The new value states to cache.
114 | layer_idx (`int`):
115 | The index of the layer to cache the states for.
116 | cache_kwargs (`Dict[str, Any]`, `optional`):
117 | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
118 |
119 | Return:
120 | A tuple containing the updated key and value states.
121 | """
122 | # Update the number of seen tokens
123 | if layer_idx == 0:
124 | self.seen_tokens += key_states.shape[-2]
125 |
126 | # Update the cache
127 | if len(self.key_cache) <= layer_idx:
128 | self.key_cache.append(key_states)
129 | self.value_cache.append(value_states)
130 | else:
131 | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
132 | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
133 |
134 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
135 |
136 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
137 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
138 | if len(self.key_cache) <= layer_idx:
139 | return 0
140 | return self.key_cache[layer_idx].shape[-2]
141 |
142 | def get_max_length(self) -> Optional[int]:
143 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
144 | return None
145 |
146 | def reorder_cache(self, beam_idx: torch.LongTensor):
147 | """Reorders the cache for beam search, given the selected beam indices."""
148 | for layer_idx in range(len(self.key_cache)):
149 | device = self.key_cache[layer_idx].device
150 | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
151 | device = self.value_cache[layer_idx].device
152 | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
153 |
154 | def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
155 | """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
156 | legacy_cache = ()
157 | for layer_idx in range(len(self)):
158 | legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
159 | return legacy_cache
160 |
161 | @classmethod
162 | def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
163 | """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
164 | cache = cls()
165 | if past_key_values is not None:
166 | for layer_idx in range(len(past_key_values)):
167 | key_states, value_states = past_key_values[layer_idx]
168 | cache.update(key_states, value_states, layer_idx)
169 | return cache
170 |
171 |
172 | class DynamicDictCache(Cache):
173 | """
174 | A cache that grows dynamically as more tokens are generated. This is the default for generative models.
175 |
176 | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
177 | `[batch_size, num_heads, seq_len, head_dim]`.
178 |
179 | Use Dict instead of List to store kv cache
180 | """
181 |
182 | def __init__(self) -> None:
183 | self.key_cache: Dict[int, torch.Tensor] = {}
184 | self.value_cache: Dict[int, torch.Tensor] = {}
185 | self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
186 |
187 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
188 | """
189 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
190 | sequence length.
191 | """
192 | if layer_idx < 0:
193 | tmp_key = list(self.key_cache.keys())[-1]
194 | return (self.key_cache[tmp_key], self.value_cache[tmp_key])
195 | if layer_idx < len(self):
196 | return (self.key_cache[layer_idx], self.value_cache[layer_idx])
197 | else:
198 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
199 |
200 | def __iter__(self):
201 | """
202 | Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
203 | keys and values
204 | """
205 | for layer_idx in range(len(self)):
206 | yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
207 |
208 | def __len__(self):
209 | """
210 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
211 | to the number of layers in the model.
212 | """
213 | return len(self.key_cache)
214 |
215 | def update(
216 | self,
217 | key_states: torch.Tensor,
218 | value_states: torch.Tensor,
219 | layer_idx: int,
220 | skip: bool = False,
221 | cache_kwargs: Optional[Dict[str, Any]] = None,
222 | ) -> Tuple[torch.Tensor, torch.Tensor]:
223 | # 如果当前层不计算,那么当前层的kv_cache也不用更新
224 | """
225 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
226 |
227 | Parameters:
228 | key_states (`torch.Tensor`):
229 | The new key states to cache.
230 | value_states (`torch.Tensor`):
231 | The new value states to cache.
232 | layer_idx (`int`):
233 | The index of the layer to cache the states for.
234 | cache_kwargs (`Dict[str, Any]`, `optional`):
235 | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
236 |
237 | Return:
238 | A tuple containing the updated key and value states.
239 | """
240 | # Update the number of seen tokens
241 | if layer_idx == 0:
242 | self.seen_tokens += key_states.shape[-2]
243 |
244 | if not skip:
245 | # Update the cache
246 | if len(self.key_cache) <= layer_idx:
247 | self.key_cache[layer_idx] = key_states
248 | self.value_cache[layer_idx] = value_states
249 | else:
250 | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
251 | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
252 |
253 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
254 |
255 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
256 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
257 | if len(self.key_cache) <= layer_idx:
258 | return 0
259 | return self.key_cache[layer_idx].shape[-2]
260 |
261 | def get_max_length(self) -> Optional[int]:
262 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
263 | return None
264 |
265 | def reorder_cache(self, beam_idx: torch.LongTensor):
266 | """Reorders the cache for beam search, given the selected beam indices."""
267 | for layer_idx in range(len(self.key_cache)):
268 | device = self.key_cache[layer_idx].device
269 | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
270 | device = self.value_cache[layer_idx].device
271 | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
272 |
273 | def to_legacy_cache(self, kv_cache_share_layers_map=None) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
274 | """Converts the `DynamicDictCache` instance into the its equivalent in the legacy cache format."""
275 | # 下一次迭代的时候读取的是这个格式的KV cache,而不是dict格式的
276 | legacy_cache = ()
277 | # 这里也要换成dict
278 | if kv_cache_share_layers_map is None:
279 | for layer_idx in range(len(self)):
280 | legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
281 | return legacy_cache
282 | else:
283 | # print('kv_cache_share_layers_map', kv_cache_share_layers_map)
284 | for layer_idx in range(len(kv_cache_share_layers_map.keys())):
285 | if kv_cache_share_layers_map[layer_idx] == layer_idx:
286 | # print(kv_cache_share_layers_map[layer_idx], layer_idx)
287 | # print('add value')
288 | # TODO: 这里是不是不用保存所有的kv cache,应该只用几层就行?还要看看这个legacy_cache后续怎么用
289 | legacy_cache += ((self.key_cache[kv_cache_share_layers_map[layer_idx]], self.value_cache[kv_cache_share_layers_map[layer_idx]]),)
290 | else:
291 | # print('add None')
292 | legacy_cache += ((None, None),)
293 |
294 | return legacy_cache
295 |
296 | @classmethod
297 | def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicDictCache":
298 | """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
299 | cache = cls()
300 | if past_key_values is not None:
301 | for layer_idx in range(len(past_key_values)):
302 | key_states, value_states = past_key_values[layer_idx]
303 | if key_states is not None:
304 | cache.update(key_states, value_states, layer_idx)
305 | return cache
306 |
307 |
308 | class SinkCache(Cache):
309 | """
310 | A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
311 | generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
312 | tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
313 |
314 | It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
315 | `[batch_size, num_heads, seq_len, head_dim]`.
316 |
317 | Parameters:
318 | window_length (`int`):
319 | The length of the context window.
320 | num_sink_tokens (`int`):
321 | The number of sink tokens. See the original paper for more information.
322 | """
323 |
324 | def __init__(self, window_length: int, num_sink_tokens: int) -> None:
325 | self.key_cache: List[torch.Tensor] = []
326 | self.value_cache: List[torch.Tensor] = []
327 | self.window_length = window_length
328 | self.num_sink_tokens = num_sink_tokens
329 | self.cos_sin_cache = {}
330 | self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
331 |
332 | @staticmethod
333 | def _rotate_half(x):
334 | x1 = x[..., : x.shape[-1] // 2]
335 | x2 = x[..., x.shape[-1] // 2 :]
336 | return torch.cat((-x2, x1), dim=-1)
337 |
338 | def _apply_key_rotary_pos_emb(
339 | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
340 | ) -> torch.Tensor:
341 | rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
342 | return rotated_key_states
343 |
344 | def _get_rerotation_cos_sin(
345 | self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
346 | ) -> Tuple[torch.Tensor, torch.Tensor]:
347 | if key_states.shape[-2] not in self.cos_sin_cache:
348 | # Upcast to float32 temporarily for better accuracy
349 | cos = cos.to(torch.float32)
350 | sin = sin.to(torch.float32)
351 |
352 | # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
353 | original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
354 | shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
355 | original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
356 | shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
357 | rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
358 | rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
359 |
360 | self.cos_sin_cache[key_states.shape[-2]] = (
361 | rerotation_cos.to(key_states.dtype).unsqueeze(0),
362 | rerotation_sin.to(key_states.dtype).unsqueeze(0),
363 | )
364 | return self.cos_sin_cache[key_states.shape[-2]]
365 |
366 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
367 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
368 | # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
369 | if len(self.key_cache) <= layer_idx:
370 | return 0
371 | return self.key_cache[layer_idx].shape[-2]
372 |
373 | def get_max_length(self) -> Optional[int]:
374 | """Returns the maximum sequence length of the cached states."""
375 | return self.window_length
376 |
377 | def update(
378 | self,
379 | key_states: torch.Tensor,
380 | value_states: torch.Tensor,
381 | layer_idx: int,
382 | cache_kwargs: Optional[Dict[str, Any]] = None,
383 | ) -> Tuple[torch.Tensor, torch.Tensor]:
384 | """
385 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
386 |
387 | Parameters:
388 | key_states (`torch.Tensor`):
389 | The new key states to cache.
390 | value_states (`torch.Tensor`):
391 | The new value states to cache.
392 | layer_idx (`int`):
393 | The index of the layer to cache the states for.
394 | cache_kwargs (`Dict[str, Any]`, `optional`):
395 | Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
396 | `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
397 | rotation as the tokens are shifted.
398 |
399 | Return:
400 | A tuple containing the updated key and value states.
401 | """
402 | # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
403 | # with partially rotated position embeddings, like Phi or Persimmon.
404 | sin = cache_kwargs.get("sin")
405 | cos = cache_kwargs.get("cos")
406 | partial_rotation_size = cache_kwargs.get("partial_rotation_size")
407 | using_rope = cos is not None and sin is not None
408 |
409 | # Update the number of seen tokens
410 | if layer_idx == 0:
411 | self.seen_tokens += key_states.shape[-2]
412 |
413 | # [bsz, num_heads, seq_len, head_dim]
414 | if len(self.key_cache) <= layer_idx:
415 | # Empty cache
416 | self.key_cache.append(key_states)
417 | self.value_cache.append(value_states)
418 |
419 | elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
420 | # Growing cache
421 | self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
422 | self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
423 |
424 | else:
425 | # Shifting cache
426 | keys_to_keep = self.key_cache[layer_idx][
427 | :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
428 | ]
429 |
430 | # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
431 | if using_rope:
432 | rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
433 | key_states, cos[: self.window_length], sin[: self.window_length]
434 | )
435 | if partial_rotation_size is not None:
436 | keys_to_keep, keys_pass = (
437 | keys_to_keep[..., :partial_rotation_size],
438 | keys_to_keep[..., partial_rotation_size:],
439 | )
440 | keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
441 | if partial_rotation_size is not None:
442 | keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
443 |
444 | # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
445 | sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
446 | self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
447 |
448 | sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
449 | values_to_keep = self.value_cache[layer_idx][
450 | :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
451 | ]
452 | self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
453 |
454 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
455 |
456 | def reorder_cache(self, beam_idx: torch.LongTensor):
457 | """Reorders the cache for beam search, given the selected beam indices."""
458 | for layer_idx in range(len(self.key_cache)):
459 | device = self.key_cache[layer_idx].device
460 | self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
461 | device = self.value_cache[layer_idx].device
462 | self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
463 |
464 |
465 | class StaticCache(Cache):
466 | """
467 | Static Cache class to be used with `torch.compile(model)`.
468 |
469 | Parameters:
470 | config (`PretrainedConfig):
471 | The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
472 | required to initialize the static cache.
473 | max_batch_size (`int`):
474 | The maximum batch size with which the model will be used.
475 | max_cache_len (`int`):
476 | The maximum sequence length with which the model will be used.
477 | device (`torch.device`):
478 | The device on which the cache should be initialized. Should be the same as the layer.
479 | dtype (*optional*, defaults to `torch.float32`):
480 | The default `dtype` to use when initializing the layer.
481 | """
482 |
483 | def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
484 | super().__init__()
485 | self.max_batch_size = max_batch_size
486 | self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
487 | # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
488 | self.head_dim = (
489 | config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
490 | )
491 |
492 | self.dtype = dtype if dtype is not None else torch.float32
493 | self.num_key_value_heads = (
494 | config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
495 | )
496 |
497 | cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
498 | self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
499 | self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
500 | self.seen_tokens = 0
501 |
502 | def update(
503 | self,
504 | key_states: torch.Tensor,
505 | value_states: torch.Tensor,
506 | layer_idx: int,
507 | cache_kwargs: Optional[Dict[str, Any]] = None,
508 | ) -> Tuple[torch.Tensor, torch.Tensor]:
509 | """
510 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
511 | It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
512 |
513 | Parameters:
514 | key_states (`torch.Tensor`):
515 | The new key states to cache.
516 | value_states (`torch.Tensor`):
517 | The new value states to cache.
518 | layer_idx (`int`):
519 | The index of the layer to cache the states for. Kept for backward compatibility
520 | cache_kwargs (`Dict[str, Any]`, `optional`):
521 | Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
522 | to know how much of the cache it should overwrite.
523 |
524 | Return:
525 | A tuple containing the updated key and value states.
526 | """
527 | new_cache_positions = cache_kwargs.get("cache_position")
528 | k_out = self.key_cache
529 | v_out = self.value_cache
530 |
531 | k_out[:, :, new_cache_positions] = key_states
532 | v_out[:, :, new_cache_positions] = value_states
533 |
534 | self.seen_tokens += key_states.shape[2]
535 | return k_out, v_out
536 |
537 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
538 | """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
539 | return self.seen_tokens
540 |
541 | def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int:
542 | return self.seen_tokens
543 |
544 | def get_max_length(self) -> Optional[int]:
545 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
546 | return self.max_cache_len
547 |
548 | def reorder_cache(self, beam_idx: torch.LongTensor):
549 | """Reorders the cache for beam search, given the selected beam indices."""
550 | device = self.key_cache.device
551 | self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
552 | device = self.value_cache.device
553 | self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
554 |
555 | def to_legacy_cache(self):
556 | """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
557 | return None
558 |
--------------------------------------------------------------------------------
/llama_real_share/modeling_llama.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ PyTorch LLaMA model."""
21 | import math
22 | import warnings
23 | from typing import List, Optional, Tuple, Union
24 |
25 | import torch
26 | import torch.nn.functional as F
27 | import torch.utils.checkpoint
28 | from torch import nn
29 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30 |
31 | from transformers.activations import ACT2FN
32 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
33 | from transformers.modeling_outputs import (
34 | BaseModelOutputWithPast,
35 | CausalLMOutputWithPast,
36 | QuestionAnsweringModelOutput,
37 | SequenceClassifierOutputWithPast,
38 | )
39 | from transformers.modeling_utils import PreTrainedModel
40 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
41 | from transformers.utils import (
42 | add_start_docstrings,
43 | add_start_docstrings_to_model_forward,
44 | is_flash_attn_2_available,
45 | is_flash_attn_greater_or_equal_2_10,
46 | logging,
47 | replace_return_docstrings,
48 | )
49 | from transformers.models.llama.configuration_llama import LlamaConfig
50 |
51 |
52 | if is_flash_attn_2_available():
53 | from flash_attn import flash_attn_func, flash_attn_varlen_func
54 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
55 |
56 |
57 | logger = logging.get_logger(__name__)
58 |
59 | _CONFIG_FOR_DOC = "LlamaConfig"
60 |
61 |
62 | def _get_unpad_data(attention_mask):
63 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
64 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
65 | max_seqlen_in_batch = seqlens_in_batch.max().item()
66 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
67 | return (
68 | indices,
69 | cu_seqlens,
70 | max_seqlen_in_batch,
71 | )
72 |
73 |
74 | class LlamaRMSNorm(nn.Module):
75 | def __init__(self, hidden_size, eps=1e-6):
76 | """
77 | LlamaRMSNorm is equivalent to T5LayerNorm
78 | """
79 | super().__init__()
80 | self.weight = nn.Parameter(torch.ones(hidden_size))
81 | self.variance_epsilon = eps
82 |
83 | def forward(self, hidden_states):
84 | input_dtype = hidden_states.dtype
85 | hidden_states = hidden_states.to(torch.float32)
86 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
87 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
88 | return self.weight * hidden_states.to(input_dtype)
89 |
90 |
91 | ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)
92 |
93 |
94 | class LlamaRotaryEmbedding(nn.Module):
95 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
96 | super().__init__()
97 | self.scaling_factor = scaling_factor
98 | self.dim = dim
99 | self.max_position_embeddings = max_position_embeddings
100 | self.base = base
101 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
102 | self.register_buffer("inv_freq", inv_freq, persistent=False)
103 | # For BC we register cos and sin cached
104 | self.max_seq_len_cached = max_position_embeddings
105 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
106 | t = t / self.scaling_factor
107 | freqs = torch.outer(t, self.inv_freq)
108 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
109 | emb = torch.cat((freqs, freqs), dim=-1)
110 | self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
111 | self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
112 |
113 | @property
114 | def sin_cached(self):
115 | logger.warning_once(
116 | "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
117 | "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
118 | )
119 | return self._sin_cached
120 |
121 | @property
122 | def cos_cached(self):
123 | logger.warning_once(
124 | "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use "
125 | "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class"
126 | )
127 | return self._cos_cached
128 |
129 | @torch.no_grad()
130 | def forward(self, x, position_ids, seq_len=None):
131 | if seq_len is not None:
132 | logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")
133 |
134 | # x: [bs, num_attention_heads, seq_len, head_size]
135 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
136 | position_ids_expanded = position_ids[:, None, :].float()
137 | # Force float32 since bfloat16 loses precision on long contexts
138 | # See https://github.com/huggingface/transformers/pull/29285
139 | device_type = x.device.type
140 | device_type = device_type if isinstance(device_type, str) else "cpu"
141 | with torch.autocast(device_type=device_type, enabled=False):
142 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
143 | emb = torch.cat((freqs, freqs), dim=-1)
144 | cos = emb.cos()
145 | sin = emb.sin()
146 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
147 |
148 |
149 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
150 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
151 |
152 | def forward(self, x, position_ids, seq_len=None):
153 | # difference to the original RoPE: a scaling factor is aplied to the position ids
154 | position_ids = position_ids.float() / self.scaling_factor
155 | cos, sin = super().forward(x, position_ids, seq_len)
156 | return cos, sin
157 |
158 |
159 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
160 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
161 |
162 | def forward(self, x, position_ids, seq_len=None):
163 | # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
164 | seq_len = torch.max(position_ids) + 1
165 | if seq_len > self.max_position_embeddings:
166 | base = self.base * (
167 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
168 | ) ** (self.dim / (self.dim - 2))
169 | inv_freq = 1.0 / (
170 | base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
171 | )
172 | self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
173 |
174 | cos, sin = super().forward(x, position_ids, seq_len)
175 | return cos, sin
176 |
177 |
178 | def rotate_half(x):
179 | """Rotates half the hidden dims of the input."""
180 | x1 = x[..., : x.shape[-1] // 2]
181 | x2 = x[..., x.shape[-1] // 2 :]
182 | return torch.cat((-x2, x1), dim=-1)
183 |
184 |
185 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
186 | """Applies Rotary Position Embedding to the query and key tensors.
187 |
188 | Args:
189 | q (`torch.Tensor`): The query tensor.
190 | k (`torch.Tensor`): The key tensor.
191 | cos (`torch.Tensor`): The cosine part of the rotary embedding.
192 | sin (`torch.Tensor`): The sine part of the rotary embedding.
193 | position_ids (`torch.Tensor`, *optional*):
194 | Deprecated and unused.
195 | unsqueeze_dim (`int`, *optional*, defaults to 1):
196 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
197 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
198 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
199 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
200 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
201 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
202 | Returns:
203 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
204 | """
205 | cos = cos.unsqueeze(unsqueeze_dim)
206 | sin = sin.unsqueeze(unsqueeze_dim)
207 | q_embed = (q * cos) + (rotate_half(q) * sin)
208 | k_embed = (k * cos) + (rotate_half(k) * sin)
209 | return q_embed, k_embed
210 |
211 |
212 | class LlamaMLP(nn.Module):
213 | def __init__(self, config):
214 | super().__init__()
215 | self.config = config
216 | self.hidden_size = config.hidden_size
217 | self.intermediate_size = config.intermediate_size
218 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
219 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
220 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
221 | self.act_fn = ACT2FN[config.hidden_act]
222 |
223 | def forward(self, x):
224 | if self.config.pretraining_tp > 1:
225 | slice = self.intermediate_size // self.config.pretraining_tp
226 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
227 | up_proj_slices = self.up_proj.weight.split(slice, dim=0)
228 | down_proj_slices = self.down_proj.weight.split(slice, dim=1)
229 |
230 | gate_proj = torch.cat(
231 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
232 | )
233 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
234 |
235 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
236 | down_proj = [
237 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
238 | ]
239 | down_proj = sum(down_proj)
240 | else:
241 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
242 |
243 | return down_proj
244 |
245 |
246 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
247 | """
248 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
249 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
250 | """
251 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
252 | if n_rep == 1:
253 | return hidden_states
254 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
255 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
256 |
257 |
258 | class LlamaAttention(nn.Module):
259 | """Multi-headed attention from 'Attention Is All You Need' paper"""
260 |
261 | def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
262 | super().__init__()
263 | self.config = config
264 | self.layer_idx = layer_idx
265 | if layer_idx is None:
266 | logger.warning_once(
267 | f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
268 | "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
269 | "when creating this class."
270 | )
271 |
272 | self.attention_dropout = config.attention_dropout
273 | self.hidden_size = config.hidden_size
274 | self.num_heads = config.num_attention_heads
275 | self.head_dim = self.hidden_size // self.num_heads
276 | self.num_key_value_heads = config.num_key_value_heads
277 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads
278 | self.max_position_embeddings = config.max_position_embeddings
279 | self.rope_theta = config.rope_theta
280 | self.is_causal = True
281 |
282 | if (self.head_dim * self.num_heads) != self.hidden_size:
283 | raise ValueError(
284 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
285 | f" and `num_heads`: {self.num_heads})."
286 | )
287 |
288 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
289 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
290 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
291 | self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
292 | self._init_rope()
293 |
294 | def _init_rope(self):
295 | if self.config.rope_scaling is None:
296 | self.rotary_emb = LlamaRotaryEmbedding(
297 | self.head_dim,
298 | max_position_embeddings=self.max_position_embeddings,
299 | base=self.rope_theta,
300 | )
301 | else:
302 | scaling_type = self.config.rope_scaling["type"]
303 | scaling_factor = self.config.rope_scaling["factor"]
304 | if scaling_type == "linear":
305 | self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
306 | self.head_dim,
307 | max_position_embeddings=self.max_position_embeddings,
308 | scaling_factor=scaling_factor,
309 | base=self.rope_theta,
310 | )
311 | elif scaling_type == "dynamic":
312 | self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
313 | self.head_dim,
314 | max_position_embeddings=self.max_position_embeddings,
315 | scaling_factor=scaling_factor,
316 | base=self.rope_theta,
317 | )
318 | else:
319 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
320 |
321 | def forward(
322 | self,
323 | hidden_states: torch.Tensor,
324 | attention_mask: Optional[torch.Tensor] = None,
325 | position_ids: Optional[torch.LongTensor] = None,
326 | past_key_value: Optional[Cache] = None,
327 | output_attentions: bool = False,
328 | use_cache: bool = False,
329 | cache_position: Optional[torch.LongTensor] = None,
330 | **kwargs,
331 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
332 | bsz, q_len, _ = hidden_states.size()
333 |
334 | if self.config.pretraining_tp > 1:
335 | key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
336 | query_slices = self.q_proj.weight.split(
337 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
338 | )
339 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
340 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
341 |
342 | query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
343 | query_states = torch.cat(query_states, dim=-1)
344 |
345 | key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
346 | key_states = torch.cat(key_states, dim=-1)
347 |
348 | value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
349 | value_states = torch.cat(value_states, dim=-1)
350 |
351 | else:
352 | query_states = self.q_proj(hidden_states)
353 | key_states = self.k_proj(hidden_states)
354 | value_states = self.v_proj(hidden_states)
355 |
356 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
357 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
358 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
359 |
360 | past_key_value = getattr(self, "past_key_value", past_key_value)
361 | cos, sin = self.rotary_emb(value_states, position_ids)
362 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
363 |
364 | if past_key_value is not None:
365 | # sin and cos are specific to RoPE models; position_ids needed for the static cache
366 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
367 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
368 |
369 | key_states = repeat_kv(key_states, self.num_key_value_groups)
370 | value_states = repeat_kv(value_states, self.num_key_value_groups)
371 |
372 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
373 |
374 | if attention_mask is not None: # no matter the length, we just slice it
375 | causal_mask = attention_mask
376 | if cache_position is not None:
377 | causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
378 | attn_weights = attn_weights + causal_mask
379 |
380 | # upcast attention to fp32
381 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
382 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
383 | attn_output = torch.matmul(attn_weights, value_states)
384 |
385 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
386 | raise ValueError(
387 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
388 | f" {attn_output.size()}"
389 | )
390 |
391 | attn_output = attn_output.transpose(1, 2).contiguous()
392 |
393 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
394 |
395 | if self.config.pretraining_tp > 1:
396 | attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
397 | o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
398 | attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
399 | else:
400 | attn_output = self.o_proj(attn_output)
401 |
402 | if not output_attentions:
403 | attn_weights = None
404 |
405 | return attn_output, attn_weights, past_key_value
406 |
407 |
408 | class LlamaFlashAttention2(LlamaAttention):
409 | """
410 | Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
411 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
412 | flash attention and deal with padding tokens in case the input contains any of them.
413 | """
414 |
415 | def __init__(self, *args, **kwargs):
416 | super().__init__(*args, **kwargs)
417 |
418 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
419 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
420 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
421 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
422 |
423 | def forward(
424 | self,
425 | hidden_states: torch.Tensor,
426 | attention_mask: Optional[torch.LongTensor] = None,
427 | position_ids: Optional[torch.LongTensor] = None,
428 | past_key_value: Optional[Cache] = None,
429 | output_attentions: bool = False,
430 | use_cache: bool = False,
431 | cache_position: Optional[torch.LongTensor] = None,
432 | **kwargs,
433 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
434 | output_attentions = False
435 |
436 | bsz, q_len, _ = hidden_states.size()
437 |
438 | query_states = self.q_proj(hidden_states)
439 | key_states = self.k_proj(hidden_states)
440 | value_states = self.v_proj(hidden_states)
441 |
442 | # Flash attention requires the input to have the shape
443 | # batch_size x seq_length x head_dim x hidden_dim
444 | # therefore we just need to keep the original shape
445 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
446 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
447 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
448 |
449 | cos, sin = self.rotary_emb(value_states, position_ids)
450 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
451 |
452 | past_key_value = getattr(self, "past_key_value", past_key_value)
453 |
454 | if past_key_value is not None:
455 | # sin and cos are specific to RoPE models; position_ids needed for the static cache
456 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
457 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
458 |
459 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
460 | # to be able to avoid many of these transpose/reshape/view.
461 | query_states = query_states.transpose(1, 2)
462 | key_states = key_states.transpose(1, 2)
463 | value_states = value_states.transpose(1, 2)
464 |
465 | dropout_rate = self.attention_dropout if self.training else 0.0
466 |
467 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
468 | # therefore the input hidden states gets silently casted in float32. Hence, we need
469 | # cast them back in the correct dtype just to be sure everything works as expected.
470 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
471 | # in fp32. (LlamaRMSNorm handles it correctly)
472 |
473 | input_dtype = query_states.dtype
474 | if input_dtype == torch.float32:
475 | if torch.is_autocast_enabled():
476 | target_dtype = torch.get_autocast_gpu_dtype()
477 | # Handle the case where the model is quantized
478 | elif hasattr(self.config, "_pre_quantization_dtype"):
479 | target_dtype = self.config._pre_quantization_dtype
480 | else:
481 | target_dtype = self.q_proj.weight.dtype
482 |
483 | logger.warning_once(
484 | f"The input hidden states seems to be silently casted in float32, this might be related to"
485 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
486 | f" {target_dtype}."
487 | )
488 |
489 | query_states = query_states.to(target_dtype)
490 | key_states = key_states.to(target_dtype)
491 | value_states = value_states.to(target_dtype)
492 |
493 | attn_output = self._flash_attention_forward(
494 | query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
495 | )
496 |
497 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
498 | attn_output = self.o_proj(attn_output)
499 |
500 | if not output_attentions:
501 | attn_weights = None
502 |
503 | return attn_output, attn_weights, past_key_value
504 |
505 | def _flash_attention_forward(
506 | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
507 | ):
508 | """
509 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
510 | first unpad the input, then computes the attention scores and pad the final attention scores.
511 |
512 | Args:
513 | query_states (`torch.Tensor`):
514 | Input query states to be passed to Flash Attention API
515 | key_states (`torch.Tensor`):
516 | Input key states to be passed to Flash Attention API
517 | value_states (`torch.Tensor`):
518 | Input value states to be passed to Flash Attention API
519 | attention_mask (`torch.Tensor`):
520 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
521 | position of padding tokens and 1 for the position of non-padding tokens.
522 | dropout (`int`, *optional*):
523 | Attention dropout
524 | softmax_scale (`float`, *optional*):
525 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
526 | """
527 | if not self._flash_attn_uses_top_left_mask:
528 | causal = self.is_causal
529 | else:
530 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
531 | causal = self.is_causal and query_length != 1
532 |
533 | # Contains at least one padding token in the sequence
534 | if attention_mask is not None:
535 | batch_size = query_states.shape[0]
536 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
537 | query_states, key_states, value_states, attention_mask, query_length
538 | )
539 |
540 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens
541 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
542 |
543 | attn_output_unpad = flash_attn_varlen_func(
544 | query_states,
545 | key_states,
546 | value_states,
547 | cu_seqlens_q=cu_seqlens_q,
548 | cu_seqlens_k=cu_seqlens_k,
549 | max_seqlen_q=max_seqlen_in_batch_q,
550 | max_seqlen_k=max_seqlen_in_batch_k,
551 | dropout_p=dropout,
552 | softmax_scale=softmax_scale,
553 | causal=causal,
554 | )
555 |
556 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
557 | else:
558 | attn_output = flash_attn_func(
559 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
560 | )
561 |
562 | return attn_output
563 |
564 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
565 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
566 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
567 |
568 | key_layer = index_first_axis(
569 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
570 | )
571 | value_layer = index_first_axis(
572 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
573 | )
574 | if query_length == kv_seq_len:
575 | query_layer = index_first_axis(
576 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
577 | )
578 | cu_seqlens_q = cu_seqlens_k
579 | max_seqlen_in_batch_q = max_seqlen_in_batch_k
580 | indices_q = indices_k
581 | elif query_length == 1:
582 | max_seqlen_in_batch_q = 1
583 | cu_seqlens_q = torch.arange(
584 | batch_size + 1, dtype=torch.int32, device=query_layer.device
585 | ) # There is a memcpy here, that is very bad.
586 | indices_q = cu_seqlens_q[:-1]
587 | query_layer = query_layer.squeeze(1)
588 | else:
589 | # The -q_len: slice assumes left padding.
590 | attention_mask = attention_mask[:, -query_length:]
591 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
592 |
593 | return (
594 | query_layer,
595 | key_layer,
596 | value_layer,
597 | indices_q,
598 | (cu_seqlens_q, cu_seqlens_k),
599 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
600 | )
601 |
602 |
603 | class LlamaSdpaAttention(LlamaAttention):
604 | """
605 | Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
606 | `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
607 | SDPA API.
608 | """
609 |
610 | # Adapted from LlamaAttention.forward
611 | def forward(
612 | self,
613 | hidden_states: torch.Tensor,
614 | attention_mask: Optional[torch.Tensor] = None,
615 | position_ids: Optional[torch.LongTensor] = None,
616 | past_key_value: Optional[Cache] = None,
617 | output_attentions: bool = False,
618 | use_cache: bool = False,
619 | cache_position: Optional[torch.LongTensor] = None,
620 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
621 | if output_attentions:
622 | # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
623 | logger.warning_once(
624 | "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
625 | 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
626 | )
627 | return super().forward(
628 | hidden_states=hidden_states,
629 | attention_mask=attention_mask,
630 | position_ids=position_ids,
631 | past_key_value=past_key_value,
632 | output_attentions=output_attentions,
633 | use_cache=use_cache,
634 | cache_position=cache_position,
635 | )
636 |
637 | bsz, q_len, _ = hidden_states.size()
638 |
639 | query_states = self.q_proj(hidden_states)
640 | key_states = self.k_proj(hidden_states)
641 | value_states = self.v_proj(hidden_states)
642 |
643 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
644 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
645 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
646 |
647 | cos, sin = self.rotary_emb(value_states, position_ids)
648 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
649 |
650 | past_key_value = getattr(self, "past_key_value", past_key_value)
651 |
652 | if past_key_value is not None:
653 | # sin and cos are specific to RoPE models; position_ids needed for the static cache
654 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
655 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
656 |
657 | key_states = repeat_kv(key_states, self.num_key_value_groups)
658 | value_states = repeat_kv(value_states, self.num_key_value_groups)
659 |
660 | causal_mask = attention_mask
661 | if attention_mask is not None and cache_position is not None:
662 | causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
663 |
664 | # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
665 | # Reference: https://github.com/pytorch/pytorch/issues/112577.
666 | if query_states.device.type == "cuda" and causal_mask is not None:
667 | query_states = query_states.contiguous()
668 | key_states = key_states.contiguous()
669 | value_states = value_states.contiguous()
670 |
671 | attn_output = torch.nn.functional.scaled_dot_product_attention(
672 | query_states,
673 | key_states,
674 | value_states,
675 | attn_mask=causal_mask,
676 | dropout_p=self.attention_dropout if self.training else 0.0,
677 | )
678 |
679 | attn_output = attn_output.transpose(1, 2).contiguous()
680 | attn_output = attn_output.view(bsz, q_len, self.hidden_size)
681 |
682 | attn_output = self.o_proj(attn_output)
683 |
684 | return attn_output, None, past_key_value
685 |
686 |
687 | LLAMA_ATTENTION_CLASSES = {
688 | "eager": LlamaAttention,
689 | "flash_attention_2": LlamaFlashAttention2,
690 | "sdpa": LlamaSdpaAttention,
691 | }
692 |
693 |
694 | class LlamaDecoderLayer(nn.Module):
695 | def __init__(self, config: LlamaConfig, layer_idx: int):
696 | super().__init__()
697 | self.hidden_size = config.hidden_size
698 |
699 | self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
700 |
701 | self.mlp = LlamaMLP(config)
702 | self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
703 | self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
704 |
705 | def forward(
706 | self,
707 | hidden_states: torch.Tensor,
708 | attention_mask: Optional[torch.Tensor] = None,
709 | position_ids: Optional[torch.LongTensor] = None,
710 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
711 | output_attentions: Optional[bool] = False,
712 | use_cache: Optional[bool] = False,
713 | cache_position: Optional[torch.LongTensor] = None,
714 | **kwargs,
715 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
716 | """
717 | Args:
718 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
719 | attention_mask (`torch.FloatTensor`, *optional*):
720 | attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
721 | query_sequence_length, key_sequence_length)` if default attention is used.
722 | output_attentions (`bool`, *optional*):
723 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
724 | returned tensors for more detail.
725 | use_cache (`bool`, *optional*):
726 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
727 | (see `past_key_values`).
728 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
729 | """
730 | if "padding_mask" in kwargs:
731 | warnings.warn(
732 | "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
733 | )
734 |
735 | residual = hidden_states
736 |
737 | hidden_states = self.input_layernorm(hidden_states)
738 |
739 | # Self Attention
740 | hidden_states, self_attn_weights, present_key_value = self.self_attn(
741 | hidden_states=hidden_states,
742 | attention_mask=attention_mask,
743 | position_ids=position_ids,
744 | past_key_value=past_key_value,
745 | output_attentions=output_attentions,
746 | use_cache=use_cache,
747 | cache_position=cache_position,
748 | **kwargs,
749 | )
750 | hidden_states = residual + hidden_states
751 |
752 | # Fully Connected
753 | residual = hidden_states
754 | hidden_states = self.post_attention_layernorm(hidden_states)
755 | hidden_states = self.mlp(hidden_states)
756 | hidden_states = residual + hidden_states
757 |
758 | outputs = (hidden_states,)
759 |
760 | if output_attentions:
761 | outputs += (self_attn_weights,)
762 |
763 | if use_cache:
764 | outputs += (present_key_value,)
765 |
766 | return outputs
767 |
768 |
769 | LLAMA_START_DOCSTRING = r"""
770 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
771 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
772 | etc.)
773 |
774 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
775 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
776 | and behavior.
777 |
778 | Parameters:
779 | config ([`LlamaConfig`]):
780 | Model configuration class with all the parameters of the model. Initializing with a config file does not
781 | load the weights associated with the model, only the configuration. Check out the
782 | [`~PreTrainedModel.from_pretrained`] method to load the model weights.
783 | """
784 |
785 |
786 | @add_start_docstrings(
787 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
788 | LLAMA_START_DOCSTRING,
789 | )
790 | class LlamaPreTrainedModel(PreTrainedModel):
791 | config_class = LlamaConfig
792 | base_model_prefix = "model"
793 | supports_gradient_checkpointing = True
794 | _no_split_modules = ["LlamaDecoderLayer"]
795 | _skip_keys_device_placement = ["past_key_values", "causal_mask"]
796 | _supports_flash_attn_2 = True
797 | _supports_sdpa = True
798 | _supports_cache_class = True
799 |
800 | def _init_weights(self, module):
801 | std = self.config.initializer_range
802 | if isinstance(module, nn.Linear):
803 | module.weight.data.normal_(mean=0.0, std=std)
804 | if module.bias is not None:
805 | module.bias.data.zero_()
806 | elif isinstance(module, nn.Embedding):
807 | module.weight.data.normal_(mean=0.0, std=std)
808 | if module.padding_idx is not None:
809 | module.weight.data[module.padding_idx].zero_()
810 |
811 | def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None):
812 | if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache:
813 | raise ValueError(
814 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
815 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
816 | )
817 |
818 | if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
819 | causal_mask = torch.full(
820 | (max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
821 | )
822 | self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
823 |
824 | for layer in self.model.layers:
825 | weights = layer.self_attn.o_proj.weight
826 | layer.self_attn.past_key_value = cache_cls(
827 | self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype
828 | )
829 |
830 | def _reset_cache(self):
831 | for layer in self.model.layers:
832 | layer.self_attn.past_key_value = None
833 |
834 |
835 | LLAMA_INPUTS_DOCSTRING = r"""
836 | Args:
837 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
838 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
839 | it.
840 |
841 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
842 | [`PreTrainedTokenizer.__call__`] for details.
843 |
844 | [What are input IDs?](../glossary#input-ids)
845 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
846 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
847 |
848 | - 1 for tokens that are **not masked**,
849 | - 0 for tokens that are **masked**.
850 |
851 | [What are attention masks?](../glossary#attention-mask)
852 |
853 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
854 | [`PreTrainedTokenizer.__call__`] for details.
855 |
856 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
857 | `past_key_values`).
858 |
859 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
860 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
861 | information on the default strategy.
862 |
863 | - 1 indicates the head is **not masked**,
864 | - 0 indicates the head is **masked**.
865 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
866 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
867 | config.n_positions - 1]`.
868 |
869 | [What are position IDs?](../glossary#position-ids)
870 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
871 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
872 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
873 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
874 |
875 | Two formats are allowed:
876 | - a [`~cache_utils.Cache`] instance;
877 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
878 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
879 | cache format.
880 |
881 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
882 | legacy cache format will be returned.
883 |
884 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
885 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
886 | of shape `(batch_size, sequence_length)`.
887 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
888 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
889 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
890 | model's internal embedding lookup matrix.
891 | use_cache (`bool`, *optional*):
892 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
893 | `past_key_values`).
894 | output_attentions (`bool`, *optional*):
895 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
896 | tensors for more detail.
897 | output_hidden_states (`bool`, *optional*):
898 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
899 | more detail.
900 | return_dict (`bool`, *optional*):
901 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
902 | """
903 |
904 |
905 | @add_start_docstrings(
906 | "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
907 | LLAMA_START_DOCSTRING,
908 | )
909 | class LlamaModel(LlamaPreTrainedModel):
910 | """
911 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
912 |
913 | Args:
914 | config: LlamaConfig
915 | """
916 |
917 | def __init__(self, config: LlamaConfig):
918 | super().__init__(config)
919 | self.padding_idx = config.pad_token_id
920 | self.vocab_size = config.vocab_size
921 |
922 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
923 | self.layers = nn.ModuleList(
924 | [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
925 | )
926 | self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
927 | self.gradient_checkpointing = False
928 |
929 | # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
930 | # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
931 | causal_mask = torch.full(
932 | (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
933 | )
934 | self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
935 | # Initialize weights and apply final processing
936 | self.post_init()
937 |
938 | def get_input_embeddings(self):
939 | return self.embed_tokens
940 |
941 | def set_input_embeddings(self, value):
942 | self.embed_tokens = value
943 |
944 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
945 | def forward(
946 | self,
947 | input_ids: torch.LongTensor = None,
948 | attention_mask: Optional[torch.Tensor] = None,
949 | position_ids: Optional[torch.LongTensor] = None,
950 | past_key_values: Optional[List[torch.FloatTensor]] = None,
951 | inputs_embeds: Optional[torch.FloatTensor] = None,
952 | use_cache: Optional[bool] = None,
953 | output_attentions: Optional[bool] = None,
954 | output_hidden_states: Optional[bool] = None,
955 | return_dict: Optional[bool] = None,
956 | cache_position: Optional[torch.LongTensor] = None,
957 | ) -> Union[Tuple, BaseModelOutputWithPast]:
958 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
959 | output_hidden_states = (
960 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
961 | )
962 | use_cache = use_cache if use_cache is not None else self.config.use_cache
963 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
964 |
965 | if (input_ids is None) ^ (inputs_embeds is not None):
966 | raise ValueError(
967 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
968 | )
969 |
970 | if self.gradient_checkpointing and self.training and use_cache:
971 | logger.warning_once(
972 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
973 | )
974 | use_cache = False
975 |
976 | if inputs_embeds is None:
977 | inputs_embeds = self.embed_tokens(input_ids)
978 |
979 | past_seen_tokens = 0
980 | if use_cache: # kept for BC (cache positions)
981 | if not isinstance(past_key_values, StaticCache):
982 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
983 | past_seen_tokens = past_key_values.get_seq_length()
984 |
985 | if cache_position is None:
986 | cache_position = torch.arange(
987 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
988 | )
989 |
990 | if position_ids is None:
991 | position_ids = cache_position.unsqueeze(0)
992 |
993 | causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)
994 | print('past_key_values.key_cache', len(past_key_values.key_cache))
995 | # embed positions
996 | hidden_states = inputs_embeds
997 |
998 | # decoder layers
999 | all_hidden_states = () if output_hidden_states else None
1000 | all_self_attns = () if output_attentions else None
1001 | next_decoder_cache = None
1002 |
1003 | for decoder_layer in self.layers:
1004 | if output_hidden_states:
1005 | all_hidden_states += (hidden_states,)
1006 |
1007 | if self.gradient_checkpointing and self.training:
1008 | layer_outputs = self._gradient_checkpointing_func(
1009 | decoder_layer.__call__,
1010 | hidden_states,
1011 | causal_mask,
1012 | position_ids,
1013 | past_key_values,
1014 | output_attentions,
1015 | use_cache,
1016 | cache_position,
1017 | )
1018 | else:
1019 | layer_outputs = decoder_layer(
1020 | hidden_states,
1021 | attention_mask=causal_mask,
1022 | position_ids=position_ids,
1023 | past_key_value=past_key_values,
1024 | output_attentions=output_attentions,
1025 | use_cache=use_cache,
1026 | cache_position=cache_position,
1027 | )
1028 | print('past_key_values.key_cache2', len(past_key_values.key_cache))
1029 | hidden_states = layer_outputs[0]
1030 |
1031 | if use_cache:
1032 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1033 |
1034 | if output_attentions:
1035 | all_self_attns += (layer_outputs[1],)
1036 |
1037 | hidden_states = self.norm(hidden_states)
1038 |
1039 | # add hidden states from the last decoder layer
1040 | if output_hidden_states:
1041 | all_hidden_states += (hidden_states,)
1042 |
1043 | next_cache = None
1044 | if use_cache:
1045 | next_cache = (
1046 | next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
1047 | )
1048 | if not return_dict:
1049 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1050 | return BaseModelOutputWithPast(
1051 | last_hidden_state=hidden_states,
1052 | past_key_values=next_cache,
1053 | hidden_states=all_hidden_states,
1054 | attentions=all_self_attns,
1055 | )
1056 |
1057 | def _update_causal_mask(self, attention_mask, input_tensor):
1058 | if self.config._attn_implementation == "flash_attention_2":
1059 | if attention_mask is not None and 0.0 in attention_mask:
1060 | return attention_mask
1061 | return None
1062 |
1063 | batch_size, seq_length = input_tensor.shape[:2]
1064 | dtype = input_tensor.dtype
1065 | device = input_tensor.device
1066 |
1067 | # support going beyond cached `max_position_embedding`
1068 | if seq_length > self.causal_mask.shape[-1]:
1069 | causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1)
1070 | self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
1071 |
1072 | # We use the current dtype to avoid any overflows
1073 | min_dtype = torch.finfo(dtype).min
1074 | causal_mask = self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * min_dtype
1075 |
1076 | causal_mask = causal_mask.to(dtype=dtype, device=device)
1077 | if attention_mask is not None and attention_mask.dim() == 2:
1078 | mask_length = attention_mask.shape[-1]
1079 | padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1080 | causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1081 |
1082 | if self.config._attn_implementation == "sdpa" and attention_mask is not None:
1083 | # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
1084 | is_tracing = (
1085 | torch.jit.is_tracing()
1086 | or isinstance(input_tensor, torch.fx.Proxy)
1087 | or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
1088 | )
1089 | if not is_tracing and torch.any(attention_mask != 1):
1090 | # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
1091 | # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1092 | # Details: https://github.com/pytorch/pytorch/issues/110213
1093 | causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
1094 |
1095 | return causal_mask
1096 |
1097 |
1098 | class LlamaForCausalLM(LlamaPreTrainedModel):
1099 | _tied_weights_keys = ["lm_head.weight"]
1100 |
1101 | def __init__(self, config):
1102 | super().__init__(config)
1103 | self.model = LlamaModel(config)
1104 | self.vocab_size = config.vocab_size
1105 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1106 |
1107 | # Initialize weights and apply final processing
1108 | self.post_init()
1109 |
1110 | def get_input_embeddings(self):
1111 | return self.model.embed_tokens
1112 |
1113 | def set_input_embeddings(self, value):
1114 | self.model.embed_tokens = value
1115 |
1116 | def get_output_embeddings(self):
1117 | return self.lm_head
1118 |
1119 | def set_output_embeddings(self, new_embeddings):
1120 | self.lm_head = new_embeddings
1121 |
1122 | def set_decoder(self, decoder):
1123 | self.model = decoder
1124 |
1125 | def get_decoder(self):
1126 | return self.model
1127 |
1128 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1129 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1130 | def forward(
1131 | self,
1132 | input_ids: torch.LongTensor = None,
1133 | attention_mask: Optional[torch.Tensor] = None,
1134 | position_ids: Optional[torch.LongTensor] = None,
1135 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1136 | inputs_embeds: Optional[torch.FloatTensor] = None,
1137 | labels: Optional[torch.LongTensor] = None,
1138 | use_cache: Optional[bool] = None,
1139 | output_attentions: Optional[bool] = None,
1140 | output_hidden_states: Optional[bool] = None,
1141 | return_dict: Optional[bool] = None,
1142 | cache_position: Optional[torch.LongTensor] = None,
1143 | ) -> Union[Tuple, CausalLMOutputWithPast]:
1144 | r"""
1145 | Args:
1146 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1147 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1148 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1149 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1150 |
1151 | Returns:
1152 |
1153 | Example:
1154 |
1155 | ```python
1156 | >>> from transformers import AutoTokenizer, LlamaForCausalLM
1157 |
1158 | >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
1159 | >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
1160 |
1161 | >>> prompt = "Hey, are you conscious? Can you talk to me?"
1162 | >>> inputs = tokenizer(prompt, return_tensors="pt")
1163 |
1164 | >>> # Generate
1165 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1166 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1167 | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1168 | ```"""
1169 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1170 | output_hidden_states = (
1171 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1172 | )
1173 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1174 |
1175 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1176 | outputs = self.model(
1177 | input_ids=input_ids,
1178 | attention_mask=attention_mask,
1179 | position_ids=position_ids,
1180 | past_key_values=past_key_values,
1181 | inputs_embeds=inputs_embeds,
1182 | use_cache=use_cache,
1183 | output_attentions=output_attentions,
1184 | output_hidden_states=output_hidden_states,
1185 | return_dict=return_dict,
1186 | cache_position=cache_position,
1187 | )
1188 |
1189 | hidden_states = outputs[0]
1190 | if self.config.pretraining_tp > 1:
1191 | lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1192 | logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1193 | logits = torch.cat(logits, dim=-1)
1194 | else:
1195 | logits = self.lm_head(hidden_states)
1196 | logits = logits.float()
1197 |
1198 | loss = None
1199 | if labels is not None:
1200 | # Shift so that tokens < n predict n
1201 | shift_logits = logits[..., :-1, :].contiguous()
1202 | shift_labels = labels[..., 1:].contiguous()
1203 | # Flatten the tokens
1204 | loss_fct = CrossEntropyLoss()
1205 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
1206 | shift_labels = shift_labels.view(-1)
1207 | # Enable model parallelism
1208 | shift_labels = shift_labels.to(shift_logits.device)
1209 | loss = loss_fct(shift_logits, shift_labels)
1210 |
1211 | if not return_dict:
1212 | output = (logits,) + outputs[1:]
1213 | return (loss,) + output if loss is not None else output
1214 |
1215 | return CausalLMOutputWithPast(
1216 | loss=loss,
1217 | logits=logits,
1218 | past_key_values=outputs.past_key_values,
1219 | hidden_states=outputs.hidden_states,
1220 | attentions=outputs.attentions,
1221 | )
1222 |
1223 | def prepare_inputs_for_generation(
1224 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1225 | ):
1226 | past_length = 0
1227 | if past_key_values is not None:
1228 | if isinstance(past_key_values, Cache):
1229 | cache_length = past_key_values.get_seq_length()
1230 | past_length = past_key_values.seen_tokens
1231 | max_cache_length = past_key_values.get_max_length()
1232 | else:
1233 | cache_length = past_length = past_key_values[0][0].shape[2]
1234 | max_cache_length = None
1235 |
1236 | # Keep only the unprocessed tokens:
1237 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1238 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1239 | # input)
1240 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1241 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1242 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1243 | # input_ids based on the past_length.
1244 | elif past_length < input_ids.shape[1]:
1245 | input_ids = input_ids[:, past_length:]
1246 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1247 |
1248 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1249 | if (
1250 | max_cache_length is not None
1251 | and attention_mask is not None
1252 | and cache_length + input_ids.shape[1] > max_cache_length
1253 | ):
1254 | attention_mask = attention_mask[:, -max_cache_length:]
1255 |
1256 | position_ids = kwargs.get("position_ids", None)
1257 | if attention_mask is not None and position_ids is None:
1258 | # create position_ids on the fly for batch generation
1259 | position_ids = attention_mask.long().cumsum(-1) - 1
1260 | position_ids.masked_fill_(attention_mask == 0, 1)
1261 | if past_key_values:
1262 | position_ids = position_ids[:, -input_ids.shape[1] :]
1263 |
1264 | if self.generation_config.cache_implementation == "static":
1265 | # generation with static cache
1266 | cache_position = kwargs.get("cache_position", None)
1267 | if cache_position is None:
1268 | past_length = 0
1269 | else:
1270 | past_length = cache_position[-1] + 1
1271 | input_ids = input_ids[:, past_length:]
1272 | position_ids = position_ids[:, past_length:]
1273 |
1274 | # TODO @gante we should only keep a `cache_position` in generate, and do +=1.
1275 | # same goes for position ids. Could also help with continued generation.
1276 | cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device)
1277 |
1278 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1279 | if inputs_embeds is not None and past_key_values is None:
1280 | model_inputs = {"inputs_embeds": inputs_embeds}
1281 | else:
1282 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
1283 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
1284 | # TODO: use `next_tokens` directly instead.
1285 | model_inputs = {"input_ids": input_ids.contiguous()}
1286 |
1287 | model_inputs.update(
1288 | {
1289 | "position_ids": position_ids.contiguous(),
1290 | "cache_position": cache_position,
1291 | "past_key_values": past_key_values,
1292 | "use_cache": kwargs.get("use_cache"),
1293 | "attention_mask": attention_mask,
1294 | }
1295 | )
1296 | return model_inputs
1297 |
1298 | @staticmethod
1299 | def _reorder_cache(past_key_values, beam_idx):
1300 | reordered_past = ()
1301 | for layer_past in past_key_values:
1302 | reordered_past += (
1303 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1304 | )
1305 | return reordered_past
1306 |
1307 |
1308 | @add_start_docstrings(
1309 | """
1310 | The LLaMa Model transformer with a sequence classification head on top (linear layer).
1311 |
1312 | [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1313 | (e.g. GPT-2) do.
1314 |
1315 | Since it does classification on the last token, it requires to know the position of the last token. If a
1316 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1317 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1318 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1319 | each row of the batch).
1320 | """,
1321 | LLAMA_START_DOCSTRING,
1322 | )
1323 | class LlamaForSequenceClassification(LlamaPreTrainedModel):
1324 | def __init__(self, config):
1325 | super().__init__(config)
1326 | self.num_labels = config.num_labels
1327 | self.model = LlamaModel(config)
1328 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1329 |
1330 | # Initialize weights and apply final processing
1331 | self.post_init()
1332 |
1333 | def get_input_embeddings(self):
1334 | return self.model.embed_tokens
1335 |
1336 | def set_input_embeddings(self, value):
1337 | self.model.embed_tokens = value
1338 |
1339 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1340 | def forward(
1341 | self,
1342 | input_ids: torch.LongTensor = None,
1343 | attention_mask: Optional[torch.Tensor] = None,
1344 | position_ids: Optional[torch.LongTensor] = None,
1345 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1346 | inputs_embeds: Optional[torch.FloatTensor] = None,
1347 | labels: Optional[torch.LongTensor] = None,
1348 | use_cache: Optional[bool] = None,
1349 | output_attentions: Optional[bool] = None,
1350 | output_hidden_states: Optional[bool] = None,
1351 | return_dict: Optional[bool] = None,
1352 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1353 | r"""
1354 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1355 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1356 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1357 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1358 | """
1359 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1360 |
1361 | transformer_outputs = self.model(
1362 | input_ids,
1363 | attention_mask=attention_mask,
1364 | position_ids=position_ids,
1365 | past_key_values=past_key_values,
1366 | inputs_embeds=inputs_embeds,
1367 | use_cache=use_cache,
1368 | output_attentions=output_attentions,
1369 | output_hidden_states=output_hidden_states,
1370 | return_dict=return_dict,
1371 | )
1372 | hidden_states = transformer_outputs[0]
1373 | logits = self.score(hidden_states)
1374 |
1375 | if input_ids is not None:
1376 | batch_size = input_ids.shape[0]
1377 | else:
1378 | batch_size = inputs_embeds.shape[0]
1379 |
1380 | if self.config.pad_token_id is None and batch_size != 1:
1381 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1382 | if self.config.pad_token_id is None:
1383 | sequence_lengths = -1
1384 | else:
1385 | if input_ids is not None:
1386 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1387 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1388 | sequence_lengths = sequence_lengths % input_ids.shape[-1]
1389 | sequence_lengths = sequence_lengths.to(logits.device)
1390 | else:
1391 | sequence_lengths = -1
1392 |
1393 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1394 |
1395 | loss = None
1396 | if labels is not None:
1397 | labels = labels.to(logits.device)
1398 | if self.config.problem_type is None:
1399 | if self.num_labels == 1:
1400 | self.config.problem_type = "regression"
1401 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1402 | self.config.problem_type = "single_label_classification"
1403 | else:
1404 | self.config.problem_type = "multi_label_classification"
1405 |
1406 | if self.config.problem_type == "regression":
1407 | loss_fct = MSELoss()
1408 | if self.num_labels == 1:
1409 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1410 | else:
1411 | loss = loss_fct(pooled_logits, labels)
1412 | elif self.config.problem_type == "single_label_classification":
1413 | loss_fct = CrossEntropyLoss()
1414 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1415 | elif self.config.problem_type == "multi_label_classification":
1416 | loss_fct = BCEWithLogitsLoss()
1417 | loss = loss_fct(pooled_logits, labels)
1418 | if not return_dict:
1419 | output = (pooled_logits,) + transformer_outputs[1:]
1420 | return ((loss,) + output) if loss is not None else output
1421 |
1422 | return SequenceClassifierOutputWithPast(
1423 | loss=loss,
1424 | logits=pooled_logits,
1425 | past_key_values=transformer_outputs.past_key_values,
1426 | hidden_states=transformer_outputs.hidden_states,
1427 | attentions=transformer_outputs.attentions,
1428 | )
1429 |
1430 |
1431 | @add_start_docstrings(
1432 | """
1433 | The Llama Model transformer with a span classification head on top for extractive question-answering tasks like
1434 | SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1435 | """,
1436 | LLAMA_START_DOCSTRING,
1437 | )
1438 | class LlamaForQuestionAnswering(LlamaPreTrainedModel):
1439 | # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama
1440 | def __init__(self, config):
1441 | super().__init__(config)
1442 | self.transformer = LlamaModel(config)
1443 | self.qa_outputs = nn.Linear(config.hidden_size, 2)
1444 |
1445 | # Initialize weights and apply final processing
1446 | self.post_init()
1447 |
1448 | def get_input_embeddings(self):
1449 | return self.transformer.embed_tokens
1450 |
1451 | def set_input_embeddings(self, value):
1452 | self.transformer.embed_tokens = value
1453 |
1454 | @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
1455 | def forward(
1456 | self,
1457 | input_ids: Optional[torch.LongTensor] = None,
1458 | attention_mask: Optional[torch.FloatTensor] = None,
1459 | position_ids: Optional[torch.LongTensor] = None,
1460 | past_key_values: Optional[List[torch.FloatTensor]] = None,
1461 | inputs_embeds: Optional[torch.FloatTensor] = None,
1462 | start_positions: Optional[torch.LongTensor] = None,
1463 | end_positions: Optional[torch.LongTensor] = None,
1464 | output_attentions: Optional[bool] = None,
1465 | output_hidden_states: Optional[bool] = None,
1466 | return_dict: Optional[bool] = None,
1467 | ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1468 | r"""
1469 | start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1470 | Labels for position (index) of the start of the labelled span for computing the token classification loss.
1471 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1472 | are not taken into account for computing the loss.
1473 | end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1474 | Labels for position (index) of the end of the labelled span for computing the token classification loss.
1475 | Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1476 | are not taken into account for computing the loss.
1477 | """
1478 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1479 |
1480 | outputs = self.transformer(
1481 | input_ids,
1482 | attention_mask=attention_mask,
1483 | position_ids=position_ids,
1484 | past_key_values=past_key_values,
1485 | inputs_embeds=inputs_embeds,
1486 | output_attentions=output_attentions,
1487 | output_hidden_states=output_hidden_states,
1488 | return_dict=return_dict,
1489 | )
1490 |
1491 | sequence_output = outputs[0]
1492 |
1493 | logits = self.qa_outputs(sequence_output)
1494 | start_logits, end_logits = logits.split(1, dim=-1)
1495 | start_logits = start_logits.squeeze(-1).contiguous()
1496 | end_logits = end_logits.squeeze(-1).contiguous()
1497 |
1498 | total_loss = None
1499 | if start_positions is not None and end_positions is not None:
1500 | # If we are on multi-GPU, split add a dimension
1501 | if len(start_positions.size()) > 1:
1502 | start_positions = start_positions.squeeze(-1).to(start_logits.device)
1503 | if len(end_positions.size()) > 1:
1504 | end_positions = end_positions.squeeze(-1).to(end_logits.device)
1505 | # sometimes the start/end positions are outside our model inputs, we ignore these terms
1506 | ignored_index = start_logits.size(1)
1507 | start_positions = start_positions.clamp(0, ignored_index)
1508 | end_positions = end_positions.clamp(0, ignored_index)
1509 |
1510 | loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1511 | start_loss = loss_fct(start_logits, start_positions)
1512 | end_loss = loss_fct(end_logits, end_positions)
1513 | total_loss = (start_loss + end_loss) / 2
1514 |
1515 | if not return_dict:
1516 | output = (start_logits, end_logits) + outputs[2:]
1517 | return ((total_loss,) + output) if total_loss is not None else output
1518 |
1519 | return QuestionAnsweringModelOutput(
1520 | loss=total_loss,
1521 | start_logits=start_logits,
1522 | end_logits=end_logits,
1523 | hidden_states=outputs.hidden_states,
1524 | attentions=outputs.attentions,
1525 | )
--------------------------------------------------------------------------------