├── .gitignore
├── LICENSE
├── README.md
├── adaptive_snapkv
└── monkeypatch
│ ├── adaptive_llama_hijack.py
│ ├── adaptive_mistral_hijack.py
│ ├── fixed_llama_hijack.py
│ ├── fixed_mistral_hijack.py
│ ├── monkeypatch.py
│ ├── slm_llama_hijack.py
│ ├── slm_mistral_hijack.py
│ └── snapkv_utils.py
├── assets
└── images
│ ├── 4k_ruler_average_score.png
│ ├── LongBench_mistral.png
│ ├── LongBench_mistral_gqa.png
│ ├── head_vary.png
│ ├── main.png
│ ├── mem.png
│ └── speed.png
├── csrc
├── .gitignore
├── LICENSE
├── build.py
├── csrc
│ ├── cuda_api.cu
│ └── static_switch.h
├── include
│ └── cuda_api.h
└── makefile
├── experiments
└── LongBench
│ ├── GQA_eval_longbench.sh
│ ├── README.md
│ ├── config
│ ├── dataset2maxlen.json
│ ├── dataset2prompt.json
│ ├── model2maxlen.json
│ └── model2path.json
│ ├── convert_to_execl.py
│ ├── eval.py
│ ├── metrics.py
│ └── pred.py
├── makefile
└── pyproject.toml
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 | experiments/LongBench/pred/*
9 | experiments/LongBench/pred_e/*
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | notebooks/test*
164 | experiments/LongBench/bak/buggy
165 | experiments/LongBench/archive
166 | experiments/LongBench/ignore
167 | experiments/LongBench/Log
168 | experiments/LongBench/*pred
169 | experiments/LongBench/runall_*.sh
170 | .vscode
171 | .idea
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yuan Feng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
24 | This project includes portions of the following open-source projects:
25 |
26 | - SnapKV: https://github.com/FasterDecoding/SnapKV
27 | - PyramidKV (MIT License): https://github.com/Zefan-Cai/PyramidKV
28 |
29 | Each of these projects retains their respective copyright and license. For more details, refer to their original LICENSE files.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaKV
2 | Adaptive Budget Allocation across different attention heads based on their concentration degrees effectively improves budget utilization, thereby improving post-eviction generation quality.
3 |
4 |
5 |
6 |
7 |
8 | ## Updates
9 | * __[2024.11.08 GQA Support]__ In response to numerous community requests, we’ve just uploaded a new testing branch, named `test_gqa_support`, which introduces support for Grouped Query Attention (GQA) under the methods "SnapKV," "PyramidKV," "Ada-SnapKV," and "Ada-PyramidKV." You can use the `GQA_eval_longbench.sh` script for direct evaluation on the LongBench benchmark. Detailed results will be released soon.
10 | * __[2024.11.15 GQA Results]__ We have integrated GQA support for Mistral-7B-Instruct-v0.2 in SnapKV, PyramidKV, and our Ada-KV. You can try it by running the `GQA_eval_longbench.sh` script! Preliminary results on a single A800 are provided in the table below. In future updates, we will include more comprehensive GQA test results in the appendix of our [Ada-KV paper](arxiv.org/abs/2407.11550).
11 | 
12 | * __[2024.12.12 Community Collaboration ]__ We’re collaborating with NVIDIA’s [KVpress](https://github.com/NVIDIA/kvpress) team to integrate Ada-KV into their excellent project. This effort also aims to build a foundation for future research on head-specific cache compression methods. See the [draft PR](https://github.com/NVIDIA/kvpress/pull/25) for progress. Meanwhile, we’ve conducted broader evaluations of Ada-KV(Ada-SnapKV), including preliminary results on the 4K Ruler Benchmark (context compression with and without questions). For further insights into the evaluation, please refer to the kvpress [repository](https://github.com/NVIDIA/kvpress).
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Usage
20 |
21 | ### Requirements
22 |
23 | ```
24 | transformers==4.37.2
25 | flash-attn==2.4.0
26 |
27 | datasets
28 | tiktoken
29 | jieba
30 | rouge_score
31 | ```
32 |
33 | ### Installation
34 |
35 | ```
36 | git clone https://github.com/FFY0/AdaKV
37 | cd AdaKV
38 | make i
39 | ```
40 |
41 | ### Quick Start
42 |
43 | ```python
44 | # replace modeling with adakv
45 | from adaptive_snapkv.monkeypatch.monkeypatch import replace_mistral_adaptive, replace_llama_adaptive
46 | replace_mistral_adaptive()
47 | replace_llama_adaptive()
48 |
49 | model = AutoModelForCausalLM.from_pretrained(
50 | model_name_or_path,
51 | config=config,
52 | device_map=device_map,
53 | attn_implementation="flash_attention_2",
54 | torch_dtype=torch.bfloat16,
55 | trust_remote_code=True,
56 | )
57 |
58 | # config hyperparameters
59 | compress_args = {}
60 | def config_compress(model, window_size=32, base_capacity=512, kernel_size=7, pooling="maxpool", floor_alpha=0.5, pyram_mode = False, beta = 20):
61 | model.model.config.window_size = window_size
62 | model.model.config.base_capacity = base_capacity
63 | model.model.config.kernel_size = kernel_size
64 |
65 | model.model.config.pooling = pooling
66 | model.model.config.floor_alpha = floor_alpha
67 |
68 | model.model.config.pyram_mode = pyram_mode
69 | model.model.config.pyram_beta = beta
70 | return model
71 |
72 | model = config_compress(model, **compress_args)
73 | ```
74 |
75 | #### Flattened Storage and Flash Attention Support
76 |
77 | Considering varied cache length across heads, we implement a flattened storage layout of KV cache combined with `flash_attn_varlen_func` for efficent computation.
78 |
79 | ##### Regular MHA Cache Storage
80 |
81 | ```
82 | Layer i:
83 | head0: (t00, t01, t02)
84 | head1: (t10, t11, t12)
85 | head2: (t20, t21, t22)
86 |
87 | past_key_value.update():
88 |
89 | Layer i:
90 | head0: (t00, t01, t02, t03)
91 | head1: (t10, t11, t12, t13)
92 | head2: (t20, t21, t22, t23)
93 |
94 | ```
95 |
96 | Note. `tij` means cache element of token j on head i in this case.
97 |
98 | ##### Flattened Cache Storage
99 |
100 | The corresponding cuda code can be found in [`./csrc/csrc/cuda_api.cu`](./csrc/csrc/cuda_api.cu).
101 | ```
102 | Layer i:
103 | (t00, t01, t02, t03) (t10, t11) (t20, t21, t22)
104 |
105 | past_key_value.update():
106 |
107 | Layer i:
108 | phase 0: malloc empty cache
109 | (_, _, _, _, _) (_, _, _) (_, _, _, _)
110 |
111 | phase 1: copy old value
112 | (t00, t01, t02, t03, _) (t10, t11, _) (t20, t21, t22, _)
113 |
114 | phase 2: insert new value
115 | (t00, t01, t02, t03, t04) (t10, t11, t12) (t20, t21, t22, t23)
116 | ```
117 |
118 | Details about flash_attn_varlen_func can be found in [`Repo`](https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/flash_attn/flash_attn_interface.py#L1051).
119 |
120 | ##### Peak Memory Footprint and Decoding Latency For Our Implementation:
121 |
122 |
123 |
124 |
125 | ## Evaluations
126 | ### LongBench without GQA Support
127 |
128 | 
129 |
130 | ```bash
131 | cd ./experiments/LongBench
132 | bash runall.sh
133 | ```
134 |
135 |
136 | ## Citation
137 | If you found our work valuable, please cite:
138 | ```
139 | @misc{feng2024adakvoptimizingkvcache,
140 | title={Ada-KV: Optimizing KV Cache Eviction by Adaptive Budget Allocation for Efficient LLM Inference},
141 | author={Yuan Feng and Junlin Lv and Yukun Cao and Xike Xie and S. Kevin Zhou},
142 | year={2024},
143 | eprint={2407.11550},
144 | archivePrefix={arXiv},
145 | primaryClass={cs.CL},
146 | url={https://arxiv.org/abs/2407.11550},
147 | }
148 | ```
149 |
150 | ## Acknowledgement
151 |
152 | We extend our gratitude to [SnapKV](https://github.com/FasterDecoding/SnapKV) and [PyramidKV](https://github.com/Zefan-Cai/PyramidKV) for their contributions of open-source code, which have significantly facilitated the advancement of this project.
153 |
154 | ## Misc
155 |
156 | ### Observation
157 |
158 | Different attention heads within each layer of LLMs exhibit significant disparities in the degrees of attention concentration.
159 |
160 | Therefore, we can improves budget utilization by dynamically allocating the budget across different attention heads within the same layer based on their concentration degrees.
161 |
162 | 
163 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/adaptive_llama_hijack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import List, Optional, Tuple, Union
5 | import warnings
6 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
7 | from transformers.models.llama.modeling_llama import (
8 | apply_rotary_pos_emb,
9 | repeat_kv,
10 | )
11 | from transformers.utils import (
12 | logging,
13 | is_flash_attn_2_available,
14 | )
15 | from transformers.modeling_attn_mask_utils import (
16 | AttentionMaskConverter,
17 | _prepare_4d_attention_mask,
18 | _prepare_4d_causal_attention_mask,
19 | _prepare_4d_causal_attention_mask_for_sdpa,
20 | )
21 | from transformers.modeling_outputs import BaseModelOutputWithPast
22 | from transformers.models.llama.modeling_llama import (
23 | apply_rotary_pos_emb,
24 | repeat_kv,
25 | )
26 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
27 |
28 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_adaptive_snapkv, DynamicCacheSplitHeadFlatten
29 |
30 | logger = logging.get_logger(__name__)
31 |
32 | if is_flash_attn_2_available():
33 | from flash_attn import flash_attn_func, flash_attn_varlen_func
34 |
35 | def adaptive_LlamaModel_forward(
36 | self,
37 | input_ids: torch.LongTensor = None,
38 | attention_mask: Optional[torch.Tensor] = None,
39 | position_ids: Optional[torch.LongTensor] = None,
40 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
41 | inputs_embeds: Optional[torch.FloatTensor] = None,
42 | use_cache: Optional[bool] = None,
43 | output_attentions: Optional[bool] = None,
44 | output_hidden_states: Optional[bool] = None,
45 | return_dict: Optional[bool] = None,
46 | cache_position: Optional[torch.LongTensor] = None,
47 | ) -> Union[Tuple, BaseModelOutputWithPast]:
48 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
49 | output_hidden_states = (
50 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
51 | )
52 | use_cache = use_cache if use_cache is not None else self.config.use_cache
53 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
54 |
55 | if (input_ids is None) ^ (inputs_embeds is not None):
56 | raise ValueError(
57 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
58 | )
59 |
60 | if self.gradient_checkpointing and self.training and use_cache:
61 | logger.warning_once(
62 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
63 | )
64 | use_cache = False
65 |
66 | if inputs_embeds is None:
67 | inputs_embeds = self.embed_tokens(input_ids)
68 |
69 | # return_legacy_cache = False
70 | # if (
71 | # use_cache and not isinstance(past_key_values, Cache) and not self.training
72 | # ): # kept for BC (non `Cache` `past_key_values` inputs)
73 | # return_legacy_cache = True
74 | # past_key_values = DynamicCache.from_legacy_cache(past_key_values)
75 | # logger.warning_once(
76 | # "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
77 | # "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
78 | # )
79 |
80 | # NOTE: adakv
81 | return_legacy_cache = True
82 | past_key_values = DynamicCacheSplitHeadFlatten.from_legacy_cache(past_key_values)
83 | logger.warning_once(
84 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
85 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
86 | )
87 |
88 | if cache_position is None:
89 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
90 | cache_position = torch.arange(
91 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
92 | )
93 | if position_ids is None:
94 | position_ids = cache_position.unsqueeze(0)
95 |
96 | causal_mask = self._update_causal_mask(
97 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
98 | )
99 | hidden_states = inputs_embeds
100 |
101 | # create position embeddings to be shared across the decoder layers
102 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
103 |
104 | # decoder layers
105 | all_hidden_states = () if output_hidden_states else None
106 | all_self_attns = () if output_attentions else None
107 | next_decoder_cache = None
108 |
109 | for decoder_layer in self.layers:
110 | if output_hidden_states:
111 | all_hidden_states += (hidden_states,)
112 |
113 | if self.gradient_checkpointing and self.training:
114 | layer_outputs = self._gradient_checkpointing_func(
115 | decoder_layer.__call__,
116 | hidden_states,
117 | causal_mask,
118 | position_ids,
119 | past_key_values,
120 | output_attentions,
121 | use_cache,
122 | cache_position,
123 | position_embeddings,
124 | )
125 | else:
126 | layer_outputs = decoder_layer(
127 | hidden_states,
128 | attention_mask=causal_mask,
129 | position_ids=position_ids,
130 | past_key_value=past_key_values,
131 | output_attentions=output_attentions,
132 | use_cache=use_cache,
133 | cache_position=cache_position,
134 | position_embeddings=position_embeddings,
135 | )
136 |
137 | hidden_states = layer_outputs[0]
138 |
139 | if use_cache:
140 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
141 |
142 | if output_attentions:
143 | all_self_attns += (layer_outputs[1],)
144 |
145 | hidden_states = self.norm(hidden_states)
146 |
147 | # add hidden states from the last decoder layer
148 | if output_hidden_states:
149 | all_hidden_states += (hidden_states,)
150 |
151 | next_cache = next_decoder_cache if use_cache else None
152 | if return_legacy_cache:
153 | next_cache = next_cache.to_legacy_cache()
154 |
155 | hidden_states = hidden_states[:, -1,:].unsqueeze(1)
156 |
157 | if not return_dict:
158 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
159 | return BaseModelOutputWithPast(
160 | last_hidden_state=hidden_states,
161 | past_key_values=next_cache,
162 | hidden_states=all_hidden_states,
163 | attentions=all_self_attns,
164 | )
165 |
166 | def adaptive_llama_flash_attn2_forward(
167 | self,
168 | hidden_states: torch.Tensor,
169 | attention_mask: Optional[torch.LongTensor] = None,
170 | position_ids: Optional[torch.LongTensor] = None,
171 | past_key_value: Optional[Cache] = None,
172 | output_attentions: bool = False,
173 | use_cache: bool = False,
174 | cache_position: Optional[torch.LongTensor] = None,
175 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
176 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177 | # NOTE: adakv
178 | init_adaptive_snapkv(self)
179 | if isinstance(past_key_value, StaticCache):
180 | raise ValueError(
181 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
182 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
183 | )
184 |
185 | output_attentions = False
186 |
187 | bsz, q_len, _ = hidden_states.size()
188 |
189 | query_states = self.q_proj(hidden_states)
190 | key_states = self.k_proj(hidden_states)
191 | value_states = self.v_proj(hidden_states)
192 |
193 | # Flash attention requires the input to have the shape
194 | # batch_size x seq_length x head_dim x hidden_dim
195 | # therefore we just need to keep the original shape
196 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
197 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
198 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
199 |
200 | if position_embeddings is None:
201 | logger.warning_once(
202 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
203 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
204 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
205 | "removed and `position_embeddings` will be mandatory."
206 | )
207 | cos, sin = self.rotary_emb(value_states, position_ids)
208 | else:
209 | cos, sin = position_embeddings
210 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
211 |
212 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
213 |
214 | is_prefill = q_len != 1
215 |
216 | if is_prefill:
217 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states)
218 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
219 |
220 | # repeat k/v heads if n_kv_heads < n_heads
221 | # [SnapKV] move to ahead
222 | key_states = repeat_kv(key_states, self.num_key_value_groups)
223 | value_states = repeat_kv(value_states, self.num_key_value_groups)
224 |
225 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
226 | # to be able to avoid many of these transpose/reshape/view.
227 | query_states = query_states.transpose(1, 2)
228 | key_states = key_states.transpose(1, 2)
229 | value_states = value_states.transpose(1, 2)
230 |
231 | dropout_rate = self.attention_dropout if self.training else 0.0
232 |
233 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
234 | # therefore the input hidden states gets silently casted in float32. Hence, we need
235 | # cast them back in the correct dtype just to be sure everything works as expected.
236 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
237 | # in fp32. (LlamaRMSNorm handles it correctly)
238 |
239 | input_dtype = query_states.dtype
240 | if input_dtype == torch.float32:
241 | if torch.is_autocast_enabled():
242 | target_dtype = torch.get_autocast_gpu_dtype()
243 | # Handle the case where the model is quantized
244 | elif hasattr(self.config, "_pre_quantization_dtype"):
245 | target_dtype = self.config._pre_quantization_dtype
246 | else:
247 | target_dtype = self.q_proj.weight.dtype
248 |
249 | logger.warning_once(
250 | f"The input hidden states seems to be silently casted in float32, this might be related to"
251 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
252 | f" {target_dtype}."
253 | )
254 |
255 | query_states = query_states.to(target_dtype)
256 | key_states = key_states.to(target_dtype)
257 | value_states = value_states.to(target_dtype)
258 |
259 | attn_output = _flash_attention_forward(
260 | query_states,
261 | key_states,
262 | value_states,
263 | attention_mask,
264 | q_len,
265 | position_ids=position_ids,
266 | dropout=dropout_rate,
267 | sliding_window=getattr(self, "sliding_window", None),
268 | use_top_left_mask=self._flash_attn_uses_top_left_mask,
269 | is_causal=self.is_causal,
270 | )
271 |
272 | attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
273 | else:
274 | # decoding
275 | cache_kwargs["head_lens"] = self.kv_cluster.head_lens
276 | cache_kwargs["cu_klen"] = self.kv_cluster.cu_klen
277 |
278 | if not self.kv_cluster.gqa_support:
279 | key_states = repeat_kv(key_states, self.num_key_value_groups)
280 | value_states = repeat_kv(value_states, self.num_key_value_groups)
281 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
282 |
283 | # NOTE: update meta data
284 | self.kv_cluster.klen_sum += self.num_heads
285 | self.kv_cluster.max_seqlen_k += 1
286 | self.kv_cluster.cu_klen += self.kv_cluster.cu_offset
287 | self.kv_cluster.head_lens += 1
288 |
289 | if self.kv_cluster.gqa_support:
290 | query_states = query_states.view(-1, self.num_key_value_groups, self.head_dim)
291 | else:
292 | query_states = query_states.view(-1, 1, self.head_dim)
293 |
294 | key_states = key_states.view(-1,1,self.head_dim)
295 | value_states = value_states.view(-1,1,self.head_dim)
296 |
297 | cu_seqlens_q = self.kv_cluster.cu_qlen
298 | cu_seqlens_k = self.kv_cluster.cu_klen
299 | max_seqlen_q = 1
300 | max_seqlen_k = self.kv_cluster.max_seqlen_k
301 |
302 | attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens_q,
303 | cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True)
304 | # TODO: support batch size > 1
305 | assert bsz == 1
306 | attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim)
307 | attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
308 |
309 | attn_output = self.o_proj(attn_output)
310 |
311 | if not output_attentions:
312 | attn_weights = None
313 |
314 | return attn_output, attn_weights, past_key_value
315 |
316 | def prepare_inputs_for_generation_llama(
317 | self,
318 | input_ids,
319 | past_key_values=None,
320 | attention_mask=None,
321 | inputs_embeds=None,
322 | cache_position=None,
323 | position_ids=None,
324 | use_cache=True,
325 | **kwargs,
326 | ):
327 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
328 | # Exception 1: when passing input_embeds, input_ids may be missing entries
329 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
330 | if past_key_values is not None:
331 | if inputs_embeds is not None: # Exception 1
332 | input_ids = input_ids[:, -cache_position.shape[0] :]
333 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
334 | input_ids = input_ids[:, cache_position]
335 |
336 | if attention_mask is not None and position_ids is None:
337 | # create position_ids on the fly for batch generation
338 | position_ids = attention_mask.long().cumsum(-1) - 1
339 | position_ids.masked_fill_(attention_mask == 0, 1)
340 | if past_key_values:
341 | position_ids = position_ids[:, -input_ids.shape[1] :]
342 |
343 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
344 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
345 |
346 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
347 | if inputs_embeds is not None and cache_position[0] == 0:
348 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
349 | else:
350 | # The clone here is for the same reason as for `position_ids`.
351 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
352 |
353 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
354 | if model_inputs["inputs_embeds"] is not None:
355 | batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
356 | device = model_inputs["inputs_embeds"].device
357 | else:
358 | batch_size, sequence_length = model_inputs["input_ids"].shape
359 | device = model_inputs["input_ids"].device
360 |
361 | dtype = self.lm_head.weight.dtype
362 | min_dtype = torch.finfo(dtype).min
363 |
364 | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
365 | attention_mask,
366 | sequence_length=sequence_length,
367 | target_length=past_key_values.get_max_length(),
368 | dtype=dtype,
369 | device=device,
370 | min_dtype=min_dtype,
371 | cache_position=cache_position,
372 | batch_size=batch_size,
373 | )
374 |
375 | model_inputs.update(
376 | {
377 | "position_ids": position_ids,
378 | "cache_position": cache_position,
379 | "past_key_values": past_key_values,
380 | "use_cache": use_cache,
381 | "attention_mask": attention_mask,
382 | }
383 | )
384 | return model_inputs
385 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/adaptive_mistral_hijack.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from typing import List, Optional, Tuple, Union, Any,Dict
6 | import warnings
7 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
8 | from transformers.models.mistral.modeling_mistral import (
9 | apply_rotary_pos_emb,
10 | repeat_kv,
11 | )
12 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \
13 | _prepare_4d_causal_attention_mask
14 | from transformers.modeling_outputs import BaseModelOutputWithPast
15 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
16 | from transformers.models.mistral.modeling_mistral import (
17 | apply_rotary_pos_emb,
18 | repeat_kv,
19 | )
20 | from transformers.utils import (
21 | logging,
22 | is_flash_attn_2_available,
23 | )
24 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_adaptive_snapkv, DynamicCacheSplitHeadFlatten
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | if is_flash_attn_2_available():
29 | from flash_attn import flash_attn_func, flash_attn_varlen_func
30 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
31 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
32 |
33 |
34 | def adaptive_MistralModel_forward(
35 | self,
36 | input_ids: torch.LongTensor = None,
37 | attention_mask: Optional[torch.Tensor] = None,
38 | position_ids: Optional[torch.LongTensor] = None,
39 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
40 | inputs_embeds: Optional[torch.FloatTensor] = None,
41 | use_cache: Optional[bool] = None,
42 | output_attentions: Optional[bool] = None,
43 | output_hidden_states: Optional[bool] = None,
44 | return_dict: Optional[bool] = None,
45 | cache_position: Optional[torch.LongTensor] = None,
46 | ) -> Union[Tuple, BaseModelOutputWithPast]:
47 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
48 | output_hidden_states = (
49 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
50 | )
51 | use_cache = use_cache if use_cache is not None else self.config.use_cache
52 |
53 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
54 |
55 | # retrieve input_ids and inputs_embeds
56 | if (input_ids is None) ^ (inputs_embeds is not None):
57 | raise ValueError(
58 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
59 | )
60 |
61 | if self.gradient_checkpointing and self.training and use_cache:
62 | logger.warning_once(
63 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
64 | )
65 | use_cache = False
66 |
67 | if inputs_embeds is None:
68 | inputs_embeds = self.embed_tokens(input_ids)
69 |
70 | past_key_values = DynamicCacheSplitHeadFlatten.from_legacy_cache(past_key_values)
71 | return_legacy_cache = True
72 |
73 | if cache_position is None:
74 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
75 | cache_position = torch.arange(
76 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
77 | )
78 |
79 | if position_ids is None:
80 | position_ids = cache_position.unsqueeze(0)
81 |
82 | causal_mask = self._update_causal_mask(
83 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
84 | )
85 |
86 | hidden_states = inputs_embeds
87 |
88 | # decoder layers
89 | all_hidden_states = () if output_hidden_states else None
90 | all_self_attns = () if output_attentions else None
91 | next_decoder_cache = None
92 |
93 | for decoder_layer in self.layers:
94 | if output_hidden_states:
95 | all_hidden_states += (hidden_states,)
96 |
97 | if self.gradient_checkpointing and self.training:
98 | layer_outputs = self._gradient_checkpointing_func(
99 | decoder_layer.__call__,
100 | hidden_states,
101 | causal_mask,
102 | position_ids,
103 | past_key_values,
104 | output_attentions,
105 | use_cache,
106 | cache_position,
107 | )
108 | else:
109 | layer_outputs = decoder_layer(
110 | hidden_states,
111 | attention_mask=causal_mask,
112 | position_ids=position_ids,
113 | past_key_value=past_key_values,
114 | output_attentions=output_attentions,
115 | use_cache=use_cache,
116 | cache_position=cache_position,
117 | )
118 |
119 | hidden_states = layer_outputs[0]
120 |
121 | if use_cache:
122 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
123 |
124 | if output_attentions:
125 | all_self_attns += (layer_outputs[1],)
126 |
127 | hidden_states = self.norm(hidden_states)
128 |
129 | # add hidden states from the last decoder layer
130 | if output_hidden_states:
131 | all_hidden_states += (hidden_states,)
132 |
133 | next_cache = next_decoder_cache if use_cache else None
134 | if return_legacy_cache:
135 | next_cache = next_cache.to_legacy_cache()
136 |
137 | if not return_dict:
138 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
139 |
140 | hidden_states = hidden_states[:, -1,:].unsqueeze(1)
141 | return BaseModelOutputWithPast(
142 | last_hidden_state=hidden_states,
143 | past_key_values=next_cache,
144 | hidden_states=all_hidden_states,
145 | attentions=all_self_attns,
146 | )
147 |
148 | def adaptive_mistral_flash_attn2_forward(
149 | self,
150 | hidden_states: torch.Tensor,
151 | attention_mask: Optional[torch.Tensor] = None,
152 | position_ids: Optional[torch.LongTensor] = None,
153 | past_key_value: Optional[Cache] = None,
154 | output_attentions: bool = False,
155 | use_cache: bool = False,
156 | cache_position: Optional[torch.LongTensor] = None,
157 | ):
158 | if isinstance(past_key_value, StaticCache):
159 | raise ValueError(
160 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
161 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
162 | )
163 | # NOTE: adakv
164 | init_adaptive_snapkv(self)
165 |
166 | output_attentions = False
167 |
168 | bsz, q_len, _ = hidden_states.size()
169 |
170 | query_states = self.q_proj(hidden_states)
171 | key_states = self.k_proj(hidden_states)
172 | value_states = self.v_proj(hidden_states)
173 |
174 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
175 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
176 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
177 |
178 | kv_seq_len = key_states.shape[-2]
179 | if past_key_value is not None:
180 | kv_seq_len += cache_position[0]
181 |
182 | cos, sin = self.rotary_emb(value_states, position_ids)
183 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
184 |
185 | dropout_rate = 0.0 if not self.training else self.attention_dropout
186 |
187 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
188 | if past_key_value is not None:
189 | # Activate slicing cache only if the config has a value `sliding_windows` attribute
190 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
191 | if (
192 | getattr(self.config, "sliding_window", None) is not None
193 | and kv_seq_len > self.config.sliding_window
194 | and cache_has_contents
195 | ):
196 | slicing_tokens = 1 - self.config.sliding_window
197 |
198 | past_key = past_key_value[self.layer_idx][0]
199 | past_value = past_key_value[self.layer_idx][1]
200 |
201 | past_key = past_key[:, :, slicing_tokens:, :].contiguous()
202 | past_value = past_value[:, :, slicing_tokens:, :].contiguous()
203 |
204 | if past_key.shape[-2] != self.config.sliding_window - 1:
205 | raise ValueError(
206 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
207 | f" {past_key.shape}"
208 | )
209 |
210 | if attention_mask is not None:
211 | attention_mask = attention_mask[:, slicing_tokens:]
212 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
213 |
214 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
215 | # therefore the input hidden states gets silently casted in float32. Hence, we need
216 | # cast them back in float16 just to be sure everything works as expected.
217 | input_dtype = query_states.dtype
218 | if input_dtype == torch.float32:
219 | if torch.is_autocast_enabled():
220 | target_dtype = torch.get_autocast_gpu_dtype()
221 | # Handle the case where the model is quantized
222 | elif hasattr(self.config, "_pre_quantization_dtype"):
223 | target_dtype = self.config._pre_quantization_dtype
224 | else:
225 | target_dtype = self.q_proj.weight.dtype
226 |
227 | logger.warning_once(
228 | f"The input hidden states seems to be silently casted in float32, this might be related to"
229 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
230 | f" {target_dtype}."
231 | )
232 |
233 | query_states = query_states.to(target_dtype)
234 | key_states = key_states.to(target_dtype)
235 | value_states = value_states.to(target_dtype)
236 |
237 | # TODO: naive for now
238 | is_prefill = q_len != 1
239 |
240 | if is_prefill:
241 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states)
242 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
243 |
244 | # repeat k/v heads if n_kv_heads < n_heads
245 | # [SnapKV] move to ahead
246 | key_states = repeat_kv(key_states, self.num_key_value_groups)
247 | value_states = repeat_kv(value_states, self.num_key_value_groups)
248 |
249 | # Reashape to the expected shape for Flash Attention
250 | query_states = query_states.transpose(1, 2)
251 | key_states = key_states.transpose(1, 2)
252 | value_states = value_states.transpose(1, 2)
253 |
254 | attn_output = _flash_attention_forward(
255 | query_states,
256 | key_states,
257 | value_states,
258 | attention_mask,
259 | q_len,
260 | position_ids=position_ids,
261 | dropout=dropout_rate,
262 | sliding_window=getattr(self.config, "sliding_window", None),
263 | use_top_left_mask=self._flash_attn_uses_top_left_mask,
264 | is_causal=self.is_causal,
265 | )
266 |
267 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
268 |
269 | else:
270 | # decoding
271 | cache_kwargs["head_lens"] = self.kv_cluster.head_lens
272 | cache_kwargs["cu_klen"] = self.kv_cluster.cu_klen
273 | # gqa_support
274 | if not self.kv_cluster.gqa_support:
275 | key_states = repeat_kv(key_states, self.num_key_value_groups)
276 | value_states = repeat_kv(value_states, self.num_key_value_groups)
277 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
278 |
279 |
280 | # NOTE: update meta data
281 | self.kv_cluster.klen_sum += self.num_heads
282 | self.kv_cluster.max_seqlen_k += 1
283 | self.kv_cluster.cu_klen += self.kv_cluster.cu_offset
284 | self.kv_cluster.head_lens += 1
285 |
286 | if self.kv_cluster.gqa_support:
287 | query_states = query_states.view(-1, self.num_key_value_groups, self.head_dim)
288 | else:
289 | query_states = query_states.view(-1, 1, self.head_dim)
290 |
291 | key_states = key_states.view(-1,1,self.head_dim)
292 | value_states = value_states.view(-1,1,self.head_dim)
293 |
294 | cu_seqlens_q = self.kv_cluster.cu_qlen
295 | cu_seqlens_k = self.kv_cluster.cu_klen
296 | max_seqlen_q = 1
297 | max_seqlen_k = self.kv_cluster.max_seqlen_k
298 |
299 | attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens_q,
300 | cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True)
301 | # TODO: support batch size > 1
302 | assert bsz == 1
303 | attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim)
304 | attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
305 |
306 | attn_output = self.o_proj(attn_output)
307 |
308 | if not output_attentions:
309 | attn_weights = None
310 |
311 | return attn_output, attn_weights, past_key_value
312 |
313 | def prepare_inputs_for_generation_mistral(
314 | self,
315 | input_ids,
316 | past_key_values=None,
317 | attention_mask=None,
318 | inputs_embeds=None,
319 | cache_position=None,
320 | position_ids=None,
321 | use_cache=True,
322 | **kwargs,
323 | ):
324 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
325 | # Exception 1: when passing input_embeds, input_ids may be missing entries
326 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
327 | if past_key_values is not None:
328 | if inputs_embeds is not None: # Exception 1
329 | input_ids = input_ids[:, -cache_position.shape[0] :]
330 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
331 | input_ids = input_ids[:, cache_position]
332 |
333 | if attention_mask is not None and position_ids is None:
334 | # create position_ids on the fly for batch generation
335 | position_ids = attention_mask.long().cumsum(-1) - 1
336 | position_ids.masked_fill_(attention_mask == 0, 1)
337 | if past_key_values:
338 | position_ids = position_ids[:, -input_ids.shape[1] :]
339 |
340 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
341 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
342 |
343 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
344 | if inputs_embeds is not None and cache_position[0] == 0:
345 | model_inputs = {"inputs_embeds": inputs_embeds}
346 | else:
347 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
348 |
349 | model_inputs.update(
350 | {
351 | "position_ids": position_ids,
352 | "cache_position": cache_position,
353 | "past_key_values": past_key_values,
354 | "use_cache": use_cache,
355 | "attention_mask": attention_mask,
356 | }
357 | )
358 | return model_inputs
359 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/fixed_llama_hijack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import List, Optional, Tuple, Union
5 | import warnings
6 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
7 | from transformers.modeling_outputs import BaseModelOutputWithPast
8 | from transformers.models.llama.modeling_llama import (
9 | apply_rotary_pos_emb,
10 | repeat_kv,
11 | )
12 | from transformers.utils import (
13 | logging,
14 | )
15 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
16 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_snapkv
17 |
18 | logger = logging.get_logger(__name__)
19 |
20 | def fixed_LlamaModel_forward(
21 | self,
22 | input_ids: torch.LongTensor = None,
23 | attention_mask: Optional[torch.Tensor] = None,
24 | position_ids: Optional[torch.LongTensor] = None,
25 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
26 | inputs_embeds: Optional[torch.FloatTensor] = None,
27 | use_cache: Optional[bool] = None,
28 | output_attentions: Optional[bool] = None,
29 | output_hidden_states: Optional[bool] = None,
30 | return_dict: Optional[bool] = None,
31 | cache_position: Optional[torch.LongTensor] = None,
32 | ) -> Union[Tuple, BaseModelOutputWithPast]:
33 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
34 | output_hidden_states = (
35 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
36 | )
37 | use_cache = use_cache if use_cache is not None else self.config.use_cache
38 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
39 |
40 | if (input_ids is None) ^ (inputs_embeds is not None):
41 | raise ValueError(
42 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
43 | )
44 |
45 | if self.gradient_checkpointing and self.training and use_cache:
46 | logger.warning_once(
47 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
48 | )
49 | use_cache = False
50 |
51 | if inputs_embeds is None:
52 | inputs_embeds = self.embed_tokens(input_ids)
53 |
54 | return_legacy_cache = False
55 | if (
56 | use_cache and not isinstance(past_key_values, Cache) and not self.training
57 | ): # kept for BC (non `Cache` `past_key_values` inputs)
58 | return_legacy_cache = True
59 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
60 | logger.warning_once(
61 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
62 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
63 | )
64 |
65 | if cache_position is None:
66 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
67 | cache_position = torch.arange(
68 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
69 | )
70 | if position_ids is None:
71 | position_ids = cache_position.unsqueeze(0)
72 |
73 | causal_mask = self._update_causal_mask(
74 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
75 | )
76 | hidden_states = inputs_embeds
77 |
78 | # create position embeddings to be shared across the decoder layers
79 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
80 |
81 | # decoder layers
82 | all_hidden_states = () if output_hidden_states else None
83 | all_self_attns = () if output_attentions else None
84 | next_decoder_cache = None
85 |
86 | for decoder_layer in self.layers:
87 | if output_hidden_states:
88 | all_hidden_states += (hidden_states,)
89 |
90 | if self.gradient_checkpointing and self.training:
91 | layer_outputs = self._gradient_checkpointing_func(
92 | decoder_layer.__call__,
93 | hidden_states,
94 | causal_mask,
95 | position_ids,
96 | past_key_values,
97 | output_attentions,
98 | use_cache,
99 | cache_position,
100 | position_embeddings,
101 | )
102 | else:
103 | layer_outputs = decoder_layer(
104 | hidden_states,
105 | attention_mask=causal_mask,
106 | position_ids=position_ids,
107 | past_key_value=past_key_values,
108 | output_attentions=output_attentions,
109 | use_cache=use_cache,
110 | cache_position=cache_position,
111 | position_embeddings=position_embeddings,
112 | )
113 |
114 | hidden_states = layer_outputs[0]
115 |
116 | if use_cache:
117 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
118 |
119 | if output_attentions:
120 | all_self_attns += (layer_outputs[1],)
121 |
122 | hidden_states = self.norm(hidden_states)
123 |
124 | # add hidden states from the last decoder layer
125 | if output_hidden_states:
126 | all_hidden_states += (hidden_states,)
127 |
128 | next_cache = next_decoder_cache if use_cache else None
129 | if return_legacy_cache:
130 | next_cache = next_cache.to_legacy_cache()
131 |
132 | hidden_states = hidden_states[:, -1,:].unsqueeze(1)
133 |
134 | if not return_dict:
135 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
136 | return BaseModelOutputWithPast(
137 | last_hidden_state=hidden_states,
138 | past_key_values=next_cache,
139 | hidden_states=all_hidden_states,
140 | attentions=all_self_attns,
141 | )
142 |
143 |
144 | def fixed_llama_flash_attn2_forward(
145 | self,
146 | hidden_states: torch.Tensor,
147 | attention_mask: Optional[torch.LongTensor] = None,
148 | position_ids: Optional[torch.LongTensor] = None,
149 | past_key_value: Optional[Cache] = None,
150 | output_attentions: bool = False,
151 | use_cache: bool = False,
152 | cache_position: Optional[torch.LongTensor] = None,
153 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
154 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
155 | if isinstance(past_key_value, StaticCache):
156 | raise ValueError(
157 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
158 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
159 | )
160 | init_snapkv(self)
161 |
162 | output_attentions = False
163 |
164 | bsz, q_len, _ = hidden_states.size()
165 |
166 | query_states = self.q_proj(hidden_states)
167 | key_states = self.k_proj(hidden_states)
168 | value_states = self.v_proj(hidden_states)
169 |
170 | # Flash attention requires the input to have the shape
171 | # batch_size x seq_length x head_dim x hidden_dim
172 | # therefore we just need to keep the original shape
173 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
174 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
175 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
176 |
177 | if position_embeddings is None:
178 | logger.warning_once(
179 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
180 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
181 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
182 | "removed and `position_embeddings` will be mandatory."
183 | )
184 | cos, sin = self.rotary_emb(value_states, position_ids)
185 | else:
186 | cos, sin = position_embeddings
187 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
188 |
189 |
190 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
191 | if past_key_value is not None:
192 | # NOTE: decoding update
193 | if q_len == 1:
194 | # support gqa
195 | if not self.kv_cluster.gqa_support:
196 | key_states = repeat_kv(key_states, self.num_key_value_groups)
197 | value_states = repeat_kv(value_states, self.num_key_value_groups)
198 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
199 | else:
200 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states)
201 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
202 |
203 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
204 | # to be able to avoid many of these transpose/reshape/view.
205 |
206 | dropout_rate = self.attention_dropout if self.training else 0.0
207 |
208 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
209 | # therefore the input hidden states gets silently casted in float32. Hence, we need
210 | # cast them back in the correct dtype just to be sure everything works as expected.
211 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
212 | # in fp32. (LlamaRMSNorm handles it correctly)
213 |
214 | input_dtype = query_states.dtype
215 | if input_dtype == torch.float32:
216 | if torch.is_autocast_enabled():
217 | target_dtype = torch.get_autocast_gpu_dtype()
218 | # Handle the case where the model is quantized
219 | elif hasattr(self.config, "_pre_quantization_dtype"):
220 | target_dtype = self.config._pre_quantization_dtype
221 | else:
222 | target_dtype = self.q_proj.weight.dtype
223 |
224 | logger.warning_once(
225 | f"The input hidden states seems to be silently casted in float32, this might be related to"
226 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
227 | f" {target_dtype}."
228 | )
229 |
230 | query_states = query_states.to(target_dtype)
231 | key_states = key_states.to(target_dtype)
232 | value_states = value_states.to(target_dtype)
233 |
234 | # Reashape to the expected shape for Flash Attention
235 | query_states = query_states.transpose(1, 2)
236 | key_states = key_states.transpose(1, 2)
237 | value_states = value_states.transpose(1, 2)
238 |
239 | attn_output = _flash_attention_forward(
240 | query_states,
241 | key_states,
242 | value_states,
243 | attention_mask,
244 | q_len,
245 | position_ids=position_ids,
246 | dropout=dropout_rate,
247 | sliding_window=getattr(self, "sliding_window", None),
248 | use_top_left_mask=self._flash_attn_uses_top_left_mask,
249 | is_causal=self.is_causal,
250 | )
251 |
252 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
253 | attn_output = self.o_proj(attn_output)
254 |
255 | if not output_attentions:
256 | attn_weights = None
257 |
258 | return attn_output, attn_weights, past_key_value
259 |
260 | def prepare_inputs_for_generation_llama(
261 | self,
262 | input_ids,
263 | past_key_values=None,
264 | attention_mask=None,
265 | inputs_embeds=None,
266 | cache_position=None,
267 | position_ids=None,
268 | use_cache=True,
269 | **kwargs,
270 | ):
271 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
272 | # Exception 1: when passing input_embeds, input_ids may be missing entries
273 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
274 | if past_key_values is not None:
275 | if inputs_embeds is not None: # Exception 1
276 | input_ids = input_ids[:, -cache_position.shape[0] :]
277 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
278 | input_ids = input_ids[:, cache_position]
279 |
280 | if attention_mask is not None and position_ids is None:
281 | # create position_ids on the fly for batch generation
282 | position_ids = attention_mask.long().cumsum(-1) - 1
283 | position_ids.masked_fill_(attention_mask == 0, 1)
284 | if past_key_values:
285 | position_ids = position_ids[:, -input_ids.shape[1] :]
286 |
287 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
288 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
289 |
290 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
291 | if inputs_embeds is not None and cache_position[0] == 0:
292 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
293 | else:
294 | # The clone here is for the same reason as for `position_ids`.
295 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
296 |
297 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
298 | if model_inputs["inputs_embeds"] is not None:
299 | batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
300 | device = model_inputs["inputs_embeds"].device
301 | else:
302 | batch_size, sequence_length = model_inputs["input_ids"].shape
303 | device = model_inputs["input_ids"].device
304 |
305 | dtype = self.lm_head.weight.dtype
306 | min_dtype = torch.finfo(dtype).min
307 |
308 | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
309 | attention_mask,
310 | sequence_length=sequence_length,
311 | target_length=past_key_values.get_max_length(),
312 | dtype=dtype,
313 | device=device,
314 | min_dtype=min_dtype,
315 | cache_position=cache_position,
316 | batch_size=batch_size,
317 | )
318 |
319 | model_inputs.update(
320 | {
321 | "position_ids": position_ids,
322 | "cache_position": cache_position,
323 | "past_key_values": past_key_values,
324 | "use_cache": use_cache,
325 | "attention_mask": attention_mask,
326 | }
327 | )
328 | return model_inputs
329 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/fixed_mistral_hijack.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import torch
3 | from typing import List, Optional, Tuple, Union
4 | import warnings
5 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
6 | from transformers.modeling_outputs import BaseModelOutputWithPast
7 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
8 | from transformers.models.mistral.modeling_mistral import (
9 | apply_rotary_pos_emb,
10 | repeat_kv,
11 | )
12 | from transformers.utils import (
13 | logging,
14 | is_flash_attn_2_available,
15 | )
16 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_snapkv
17 | from flash_attn import flash_attn_func
18 |
19 | logger = logging.get_logger(__name__)
20 |
21 | if is_flash_attn_2_available():
22 | from flash_attn import flash_attn_func, flash_attn_varlen_func
23 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
24 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
25 |
26 | def fixed_MistralModel_forward(
27 | self,
28 | input_ids: torch.LongTensor = None,
29 | attention_mask: Optional[torch.Tensor] = None,
30 | position_ids: Optional[torch.LongTensor] = None,
31 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
32 | inputs_embeds: Optional[torch.FloatTensor] = None,
33 | use_cache: Optional[bool] = None,
34 | output_attentions: Optional[bool] = None,
35 | output_hidden_states: Optional[bool] = None,
36 | return_dict: Optional[bool] = None,
37 | cache_position: Optional[torch.LongTensor] = None,
38 | ) -> Union[Tuple, BaseModelOutputWithPast]:
39 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
40 | output_hidden_states = (
41 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
42 | )
43 | use_cache = use_cache if use_cache is not None else self.config.use_cache
44 |
45 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
46 |
47 | # retrieve input_ids and inputs_embeds
48 | if (input_ids is None) ^ (inputs_embeds is not None):
49 | raise ValueError(
50 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
51 | )
52 |
53 | if self.gradient_checkpointing and self.training and use_cache:
54 | logger.warning_once(
55 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
56 | )
57 | use_cache = False
58 |
59 | if inputs_embeds is None:
60 | inputs_embeds = self.embed_tokens(input_ids)
61 |
62 | return_legacy_cache = False
63 | if use_cache and not isinstance(past_key_values, Cache) and not self.training:
64 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
65 | return_legacy_cache = True
66 | logger.warning_once(
67 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
68 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
69 | )
70 |
71 | if cache_position is None:
72 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
73 | cache_position = torch.arange(
74 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
75 | )
76 |
77 | if position_ids is None:
78 | position_ids = cache_position.unsqueeze(0)
79 |
80 | causal_mask = self._update_causal_mask(
81 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
82 | )
83 |
84 | hidden_states = inputs_embeds
85 |
86 | # decoder layers
87 | all_hidden_states = () if output_hidden_states else None
88 | all_self_attns = () if output_attentions else None
89 | next_decoder_cache = None
90 |
91 | for decoder_layer in self.layers:
92 | if output_hidden_states:
93 | all_hidden_states += (hidden_states,)
94 |
95 | if self.gradient_checkpointing and self.training:
96 | layer_outputs = self._gradient_checkpointing_func(
97 | decoder_layer.__call__,
98 | hidden_states,
99 | causal_mask,
100 | position_ids,
101 | past_key_values,
102 | output_attentions,
103 | use_cache,
104 | cache_position,
105 | )
106 | else:
107 | layer_outputs = decoder_layer(
108 | hidden_states,
109 | attention_mask=causal_mask,
110 | position_ids=position_ids,
111 | past_key_value=past_key_values,
112 | output_attentions=output_attentions,
113 | use_cache=use_cache,
114 | cache_position=cache_position,
115 | )
116 |
117 | hidden_states = layer_outputs[0]
118 |
119 | if use_cache:
120 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
121 |
122 | if output_attentions:
123 | all_self_attns += (layer_outputs[1],)
124 |
125 | hidden_states = self.norm(hidden_states)
126 |
127 | # add hidden states from the last decoder layer
128 | if output_hidden_states:
129 | all_hidden_states += (hidden_states,)
130 |
131 | next_cache = next_decoder_cache if use_cache else None
132 | if return_legacy_cache:
133 | next_cache = next_cache.to_legacy_cache()
134 |
135 | if not return_dict:
136 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
137 |
138 | hidden_states = hidden_states[:, -1,:].unsqueeze(1)
139 | return BaseModelOutputWithPast(
140 | last_hidden_state=hidden_states,
141 | past_key_values=next_cache,
142 | hidden_states=all_hidden_states,
143 | attentions=all_self_attns,
144 | )
145 |
146 | def fixed_mistral_flash_attn2_forward(
147 | self,
148 | hidden_states: torch.Tensor,
149 | attention_mask: Optional[torch.Tensor] = None,
150 | position_ids: Optional[torch.LongTensor] = None,
151 | past_key_value: Optional[Cache] = None,
152 | output_attentions: bool = False,
153 | use_cache: bool = False,
154 | cache_position: Optional[torch.LongTensor] = None,
155 | ):
156 | if isinstance(past_key_value, StaticCache):
157 | raise ValueError(
158 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
159 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
160 | )
161 | # [SnapKV] register kv_cluster
162 | init_snapkv(self)
163 | output_attentions = False
164 |
165 | output_attentions = False
166 |
167 | bsz, q_len, _ = hidden_states.size()
168 |
169 | query_states = self.q_proj(hidden_states)
170 | key_states = self.k_proj(hidden_states)
171 | value_states = self.v_proj(hidden_states)
172 |
173 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
174 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
175 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
176 |
177 | kv_seq_len = key_states.shape[-2]
178 | if past_key_value is not None:
179 | kv_seq_len += cache_position[0]
180 |
181 | cos, sin = self.rotary_emb(value_states, position_ids)
182 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
183 |
184 | dropout_rate = 0.0 if not self.training else self.attention_dropout
185 |
186 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
187 | if past_key_value is not None:
188 | # Activate slicing cache only if the config has a value `sliding_windows` attribute
189 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
190 | if (
191 | getattr(self.config, "sliding_window", None) is not None
192 | and kv_seq_len > self.config.sliding_window
193 | and cache_has_contents
194 | ):
195 | slicing_tokens = 1 - self.config.sliding_window
196 |
197 | past_key = past_key_value[self.layer_idx][0]
198 | past_value = past_key_value[self.layer_idx][1]
199 |
200 | past_key = past_key[:, :, slicing_tokens:, :].contiguous()
201 | past_value = past_value[:, :, slicing_tokens:, :].contiguous()
202 |
203 | if past_key.shape[-2] != self.config.sliding_window - 1:
204 | raise ValueError(
205 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
206 | f" {past_key.shape}"
207 | )
208 |
209 | if attention_mask is not None:
210 | attention_mask = attention_mask[:, slicing_tokens:]
211 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
212 |
213 | if q_len == 1:
214 | # support gqa
215 | if not self.kv_cluster.gqa_support:
216 | key_states = repeat_kv(key_states, self.num_key_value_groups)
217 | value_states = repeat_kv(value_states, self.num_key_value_groups)
218 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
219 | else:
220 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states)
221 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
222 |
223 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
224 | # therefore the input hidden states gets silently casted in float32. Hence, we need
225 | # cast them back in float16 just to be sure everything works as expected.
226 | input_dtype = query_states.dtype
227 | if input_dtype == torch.float32:
228 | if torch.is_autocast_enabled():
229 | target_dtype = torch.get_autocast_gpu_dtype()
230 | # Handle the case where the model is quantized
231 | elif hasattr(self.config, "_pre_quantization_dtype"):
232 | target_dtype = self.config._pre_quantization_dtype
233 | else:
234 | target_dtype = self.q_proj.weight.dtype
235 |
236 | logger.warning_once(
237 | f"The input hidden states seems to be silently casted in float32, this might be related to"
238 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
239 | f" {target_dtype}."
240 | )
241 |
242 | query_states = query_states.to(target_dtype)
243 | key_states = key_states.to(target_dtype)
244 | value_states = value_states.to(target_dtype)
245 |
246 | # Reashape to the expected shape for Flash Attention
247 | query_states = query_states.transpose(1, 2)
248 | key_states = key_states.transpose(1, 2)
249 | value_states = value_states.transpose(1, 2)
250 |
251 | attn_output = _flash_attention_forward(
252 | query_states,
253 | key_states,
254 | value_states,
255 | attention_mask,
256 | q_len,
257 | position_ids=position_ids,
258 | dropout=dropout_rate,
259 | sliding_window=getattr(self.config, "sliding_window", None),
260 | use_top_left_mask=self._flash_attn_uses_top_left_mask,
261 | is_causal=self.is_causal,
262 | )
263 |
264 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
265 | attn_output = self.o_proj(attn_output)
266 |
267 | if not output_attentions:
268 | attn_weights = None
269 |
270 | return attn_output, attn_weights, past_key_value
271 |
272 |
273 | def prepare_inputs_for_generation_mistral(
274 | self,
275 | input_ids,
276 | past_key_values=None,
277 | attention_mask=None,
278 | inputs_embeds=None,
279 | cache_position=None,
280 | position_ids=None,
281 | use_cache=True,
282 | **kwargs,
283 | ):
284 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
285 | # Exception 1: when passing input_embeds, input_ids may be missing entries
286 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
287 | if past_key_values is not None:
288 | if inputs_embeds is not None: # Exception 1
289 | input_ids = input_ids[:, -cache_position.shape[0] :]
290 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
291 | input_ids = input_ids[:, cache_position]
292 |
293 | if attention_mask is not None and position_ids is None:
294 | # create position_ids on the fly for batch generation
295 | position_ids = attention_mask.long().cumsum(-1) - 1
296 | position_ids.masked_fill_(attention_mask == 0, 1)
297 | if past_key_values:
298 | position_ids = position_ids[:, -input_ids.shape[1] :]
299 |
300 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
301 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
302 |
303 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
304 | if inputs_embeds is not None and cache_position[0] == 0:
305 | model_inputs = {"inputs_embeds": inputs_embeds}
306 | else:
307 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
308 |
309 | model_inputs.update(
310 | {
311 | "position_ids": position_ids,
312 | "cache_position": cache_position,
313 | "past_key_values": past_key_values,
314 | "use_cache": use_cache,
315 | "attention_mask": attention_mask,
316 | }
317 | )
318 | return model_inputs
319 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/monkeypatch.py:
--------------------------------------------------------------------------------
1 | from importlib.metadata import version
2 | import warnings
3 | import transformers
4 | import transformers.models.mistral.modeling_mistral
5 | from adaptive_snapkv.monkeypatch.fixed_mistral_hijack import fixed_mistral_flash_attn2_forward, fixed_MistralModel_forward
6 | from adaptive_snapkv.monkeypatch.fixed_mistral_hijack import prepare_inputs_for_generation_mistral as fixed_prepare_inputs_for_generation_mistral
7 | from adaptive_snapkv.monkeypatch.adaptive_mistral_hijack import adaptive_mistral_flash_attn2_forward,adaptive_MistralModel_forward
8 | from adaptive_snapkv.monkeypatch.adaptive_mistral_hijack import prepare_inputs_for_generation_mistral as ada_prepare_inputs_for_generation_mistral
9 |
10 | from adaptive_snapkv.monkeypatch.fixed_llama_hijack import fixed_llama_flash_attn2_forward, fixed_LlamaModel_forward
11 | from adaptive_snapkv.monkeypatch.fixed_llama_hijack import prepare_inputs_for_generation_llama as fixed_prepare_inputs_for_generation_llama
12 | from adaptive_snapkv.monkeypatch.adaptive_llama_hijack import adaptive_llama_flash_attn2_forward,adaptive_LlamaModel_forward
13 | from adaptive_snapkv.monkeypatch.adaptive_llama_hijack import prepare_inputs_for_generation_llama as ada_prepare_inputs_for_generation_llama
14 |
15 | from adaptive_snapkv.monkeypatch.slm_llama_hijack import slm_llama_flash_attn2_forward, slm_LlamaModel_forward
16 | from adaptive_snapkv.monkeypatch.slm_mistral_hijack import slm_mistral_flash_attn2_forward, slm_MistralModel_forward
17 |
18 | def check_version():
19 | try:
20 | transformers_version = version("transformers")
21 | except Exception as e:
22 | print(f"Transformers not installed: {e}")
23 | version_list = ['4.37']
24 | warning_flag = True
25 | for x in version_list:
26 | if x in transformers_version:
27 | warning_flag = False
28 | break
29 | if warning_flag:
30 | warnings.warn(f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}.")
31 |
32 |
33 | # config hyperparameters
34 | def config_compress(model, window_size=32, base_capacity=1024, kernel_size=7, pooling="maxpool", floor_alpha=0.5, pyram_mode = False, beta = 20, skip=0, gqa_support=False,gqa_func="mean"):
35 | model.model.config.window_size = window_size
36 | model.model.config.base_capacity = base_capacity
37 | model.model.config.kernel_size = kernel_size
38 |
39 | model.model.config.pooling = pooling
40 | model.model.config.floor_alpha = floor_alpha
41 | model.model.config.skip = skip
42 | model.model.config.normalize = None
43 |
44 | model.model.config.pyram_mode = pyram_mode
45 | model.model.config.pyram_beta = beta
46 |
47 | model.model.config.gqa_support = gqa_support
48 | model.model.config.gqa_func = gqa_func
49 |
50 | return model
51 |
52 |
53 | def replace_mistral_fixed():
54 | check_version()
55 | transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
56 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward
57 | transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward
58 |
59 | def replace_mistral_adaptive():
60 | check_version()
61 | transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral
62 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward
63 | transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward
64 |
65 | def replace_llama_fixed():
66 | check_version()
67 | transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
68 | transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward
69 | transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward
70 |
71 | def replace_llama_adaptive():
72 | check_version()
73 | transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama
74 | transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward
75 | transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward
76 |
77 |
78 | def replace_llama_slm():
79 | check_version()
80 | transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama
81 | transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = slm_llama_flash_attn2_forward
82 | transformers.models.llama.modeling_llama.LlamaModel.forward = slm_LlamaModel_forward
83 |
84 | def replace_mistral_slm():
85 | check_version()
86 | transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral
87 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = slm_mistral_flash_attn2_forward
88 | transformers.models.mistral.modeling_mistral.MistralModel.forward = slm_MistralModel_forward
89 |
90 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/slm_llama_hijack.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import List, Optional, Tuple, Union
5 | import warnings
6 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
7 | from transformers.modeling_outputs import BaseModelOutputWithPast
8 | from transformers.models.llama.modeling_llama import (
9 | apply_rotary_pos_emb,
10 | repeat_kv,
11 | )
12 | from transformers.utils import (
13 | logging,
14 | )
15 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
16 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_slm
17 |
18 | logger = logging.get_logger(__name__)
19 |
20 | def slm_LlamaModel_forward(
21 | self,
22 | input_ids: torch.LongTensor = None,
23 | attention_mask: Optional[torch.Tensor] = None,
24 | position_ids: Optional[torch.LongTensor] = None,
25 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
26 | inputs_embeds: Optional[torch.FloatTensor] = None,
27 | use_cache: Optional[bool] = None,
28 | output_attentions: Optional[bool] = None,
29 | output_hidden_states: Optional[bool] = None,
30 | return_dict: Optional[bool] = None,
31 | cache_position: Optional[torch.LongTensor] = None,
32 | ) -> Union[Tuple, BaseModelOutputWithPast]:
33 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
34 | output_hidden_states = (
35 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
36 | )
37 | use_cache = use_cache if use_cache is not None else self.config.use_cache
38 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
39 |
40 | if (input_ids is None) ^ (inputs_embeds is not None):
41 | raise ValueError(
42 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
43 | )
44 |
45 | if self.gradient_checkpointing and self.training and use_cache:
46 | logger.warning_once(
47 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
48 | )
49 | use_cache = False
50 |
51 | if inputs_embeds is None:
52 | inputs_embeds = self.embed_tokens(input_ids)
53 |
54 | return_legacy_cache = False
55 | if (
56 | use_cache and not isinstance(past_key_values, Cache) and not self.training
57 | ): # kept for BC (non `Cache` `past_key_values` inputs)
58 | return_legacy_cache = True
59 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
60 | logger.warning_once(
61 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
62 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
63 | )
64 |
65 | if cache_position is None:
66 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
67 | cache_position = torch.arange(
68 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
69 | )
70 | if position_ids is None:
71 | position_ids = cache_position.unsqueeze(0)
72 |
73 | causal_mask = self._update_causal_mask(
74 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
75 | )
76 | hidden_states = inputs_embeds
77 |
78 | # create position embeddings to be shared across the decoder layers
79 | position_embeddings = self.rotary_emb(hidden_states, position_ids)
80 |
81 | # decoder layers
82 | all_hidden_states = () if output_hidden_states else None
83 | all_self_attns = () if output_attentions else None
84 | next_decoder_cache = None
85 |
86 | for decoder_layer in self.layers:
87 | if output_hidden_states:
88 | all_hidden_states += (hidden_states,)
89 |
90 | if self.gradient_checkpointing and self.training:
91 | layer_outputs = self._gradient_checkpointing_func(
92 | decoder_layer.__call__,
93 | hidden_states,
94 | causal_mask,
95 | position_ids,
96 | past_key_values,
97 | output_attentions,
98 | use_cache,
99 | cache_position,
100 | position_embeddings,
101 | )
102 | else:
103 | layer_outputs = decoder_layer(
104 | hidden_states,
105 | attention_mask=causal_mask,
106 | position_ids=position_ids,
107 | past_key_value=past_key_values,
108 | output_attentions=output_attentions,
109 | use_cache=use_cache,
110 | cache_position=cache_position,
111 | position_embeddings=position_embeddings,
112 | )
113 |
114 | hidden_states = layer_outputs[0]
115 |
116 | if use_cache:
117 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
118 |
119 | if output_attentions:
120 | all_self_attns += (layer_outputs[1],)
121 |
122 | hidden_states = self.norm(hidden_states)
123 |
124 | # add hidden states from the last decoder layer
125 | if output_hidden_states:
126 | all_hidden_states += (hidden_states,)
127 |
128 | next_cache = next_decoder_cache if use_cache else None
129 | if return_legacy_cache:
130 | next_cache = next_cache.to_legacy_cache()
131 |
132 | hidden_states = hidden_states[:, -1,:].unsqueeze(1)
133 |
134 | if not return_dict:
135 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
136 | return BaseModelOutputWithPast(
137 | last_hidden_state=hidden_states,
138 | past_key_values=next_cache,
139 | hidden_states=all_hidden_states,
140 | attentions=all_self_attns,
141 | )
142 |
143 |
144 | def slm_llama_flash_attn2_forward(
145 | self,
146 | hidden_states: torch.Tensor,
147 | attention_mask: Optional[torch.LongTensor] = None,
148 | position_ids: Optional[torch.LongTensor] = None,
149 | past_key_value: Optional[Cache] = None,
150 | output_attentions: bool = False,
151 | use_cache: bool = False,
152 | cache_position: Optional[torch.LongTensor] = None,
153 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
154 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
155 | if isinstance(past_key_value, StaticCache):
156 | raise ValueError(
157 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
158 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
159 | )
160 | init_slm(self)
161 |
162 | output_attentions = False
163 |
164 | bsz, q_len, _ = hidden_states.size()
165 |
166 | query_states = self.q_proj(hidden_states)
167 | key_states = self.k_proj(hidden_states)
168 | value_states = self.v_proj(hidden_states)
169 |
170 | # Flash attention requires the input to have the shape
171 | # batch_size x seq_length x head_dim x hidden_dim
172 | # therefore we just need to keep the original shape
173 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
174 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
175 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
176 |
177 | if position_embeddings is None:
178 | logger.warning_once(
179 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
180 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
181 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be "
182 | "removed and `position_embeddings` will be mandatory."
183 | )
184 | cos, sin = self.rotary_emb(value_states, position_ids)
185 | else:
186 | cos, sin = position_embeddings
187 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
188 |
189 |
190 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
191 | if past_key_value is not None:
192 | # NOTE: decoding update
193 | if q_len == 1:
194 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
195 | else:
196 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states)
197 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
198 |
199 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
200 | # to be able to avoid many of these transpose/reshape/view.
201 |
202 | dropout_rate = self.attention_dropout if self.training else 0.0
203 |
204 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
205 | # therefore the input hidden states gets silently casted in float32. Hence, we need
206 | # cast them back in the correct dtype just to be sure everything works as expected.
207 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms
208 | # in fp32. (LlamaRMSNorm handles it correctly)
209 |
210 | input_dtype = query_states.dtype
211 | if input_dtype == torch.float32:
212 | if torch.is_autocast_enabled():
213 | target_dtype = torch.get_autocast_gpu_dtype()
214 | # Handle the case where the model is quantized
215 | elif hasattr(self.config, "_pre_quantization_dtype"):
216 | target_dtype = self.config._pre_quantization_dtype
217 | else:
218 | target_dtype = self.q_proj.weight.dtype
219 |
220 | logger.warning_once(
221 | f"The input hidden states seems to be silently casted in float32, this might be related to"
222 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
223 | f" {target_dtype}."
224 | )
225 |
226 | query_states = query_states.to(target_dtype)
227 | key_states = key_states.to(target_dtype)
228 | value_states = value_states.to(target_dtype)
229 |
230 | # Reashape to the expected shape for Flash Attention
231 | query_states = query_states.transpose(1, 2)
232 | key_states = key_states.transpose(1, 2)
233 | value_states = value_states.transpose(1, 2)
234 |
235 | attn_output = _flash_attention_forward(
236 | query_states,
237 | key_states,
238 | value_states,
239 | attention_mask,
240 | q_len,
241 | position_ids=position_ids,
242 | dropout=dropout_rate,
243 | sliding_window=getattr(self, "sliding_window", None),
244 | use_top_left_mask=self._flash_attn_uses_top_left_mask,
245 | is_causal=self.is_causal,
246 | )
247 |
248 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
249 | attn_output = self.o_proj(attn_output)
250 |
251 | if not output_attentions:
252 | attn_weights = None
253 |
254 | return attn_output, attn_weights, past_key_value
255 |
256 | def prepare_inputs_for_generation_llama(
257 | self,
258 | input_ids,
259 | past_key_values=None,
260 | attention_mask=None,
261 | inputs_embeds=None,
262 | cache_position=None,
263 | position_ids=None,
264 | use_cache=True,
265 | **kwargs,
266 | ):
267 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
268 | # Exception 1: when passing input_embeds, input_ids may be missing entries
269 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
270 | if past_key_values is not None:
271 | if inputs_embeds is not None: # Exception 1
272 | input_ids = input_ids[:, -cache_position.shape[0] :]
273 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
274 | input_ids = input_ids[:, cache_position]
275 |
276 | if attention_mask is not None and position_ids is None:
277 | # create position_ids on the fly for batch generation
278 | position_ids = attention_mask.long().cumsum(-1) - 1
279 | position_ids.masked_fill_(attention_mask == 0, 1)
280 | if past_key_values:
281 | position_ids = position_ids[:, -input_ids.shape[1] :]
282 |
283 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
284 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
285 |
286 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
287 | if inputs_embeds is not None and cache_position[0] == 0:
288 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
289 | else:
290 | # The clone here is for the same reason as for `position_ids`.
291 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
292 |
293 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
294 | if model_inputs["inputs_embeds"] is not None:
295 | batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
296 | device = model_inputs["inputs_embeds"].device
297 | else:
298 | batch_size, sequence_length = model_inputs["input_ids"].shape
299 | device = model_inputs["input_ids"].device
300 |
301 | dtype = self.lm_head.weight.dtype
302 | min_dtype = torch.finfo(dtype).min
303 |
304 | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
305 | attention_mask,
306 | sequence_length=sequence_length,
307 | target_length=past_key_values.get_max_length(),
308 | dtype=dtype,
309 | device=device,
310 | min_dtype=min_dtype,
311 | cache_position=cache_position,
312 | batch_size=batch_size,
313 | )
314 |
315 | model_inputs.update(
316 | {
317 | "position_ids": position_ids,
318 | "cache_position": cache_position,
319 | "past_key_values": past_key_values,
320 | "use_cache": use_cache,
321 | "attention_mask": attention_mask,
322 | }
323 | )
324 | return model_inputs
325 |
326 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/slm_mistral_hijack.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from numpy import diff
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from typing import List, Optional, Tuple, Union
7 | import warnings
8 | from transformers.cache_utils import Cache, DynamicCache, StaticCache
9 | from transformers.modeling_outputs import BaseModelOutputWithPast
10 | from transformers.modeling_flash_attention_utils import _flash_attention_forward
11 | from transformers.models.mistral.modeling_mistral import (
12 | apply_rotary_pos_emb,
13 | repeat_kv,
14 | )
15 | from transformers.utils import (
16 | logging,
17 | is_flash_attn_2_available,
18 | )
19 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_slm
20 | from flash_attn import flash_attn_func
21 |
22 | logger = logging.get_logger(__name__)
23 |
24 | if is_flash_attn_2_available():
25 | from flash_attn import flash_attn_func, flash_attn_varlen_func
26 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
27 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
28 |
29 | def slm_MistralModel_forward(
30 | self,
31 | input_ids: torch.LongTensor = None,
32 | attention_mask: Optional[torch.Tensor] = None,
33 | position_ids: Optional[torch.LongTensor] = None,
34 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
35 | inputs_embeds: Optional[torch.FloatTensor] = None,
36 | use_cache: Optional[bool] = None,
37 | output_attentions: Optional[bool] = None,
38 | output_hidden_states: Optional[bool] = None,
39 | return_dict: Optional[bool] = None,
40 | cache_position: Optional[torch.LongTensor] = None,
41 | ) -> Union[Tuple, BaseModelOutputWithPast]:
42 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
43 | output_hidden_states = (
44 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
45 | )
46 | use_cache = use_cache if use_cache is not None else self.config.use_cache
47 |
48 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
49 |
50 | # retrieve input_ids and inputs_embeds
51 | if (input_ids is None) ^ (inputs_embeds is not None):
52 | raise ValueError(
53 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
54 | )
55 |
56 | if self.gradient_checkpointing and self.training and use_cache:
57 | logger.warning_once(
58 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
59 | )
60 | use_cache = False
61 |
62 | if inputs_embeds is None:
63 | inputs_embeds = self.embed_tokens(input_ids)
64 |
65 | return_legacy_cache = False
66 | if use_cache and not isinstance(past_key_values, Cache) and not self.training:
67 | past_key_values = DynamicCache.from_legacy_cache(past_key_values)
68 | return_legacy_cache = True
69 | logger.warning_once(
70 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
71 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
72 | )
73 |
74 | if cache_position is None:
75 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
76 | cache_position = torch.arange(
77 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
78 | )
79 |
80 | if position_ids is None:
81 | position_ids = cache_position.unsqueeze(0)
82 |
83 | causal_mask = self._update_causal_mask(
84 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
85 | )
86 |
87 | hidden_states = inputs_embeds
88 |
89 | # decoder layers
90 | all_hidden_states = () if output_hidden_states else None
91 | all_self_attns = () if output_attentions else None
92 | next_decoder_cache = None
93 |
94 | for decoder_layer in self.layers:
95 | if output_hidden_states:
96 | all_hidden_states += (hidden_states,)
97 |
98 | if self.gradient_checkpointing and self.training:
99 | layer_outputs = self._gradient_checkpointing_func(
100 | decoder_layer.__call__,
101 | hidden_states,
102 | causal_mask,
103 | position_ids,
104 | past_key_values,
105 | output_attentions,
106 | use_cache,
107 | cache_position,
108 | )
109 | else:
110 | layer_outputs = decoder_layer(
111 | hidden_states,
112 | attention_mask=causal_mask,
113 | position_ids=position_ids,
114 | past_key_value=past_key_values,
115 | output_attentions=output_attentions,
116 | use_cache=use_cache,
117 | cache_position=cache_position,
118 | )
119 |
120 | hidden_states = layer_outputs[0]
121 |
122 | if use_cache:
123 | next_decoder_cache = layer_outputs[2 if output_attentions else 1]
124 |
125 | if output_attentions:
126 | all_self_attns += (layer_outputs[1],)
127 |
128 | hidden_states = self.norm(hidden_states)
129 |
130 | # add hidden states from the last decoder layer
131 | if output_hidden_states:
132 | all_hidden_states += (hidden_states,)
133 |
134 | next_cache = next_decoder_cache if use_cache else None
135 | if return_legacy_cache:
136 | next_cache = next_cache.to_legacy_cache()
137 |
138 | if not return_dict:
139 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
140 |
141 | hidden_states = hidden_states[:, -1,:].unsqueeze(1)
142 | return BaseModelOutputWithPast(
143 | last_hidden_state=hidden_states,
144 | past_key_values=next_cache,
145 | hidden_states=all_hidden_states,
146 | attentions=all_self_attns,
147 | )
148 |
149 | def slm_mistral_flash_attn2_forward(
150 | self,
151 | hidden_states: torch.Tensor,
152 | attention_mask: Optional[torch.Tensor] = None,
153 | position_ids: Optional[torch.LongTensor] = None,
154 | past_key_value: Optional[Cache] = None,
155 | output_attentions: bool = False,
156 | use_cache: bool = False,
157 | cache_position: Optional[torch.LongTensor] = None,
158 | ):
159 | if isinstance(past_key_value, StaticCache):
160 | raise ValueError(
161 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
162 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
163 | )
164 | init_slm(self)
165 | output_attentions = False
166 |
167 | output_attentions = False
168 |
169 | bsz, q_len, _ = hidden_states.size()
170 |
171 | query_states = self.q_proj(hidden_states)
172 | key_states = self.k_proj(hidden_states)
173 | value_states = self.v_proj(hidden_states)
174 |
175 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
176 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
177 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
178 |
179 | kv_seq_len = key_states.shape[-2]
180 | if past_key_value is not None:
181 | kv_seq_len += cache_position[0]
182 |
183 | cos, sin = self.rotary_emb(value_states, position_ids)
184 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
185 |
186 | dropout_rate = 0.0 if not self.training else self.attention_dropout
187 |
188 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
189 | if past_key_value is not None:
190 | # Activate slicing cache only if the config has a value `sliding_windows` attribute
191 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
192 | if (
193 | getattr(self.config, "sliding_window", None) is not None
194 | and kv_seq_len > self.config.sliding_window
195 | and cache_has_contents
196 | ):
197 | slicing_tokens = 1 - self.config.sliding_window
198 |
199 | past_key = past_key_value[self.layer_idx][0]
200 | past_value = past_key_value[self.layer_idx][1]
201 |
202 | past_key = past_key[:, :, slicing_tokens:, :].contiguous()
203 | past_value = past_value[:, :, slicing_tokens:, :].contiguous()
204 |
205 | if past_key.shape[-2] != self.config.sliding_window - 1:
206 | raise ValueError(
207 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
208 | f" {past_key.shape}"
209 | )
210 |
211 | if attention_mask is not None:
212 | attention_mask = attention_mask[:, slicing_tokens:]
213 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
214 |
215 | if q_len == 1:
216 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
217 | else:
218 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states)
219 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
220 |
221 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons
222 | # therefore the input hidden states gets silently casted in float32. Hence, we need
223 | # cast them back in float16 just to be sure everything works as expected.
224 | input_dtype = query_states.dtype
225 | if input_dtype == torch.float32:
226 | if torch.is_autocast_enabled():
227 | target_dtype = torch.get_autocast_gpu_dtype()
228 | # Handle the case where the model is quantized
229 | elif hasattr(self.config, "_pre_quantization_dtype"):
230 | target_dtype = self.config._pre_quantization_dtype
231 | else:
232 | target_dtype = self.q_proj.weight.dtype
233 |
234 | logger.warning_once(
235 | f"The input hidden states seems to be silently casted in float32, this might be related to"
236 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
237 | f" {target_dtype}."
238 | )
239 |
240 | query_states = query_states.to(target_dtype)
241 | key_states = key_states.to(target_dtype)
242 | value_states = value_states.to(target_dtype)
243 |
244 | # Reashape to the expected shape for Flash Attention
245 | query_states = query_states.transpose(1, 2)
246 | key_states = key_states.transpose(1, 2)
247 | value_states = value_states.transpose(1, 2)
248 |
249 | attn_output = _flash_attention_forward(
250 | query_states,
251 | key_states,
252 | value_states,
253 | attention_mask,
254 | q_len,
255 | position_ids=position_ids,
256 | dropout=dropout_rate,
257 | sliding_window=getattr(self.config, "sliding_window", None),
258 | use_top_left_mask=self._flash_attn_uses_top_left_mask,
259 | is_causal=self.is_causal,
260 | )
261 |
262 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
263 | attn_output = self.o_proj(attn_output)
264 |
265 | if not output_attentions:
266 | attn_weights = None
267 |
268 | return attn_output, attn_weights, past_key_value
269 |
270 |
271 | def prepare_inputs_for_generation_mistral(
272 | self,
273 | input_ids,
274 | past_key_values=None,
275 | attention_mask=None,
276 | inputs_embeds=None,
277 | cache_position=None,
278 | position_ids=None,
279 | use_cache=True,
280 | **kwargs,
281 | ):
282 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
283 | # Exception 1: when passing input_embeds, input_ids may be missing entries
284 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
285 | if past_key_values is not None:
286 | if inputs_embeds is not None: # Exception 1
287 | input_ids = input_ids[:, -cache_position.shape[0] :]
288 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
289 | input_ids = input_ids[:, cache_position]
290 |
291 | if attention_mask is not None and position_ids is None:
292 | # create position_ids on the fly for batch generation
293 | position_ids = attention_mask.long().cumsum(-1) - 1
294 | position_ids.masked_fill_(attention_mask == 0, 1)
295 | if past_key_values:
296 | position_ids = position_ids[:, -input_ids.shape[1] :]
297 |
298 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
299 | position_ids = position_ids.clone(memory_format=torch.contiguous_format)
300 |
301 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
302 | if inputs_embeds is not None and cache_position[0] == 0:
303 | model_inputs = {"inputs_embeds": inputs_embeds}
304 | else:
305 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
306 |
307 | model_inputs.update(
308 | {
309 | "position_ids": position_ids,
310 | "cache_position": cache_position,
311 | "past_key_values": past_key_values,
312 | "use_cache": use_cache,
313 | "attention_mask": attention_mask,
314 | }
315 | )
316 | return model_inputs
317 |
--------------------------------------------------------------------------------
/adaptive_snapkv/monkeypatch/snapkv_utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import time
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import math
8 | from typing import List, Optional, Tuple, Union, Any,Dict
9 | from transformers.cache_utils import Cache, DynamicCache
10 | from flash_attn import flash_attn_func
11 | # perform qk calculation and get indices
12 | # this version will not update in inference mode
13 |
14 | class DynamicCacheSplitHeadFlatten(Cache):
15 | """
16 | Flattened version of DynamicCacheSplitHead
17 | """
18 | def __init__(self) ->None:
19 | # Token wise List[] Head wise KV List[torch.Tensor]
20 | super().__init__()
21 | self.key_cache: List[List[torch.Tensor]] = []
22 | self.value_cache: List[List[torch.Tensor]] = []
23 | self._seen_tokens = 0
24 |
25 | def __len__(self):
26 | return len(self.key_cache)
27 |
28 | def __iter__(self):
29 | for layer_idx in range(len(self)):
30 | yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))
31 |
32 | def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]:
33 | if layer_idx < len(self):
34 | return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))
35 | else:
36 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
37 |
38 | def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
39 | # NOTE: k, v = [head_num](bs, 1, seqlen, dim)
40 | # each layer is a flatten layout like:
41 | # [head_0_len + head_1_len + ..., dim]
42 | if len(self.key_cache) <= layer_idx:
43 | self.key_cache.append(key_states)
44 | self.value_cache.append(value_states)
45 | else:
46 | assert self.key_cache[layer_idx].dim() == 2
47 | bs, head, seqlen, dim = key_states.shape
48 | assert bs == 1 and seqlen == 1
49 | # NOTE: phase 2. we got [bs, head, seqlen, dim] as k, v input
50 | head_lens = cache_kwargs["head_lens"]
51 | cu_klen = cache_kwargs["cu_klen"]
52 |
53 | # TODO: wrap as a python interface
54 | from tiny_api_cuda import update_flatten_view
55 | new_key_cache = update_flatten_view(self.key_cache[layer_idx].view(-1,dim), key_states.view(-1, dim), head_lens, cu_klen)
56 | new_value_cache = update_flatten_view(self.value_cache[layer_idx].view(-1,dim), value_states.view(-1, dim), head_lens, cu_klen)
57 |
58 |
59 | self.key_cache[layer_idx] = new_key_cache
60 | self.value_cache[layer_idx] = new_value_cache
61 |
62 |
63 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
64 |
65 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
66 | if len(self.key_cache) <= layer_idx:
67 | return 0
68 |
69 | # TODO: return 1 to means has content for now
70 | return 1
71 | # return max(map(lambda states: states.shape[-2], self.key_cache[layer_idx]))
72 |
73 | def get_max_length(self) -> Optional[int]:
74 | return None
75 |
76 | def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
77 | """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
78 | legacy_cache = ()
79 | for layer_idx in range(len(self)):
80 | legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
81 | return legacy_cache
82 |
83 | @classmethod
84 | def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCacheEachHead":
85 | """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
86 | cache = cls()
87 | if past_key_values is not None:
88 | for layer_idx in range(len(past_key_values)):
89 | key_states, value_states = past_key_values[layer_idx]
90 | cache.update(key_states, value_states, layer_idx)
91 | return cache
92 |
93 |
94 |
95 | # class DynamicCacheSplitHead(Cache):
96 | # """
97 | # demo for illustrate the splited cache update
98 | # This class is slower than DynamicCacheSplitHeadFlatten, due to the frequent tensor copy
99 | # """
100 | # def __init__(self) ->None:
101 | # # Token wise List[] Head wise KV List[torch.Tensor]
102 | # super().__init__()
103 | # self.key_cache: List[List[torch.Tensor]] = []
104 | # self.value_cache: List[List[torch.Tensor]] = []
105 | # self._seen_tokens = 0
106 |
107 | # def __len__(self):
108 | # return len(self.key_cache)
109 |
110 | # def __iter__(self):
111 | # for layer_idx in range(len(self)):
112 | # yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))
113 |
114 | # def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]:
115 | # if layer_idx < len(self):
116 | # return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx]))
117 | # else:
118 | # raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
119 |
120 | # def update(
121 | # self,
122 | # key_states: List[torch.Tensor],
123 | # value_states: List[torch.Tensor],
124 | # layer_idx: int,
125 | # cache_kwargs: Optional[Dict[str, Any]] = None,
126 | # ) -> Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]]:
127 | # if layer_idx == 0:
128 | # self._seen_tokens += max(map(lambda states: states.shape[-2], key_states))
129 |
130 | # if len(self.key_cache)<=layer_idx:
131 | # self.key_cache.append(list(key_states))
132 | # self.value_cache.append(list(value_states))
133 | # else:
134 | # # tensor shape[ [bsz, seq, dim] * head_nums]
135 | # # [bsz,\sum seq,dim]
136 | # # [bsz,\sum seq+headnum,dim ]
137 | # for head_idx in range(len(key_states)):
138 | # self.key_cache[layer_idx][head_idx] = torch.cat([self.key_cache[layer_idx][head_idx],key_states[head_idx]], dim=-2)
139 | # self.value_cache[layer_idx][head_idx] = torch.cat([self.value_cache[layer_idx][head_idx],value_states[head_idx]], dim=-2)
140 | # return tuple(self.key_cache[layer_idx]), tuple(self.value_cache[layer_idx])
141 |
142 | # def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
143 | # if len(self.key_cache) <= layer_idx:
144 | # return 0
145 | # return max(map(lambda states: states.shape[-2], self.key_cache[layer_idx]))
146 |
147 | # def get_max_length(self) -> Optional[int]:
148 | # return None
149 |
150 |
151 | # # Tuple[Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]],...]
152 | # def to_legacy_cache(self)-> Tuple[Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]],...]:
153 | # """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
154 | # legacy_cache = ()
155 | # for layer_idx in range(len(self)):
156 | # legacy_cache += ((tuple(self.key_cache[layer_idx]), tuple(self.value_cache[layer_idx])),)
157 | # return legacy_cache
158 | # @classmethod
159 | # def from_legacy_cache(cls,past_key_values:Optional[ Tuple[Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]],...]]=None)->"DynamicCacheEachHead":
160 | # cache = cls()
161 | # if past_key_values is not None:
162 | # for layer_idx in range(len(past_key_values)):
163 | # key_states,value_states = past_key_values[layer_idx]
164 | # cache.update(list(key_states),list(value_states),layer_idx)
165 | # return cache
166 |
167 |
168 |
169 | # Copied from transformers.models.llama.modeling_llama.repeat_kv for gqa_support
170 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
171 | """
172 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
173 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
174 | """
175 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape
176 | if n_rep == 1:
177 | return hidden_states
178 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
179 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
180 |
181 | class SnapKVCluster():
182 | def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool', layer_idx = None, num_hidden_layers = None, pyram_mode = False, pyram_beta = 20,gqa_support=False,num_key_value_groups = 1, gqa_func=None):
183 | self.window_size = window_size
184 | self.max_capacity_prompt = max_capacity_prompt
185 | assert self.max_capacity_prompt - self.window_size > 0
186 | self.kernel_size = kernel_size
187 | self.pooling = pooling
188 |
189 | self.pyram_init = False
190 | self.pyram_mode = pyram_mode
191 | self.pyram_beta = pyram_beta
192 | self.layer_idx = layer_idx
193 | self.num_hidden_layers = num_hidden_layers
194 |
195 | # support gqa
196 | self.gqa_support = gqa_support
197 | self.num_key_value_groups = num_key_value_groups
198 | self.gqa_func = gqa_func
199 | if self.gqa_support:
200 | assert gqa_func is not None, "gqa_func should not be None"
201 | assert gqa_func in ['max','mean'], "currently gqa_func should be in ['max','mean']"
202 | if self.num_key_value_groups == 1:
203 | warnings.warn("gqa_support is enabled, but num_key_value_groups is 1, which means the model is not using gqa. Please check the model configuration.")
204 |
205 |
206 |
207 | def reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):
208 | self.window_size = window_size
209 | self.max_capacity_prompt = max_capacity_prompt
210 | assert self.max_capacity_prompt - self.window_size > 0
211 | self.kernel_size = kernel_size
212 | self.pooling = pooling
213 |
214 | def update_kv(self, origin_key_states, query_states, origin_value_states):
215 |
216 | # support gqa
217 | key_states = repeat_kv(origin_key_states, self.num_key_value_groups)
218 | value_states = repeat_kv(origin_value_states, self.num_key_value_groups)
219 | # check if prefix phase
220 | assert key_states.shape[-2] == query_states.shape[-2]
221 | bsz, num_heads, q_len, head_dim = query_states.shape
222 |
223 | # compute pyramidal capacity
224 | if self.pyram_mode and not self.pyram_init:
225 | # NOTE: (max_num + min_num) / 2 == base_capacity to restrict the total capacity
226 | base_capacity = self.max_capacity_prompt - self.window_size
227 | min_num = base_capacity // self.pyram_beta
228 | max_num = base_capacity * 2 - min_num
229 |
230 | # if the max_num is larger than the query length, we need to adjust the max_num
231 | if max_num >= q_len - self.window_size:
232 | max_num = q_len - self.window_size
233 | min_num = base_capacity * 2 - max_num
234 |
235 | # NOTE: compute interval
236 | steps = (max_num - min_num) // (self.num_hidden_layers - 1)
237 |
238 | self.max_capacity_prompt = max_num - self.layer_idx * steps + self.window_size
239 | self.pyram_init = True
240 | print(f"Pyram mode adaptive capacity, layer: {self.layer_idx}, max_capacity_prompt: {self.max_capacity_prompt}, base_capacity: {self.max_capacity_prompt - self.window_size}", flush=True)
241 |
242 | if q_len < self.max_capacity_prompt:
243 | # support gqa
244 | if self.gqa_support:
245 | return origin_key_states, origin_value_states
246 | else:
247 | return key_states, value_states
248 | else:
249 | attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
250 | mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
251 | mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
252 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
253 | mask = mask.to(attn_weights.device)
254 | attention_mask = mask[None, None, :, :]
255 |
256 | attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask
257 |
258 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
259 | attn_weights_mean = attn_weights[:, :, -self.window_size:, : -self.window_size].mean(dim = -2)
260 |
261 | # gqa_support
262 | if self.gqa_support:
263 | attn_weights_mean = attn_weights_mean.view(attn_weights_mean.shape[0], -1, self.num_key_value_groups, attn_weights_mean.shape[-1])
264 | if self.gqa_func == 'max':
265 | attn_weights_mean = attn_weights_mean.max(dim=-2).values
266 | elif self.gqa_func == 'mean':
267 | attn_weights_mean = attn_weights_mean.mean(dim=-2)
268 | else:
269 | raise ValueError('gqa_func not supported')
270 |
271 | if self.pooling == 'avgpool':
272 | attn_cache = F.avg_pool1d(attn_weights_mean, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
273 | elif self.pooling == 'maxpool':
274 | attn_cache = F.max_pool1d(attn_weights_mean, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
275 | else:
276 | raise ValueError('Pooling method not supported')
277 |
278 | indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices
279 | indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
280 |
281 | # support gqa
282 | if self.gqa_support:
283 | k_past_compress = origin_key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
284 | v_past_compress = origin_value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
285 | k_cur = origin_key_states[:, :, -self.window_size:, :]
286 | v_cur = origin_value_states[:, :, -self.window_size:, :]
287 | else:
288 | k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
289 | v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
290 | k_cur = key_states[:, :, -self.window_size:, :]
291 | v_cur = value_states[:, :, -self.window_size:, :]
292 | key_states = torch.cat([k_past_compress, k_cur], dim = 2)
293 | value_states = torch.cat([v_past_compress, v_cur], dim = 2)
294 | return key_states, value_states
295 |
296 |
297 | class AdaptiveSnapKVCluster():
298 | def __init__(self, window_size = 32, kernel_size = 7, pooling = 'maxpool',base_capacity=None,floor_alpha = None,skip = None,normalize=None,
299 | layer_idx = None, num_hidden_layers = None, pyram_mode = False, pyram_beta = 20,gqa_support=False,num_key_value_groups = 1, gqa_func=None):
300 | self.window_size = window_size
301 | self.kernel_size = kernel_size
302 | self.pooling = pooling
303 | self.base_capacity = base_capacity - window_size
304 | self.floor_ratio = floor_alpha
305 | self.floor_capacity = int(self.base_capacity * self.floor_ratio)
306 | self.adaptive_capacity = self.base_capacity - self.floor_capacity
307 | self.skip = skip
308 |
309 | self.normalize = normalize
310 | self.pyram_init = False
311 | self.pyram_mode = pyram_mode
312 | self.pyram_beta = pyram_beta
313 | self.layer_idx = layer_idx
314 | self.num_hidden_layers = num_hidden_layers
315 |
316 | # NOTE: layer-wise meta-data
317 | self.head_lens = None
318 | self.max_seqlen_k = 0
319 | self.klen_sum = 0
320 | self.cu_klen = 0
321 | self.cu_offset = None
322 | self.cu_headlens = None
323 |
324 | # support gqa
325 | self.gqa_support = gqa_support
326 | self.num_key_value_groups = num_key_value_groups
327 | self.gqa_func = gqa_func
328 | if self.gqa_support:
329 | assert gqa_func is not None, "gqa_func should not be None"
330 | assert gqa_func in ['max','mean'], "currently gqa_func should be in ['max','mean']"
331 | if self.num_key_value_groups == 1:
332 | warnings.warn("gqa_support is enabled, but num_key_value_groups is 1, which means the model is not using gqa. Please check the model configuration.")
333 |
334 |
335 | def calcul_attn_sore(self, key_states, query_states):
336 | bsz, num_heads, q_len, head_dim = query_states.shape
337 | attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(
338 | head_dim)
339 | mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min,
340 | device=attn_weights.device)
341 | mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
342 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
343 | mask = mask.to(attn_weights.device)
344 | attention_mask = mask[None, None, :, :]
345 |
346 | attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask
347 |
348 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
349 | attn_weights_mean = attn_weights[:, :, -self.window_size:, : -self.window_size].mean(dim=-2)
350 |
351 | if self.gqa_support:
352 | attn_weights_mean = attn_weights_mean.view(attn_weights_mean.shape[0],num_heads//self.num_key_value_groups,self.num_key_value_groups,-1)
353 | if self.gqa_func == 'max':
354 | attn_weights_mean = attn_weights_mean.max(dim=-2).values
355 | elif self.gqa_func == 'mean':
356 | attn_weights_mean = attn_weights_mean.mean(dim=-2)
357 | else:
358 | raise ValueError('gqa_func not supported')
359 |
360 | if self.pooling == 'avgpool':
361 | attn_weights_mean_pooling = F.avg_pool1d(attn_weights_mean, kernel_size=self.kernel_size,
362 | padding=self.kernel_size // 2,
363 | stride=1)
364 | elif self.pooling == 'maxpool':
365 | attn_weights_mean_pooling = F.max_pool1d(attn_weights_mean, kernel_size=self.kernel_size,
366 | padding=self.kernel_size // 2,
367 | stride=1)
368 | else:
369 | raise ValueError('Pooling method not supported')
370 | return attn_weights_mean_pooling
371 |
372 | def update_kv(self, origin_key_states, query_states, origin_value_states):
373 | if self.gqa_support:
374 | return self.update_kv_gqa(origin_key_states, query_states, origin_value_states)
375 | else:
376 | return self.update_kv_wo_gqa(origin_key_states, query_states, origin_value_states)
377 |
378 |
379 | # update kv with gqa_support
380 | def update_kv_gqa(self, origin_key_states, query_states, origin_value_states):
381 | key_states = repeat_kv(origin_key_states, self.num_key_value_groups)
382 | # value_states = repeat_kv(origin_value_states, self.num_key_value_groups)
383 |
384 | # check if prefix phase assert key_states.shape[-2] == query_states.shape[-2]
385 | _device = key_states.device
386 | bsz, num_heads, q_len, head_dim = query_states.shape
387 | attn_score= self.calcul_attn_sore(key_states,query_states)
388 | origin_heads_key_states = torch.split(origin_key_states, 1, dim=1)
389 | origin_heads_value_states = torch.split(origin_value_states, 1, dim=1)
390 |
391 | # compute pyramidal capacity
392 | if self.pyram_mode and not self.pyram_init:
393 | # NOTE: (max_num + min_num) / 2 == base_capacity to restrict the total capacity
394 | min_num = self.base_capacity // self.pyram_beta
395 | max_num = self.base_capacity * 2 - min_num
396 |
397 | # if the max_num is larger than the query length, we need to adjust the max_num
398 | if max_num >= q_len - self.window_size:
399 | max_num = q_len - self.window_size
400 | min_num = self.base_capacity * 2 - max_num
401 |
402 | # NOTE: compute interval
403 | steps = (max_num - min_num) // (self.num_hidden_layers - 1)
404 |
405 | # renew adaptive capacity
406 | self.base_capacity = max_num - self.layer_idx * steps
407 | self.floor_capacity = int(self.base_capacity * self.floor_ratio)
408 | self.adaptive_capacity = self.base_capacity - self.floor_capacity
409 | self.pyram_init = True
410 | print(f"Pyram mode adaptive capacity, layer: {self.layer_idx}, acap: {self.adaptive_capacity}, bcap: {self.base_capacity}, fcap: {self.floor_capacity}", flush=True)
411 |
412 | def init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k):
413 | # init metadata
414 | self.head_lens = torch.tensor(k_lens, dtype=torch.int32, device=_device)
415 | self.klen_sum = klen_sum
416 | self.max_seqlen_k = max_seqlen_k
417 | self.cu_headlens = torch.cumsum(self.head_lens, dim=0, dtype=torch.int32)
418 | # init varlen flash attention metadata
419 | self.cu_klen = self.cu_headlens - self.head_lens
420 | self.cu_klen = torch.cat(
421 | [self.cu_klen, torch.tensor([self.klen_sum], dtype=torch.int32, device=_device)], dim=0)
422 | # check bug
423 | self.layer_qlens = torch.ones(num_heads//self.num_key_value_groups, dtype=torch.int32,device=_device)
424 | self.qlen_sum = num_heads//self.num_key_value_groups
425 | self.cu_qlen = torch.cumsum(self.layer_qlens, dim=0, dtype=torch.int32) - self.layer_qlens
426 | self.cu_qlen = torch.cat(
427 | [self.cu_qlen, torch.tensor([self.qlen_sum], dtype=torch.int32, device=_device)], dim=0)
428 |
429 |
430 | if self.gqa_support:
431 | self.cu_offset = torch.arange(0, num_heads//self.num_key_value_groups + 1, dtype=torch.int32, device=_device)
432 | self.cu_head_offset = torch.arange(1, num_heads//self.num_key_value_groups +1, dtype=torch.int32, device=_device)
433 |
434 | else:
435 | self.cu_offset = torch.arange(0, num_heads + 1, dtype=torch.int32, device=_device)
436 | self.cu_head_offset = torch.arange(1, num_heads+1, dtype=torch.int32, device=_device)
437 |
438 |
439 | if self.base_capacity > attn_score.size(-1):
440 | init_metadata(num_heads, [q_len] * (num_heads//self.num_key_value_groups), q_len * (num_heads//self.num_key_value_groups), q_len)
441 | # not compress
442 | return origin_key_states.reshape(-1, head_dim), origin_value_states.reshape(-1, head_dim)
443 |
444 |
445 | sorted_attn_score,sorted_attn_score_indices = attn_score.sort(dim=-1,descending=True)
446 | if self.layer_idx >= self.skip:
447 | adaptive_attn_score = sorted_attn_score
448 | length = adaptive_attn_score.size(dim=-1)
449 | if self.normalize:
450 | ratio_weight = sorted_attn_score[...,:self.base_capacity].sum(dim=-1,keepdim=True)/sorted_attn_score.sum(dim=-1,keepdim=True)
451 | adaptive_attn_score = adaptive_attn_score*ratio_weight
452 | adaptive_attn_score = adaptive_attn_score.reshape(bsz,length*num_heads//self.num_key_value_groups)
453 | sorted_indices = torch.topk(adaptive_attn_score,k=num_heads*self.base_capacity//self.num_key_value_groups,dim=-1).indices
454 | sorted_indices = sorted_indices//length
455 |
456 | # floor_alpha capacity set
457 | head_adaptive_capacity = torch.zeros((bsz,num_heads//self.num_key_value_groups),device=_device,dtype = sorted_indices.dtype)
458 | head_adaptive_capacity.scatter_add_(-1,sorted_indices,torch.ones_like(sorted_indices,dtype=head_adaptive_capacity.dtype),)
459 | assert head_adaptive_capacity.sum().item() == num_heads*self.base_capacity//self.num_key_value_groups
460 | head_adaptive_capacity = torch.round(head_adaptive_capacity * (1-self.floor_ratio) + self.floor_capacity).int()
461 | else:
462 | head_adaptive_capacity = torch.ones((bsz,num_heads),device=_device,dtype = sorted_attn_score_indices.dtype) * self.base_capacity
463 | sorted_attn_score_indices = sorted_attn_score_indices.split(1,dim=1)
464 |
465 | heads_key_states = []
466 | heads_value_states = []
467 | assert bsz == 1
468 | # per head
469 |
470 | # reinit varlen metadata
471 | k_lens = []
472 | klen_sum = 0
473 | max_seqlen_k = 0
474 | self.cu_klen = 0
475 |
476 |
477 | for head_idx in range(num_heads//self.num_key_value_groups):
478 | cache_index = sorted_attn_score_indices[head_idx][...,:head_adaptive_capacity[0][head_idx]]
479 |
480 | l = cache_index.shape[-1] + self.window_size
481 | k_lens.append(l)
482 | max_seqlen_k = max(max_seqlen_k, l)
483 | klen_sum += l
484 |
485 | cache_index = cache_index.view(1, 1, -1, 1).expand(-1, -1, -1, head_dim)
486 | top_Kcache = origin_heads_key_states[head_idx].gather(dim=2,index=cache_index)
487 | top_Vcache = origin_heads_value_states[head_idx].gather(dim=2,index=cache_index)
488 | selected_k = torch.cat([top_Kcache,origin_heads_key_states[head_idx][:, :, -self.window_size:, :]],dim=2)
489 | selected_v = torch.cat([top_Vcache,origin_heads_value_states[head_idx][:, :, -self.window_size:, :]],dim=2)
490 |
491 | # NOTE: flatten view
492 | heads_key_states.append(selected_k.view(-1, head_dim))
493 | heads_value_states.append(selected_v.view(-1, head_dim))
494 |
495 | init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k)
496 |
497 | # NOTE: compose as flatten view
498 | heads_key_states = torch.cat(heads_key_states, dim=0)
499 | heads_value_states = torch.cat(heads_value_states, dim=0)
500 |
501 | return heads_key_states, heads_value_states
502 |
503 |
504 |
505 |
506 | # update without gqa_support
507 | def update_kv_wo_gqa(self, origin_key_states, query_states, origin_value_states):
508 | key_states = repeat_kv(origin_key_states, self.num_key_value_groups)
509 | value_states = repeat_kv(origin_value_states, self.num_key_value_groups)
510 |
511 | # check if prefix phase assert key_states.shape[-2] == query_states.shape[-2]
512 | _device = key_states.device
513 | bsz, num_heads, q_len, head_dim = query_states.shape
514 | attn_score= self.calcul_attn_sore(key_states,query_states)
515 | origin_heads_key_states = torch.split(key_states, 1, dim=1)
516 | origin_heads_value_states = torch.split(value_states, 1, dim=1)
517 |
518 | # compute pyramidal capacity
519 | if self.pyram_mode and not self.pyram_init:
520 | # NOTE: (max_num + min_num) / 2 == base_capacity to restrict the total capacity
521 | min_num = self.base_capacity // self.pyram_beta
522 | max_num = self.base_capacity * 2 - min_num
523 |
524 | # if the max_num is larger than the query length, we need to adjust the max_num
525 | if max_num >= q_len - self.window_size:
526 | max_num = q_len - self.window_size
527 | min_num = self.base_capacity * 2 - max_num
528 |
529 | # NOTE: compute interval
530 | steps = (max_num - min_num) // (self.num_hidden_layers - 1)
531 |
532 | # renew adaptive capacity
533 | self.base_capacity = max_num - self.layer_idx * steps
534 | self.floor_capacity = int(self.base_capacity * self.floor_ratio)
535 | self.adaptive_capacity = self.base_capacity - self.floor_capacity
536 | self.pyram_init = True
537 | print(f"Pyram mode adaptive capacity, layer: {self.layer_idx}, acap: {self.adaptive_capacity}, bcap: {self.base_capacity}, fcap: {self.floor_capacity}", flush=True)
538 |
539 | def init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k):
540 | # init metadata
541 | self.head_lens = torch.tensor(k_lens, dtype=torch.int32, device=_device)
542 | self.klen_sum = klen_sum
543 | self.max_seqlen_k = max_seqlen_k
544 | self.cu_headlens = torch.cumsum(self.head_lens, dim=0, dtype=torch.int32)
545 | # init varlen flash attention metadata
546 | self.cu_klen = self.cu_headlens - self.head_lens
547 | self.cu_klen = torch.cat(
548 | [self.cu_klen, torch.tensor([self.klen_sum], dtype=torch.int32, device=_device)], dim=0)
549 | self.layer_qlens = torch.ones(num_heads, dtype=torch.int32,device=_device)
550 | self.qlen_sum = num_heads
551 | self.cu_qlen = torch.cumsum(self.layer_qlens, dim=0, dtype=torch.int32) - self.layer_qlens
552 | self.cu_qlen = torch.cat(
553 | [self.cu_qlen, torch.tensor([self.qlen_sum], dtype=torch.int32, device=_device)], dim=0)
554 | self.cu_offset = torch.arange(0, num_heads + 1, dtype=torch.int32, device=_device)
555 | self.cu_head_offset = torch.arange(1, num_heads+1, dtype=torch.int32, device=_device)
556 |
557 | if self.base_capacity > attn_score.size(-1):
558 | init_metadata(num_heads, [q_len] * num_heads, q_len * num_heads, q_len)
559 | # not compress
560 | return key_states.reshape(-1, head_dim), value_states.reshape(-1, head_dim)
561 |
562 | # if you need to weight the attn_score
563 | pass
564 | sorted_attn_score,sorted_attn_score_indices = attn_score.sort(dim=-1,descending=True)
565 | if self.layer_idx >= self.skip:
566 | adaptive_attn_score = sorted_attn_score
567 | length = adaptive_attn_score.size(dim=-1)
568 | if self.normalize:
569 | ratio_weight = sorted_attn_score[...,:self.base_capacity].sum(dim=-1,keepdim=True)/sorted_attn_score.sum(dim=-1,keepdim=True)
570 | adaptive_attn_score = adaptive_attn_score*ratio_weight
571 | adaptive_attn_score = adaptive_attn_score.reshape(bsz,length*num_heads)
572 | sorted_indices = torch.topk(adaptive_attn_score,k=num_heads*self.base_capacity,dim=-1).indices
573 | sorted_indices = sorted_indices//length
574 | # floor_alpha capacity set
575 | head_adaptive_capacity = torch.zeros((bsz,num_heads),device=_device,dtype = sorted_indices.dtype)
576 | head_adaptive_capacity.scatter_add_(-1,sorted_indices,torch.ones_like(sorted_indices,dtype=head_adaptive_capacity.dtype),)
577 | assert head_adaptive_capacity.sum().item() == num_heads*self.base_capacity
578 | head_adaptive_capacity = torch.round(head_adaptive_capacity * (1-self.floor_ratio) + self.floor_capacity).int()
579 | else:
580 | head_adaptive_capacity = torch.ones((bsz,num_heads),device=_device,dtype = sorted_attn_score_indices.dtype) * self.base_capacity
581 | sorted_attn_score_indices = sorted_attn_score_indices.split(1,dim=1)
582 |
583 | heads_key_states = []
584 | heads_value_states = []
585 | assert bsz == 1
586 | # per head
587 |
588 | # reinit varlen metadata
589 | k_lens = []
590 | klen_sum = 0
591 | max_seqlen_k = 0
592 | self.cu_klen = 0
593 |
594 |
595 | for head_idx in range(num_heads):
596 | cache_index = sorted_attn_score_indices[head_idx][...,:head_adaptive_capacity[0][head_idx]]
597 |
598 | l = cache_index.shape[-1] + self.window_size
599 | k_lens.append(l)
600 | max_seqlen_k = max(max_seqlen_k, l)
601 | klen_sum += l
602 |
603 | cache_index = cache_index.view(1, 1, -1, 1).expand(-1, -1, -1, head_dim)
604 | top_Kcache = origin_heads_key_states[head_idx].gather(dim=2,index=cache_index)
605 | top_Vcache = origin_heads_value_states[head_idx].gather(dim=2,index=cache_index)
606 | selected_k = torch.cat([top_Kcache,origin_heads_key_states[head_idx][:, :, -self.window_size:, :]],dim=2)
607 | selected_v = torch.cat([top_Vcache,origin_heads_value_states[head_idx][:, :, -self.window_size:, :]],dim=2)
608 |
609 | # NOTE: flatten view
610 | heads_key_states.append(selected_k.view(-1, head_dim))
611 | heads_value_states.append(selected_v.view(-1, head_dim))
612 |
613 | init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k)
614 |
615 | # NOTE: compose as flatten view
616 | heads_key_states = torch.cat(heads_key_states, dim=0)
617 | heads_value_states = torch.cat(heads_value_states, dim=0)
618 |
619 | return heads_key_states,heads_value_states
620 |
621 |
622 | def init_snapkv(self):
623 |
624 | assert hasattr(self.config, 'window_size'), "window_size not set"
625 | assert hasattr(self.config, 'kernel_size'), "kernel_size not set"
626 | assert hasattr(self.config, "pooling"), "pooling not set"
627 | assert hasattr(self.config, "base_capacity"), "base_capacity not set"
628 | # init only once
629 | if not hasattr(self, "kv_cluster"):
630 | self.kv_cluster = SnapKVCluster(
631 | window_size = self.config.window_size,
632 | max_capacity_prompt = self.config.base_capacity,
633 | kernel_size = self.config.kernel_size,
634 | pooling = self.config.pooling,
635 | layer_idx = self.layer_idx,
636 | num_hidden_layers = self.config.num_hidden_layers,
637 | pyram_mode = self.config.pyram_mode,
638 | pyram_beta = self.config.pyram_beta,
639 | gqa_support = self.config.gqa_support,
640 | num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads,
641 | gqa_func = self.config.gqa_func
642 | )
643 | if self.config.gqa_support:
644 | if self.config.model_type != "mistral":
645 | warnings.warn("GQA currently supports only for mistral-7B-v0.2 model")
646 | # if len(self.config.skip) > 0:
647 | # warnings.warn("vanilla transformer should not enable skip",self.config.skip)
648 | print(f"Compress config(Snap): window_size={self.kv_cluster.window_size}, max_capacity_prompt={self.kv_cluster.max_capacity_prompt}, kernel_size={self.kv_cluster.kernel_size}, pooling={self.kv_cluster.pooling}, pyram_mode={self.kv_cluster.pyram_mode}, beta={self.kv_cluster.pyram_beta}", flush=True)
649 |
650 | def init_adaptive_snapkv(self):
651 | assert hasattr(self.config,'window_size'),"window_size not set"
652 | assert hasattr(self.config,'kernel_size'),"kernel_size not set"
653 | assert hasattr(self.config,"pooling"),"pooling not set"
654 | assert hasattr(self.config, "base_capacity"), "base_capacity not set"
655 | assert hasattr(self.config,"floor_alpha"),"floor_alpha not set"
656 | assert self.config.floor_alpha is not None
657 |
658 |
659 | # init only once
660 | if not hasattr(self, "kv_cluster"):
661 | self.kv_cluster = AdaptiveSnapKVCluster(
662 | window_size = self.config.window_size,
663 | base_capacity=self.config.base_capacity,
664 | kernel_size = self.config.kernel_size,
665 | pooling = self.config.pooling,
666 | floor_alpha= self.config.floor_alpha,
667 | skip = self.config.skip,
668 | layer_idx = self.layer_idx,
669 | normalize = self.config.normalize,
670 | num_hidden_layers = self.config.num_hidden_layers,
671 | pyram_mode = self.config.pyram_mode,
672 | pyram_beta = self.config.pyram_beta,
673 | gqa_support = self.config.gqa_support,
674 | num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads,
675 | gqa_func = self.config.gqa_func
676 | )
677 | if self.config.gqa_support:
678 | if self.config.model_type != "mistral":
679 | warnings.warn("GQA currently supports only for mistral-7B-v0.2 model")
680 | print(f"Compress config(Ada): window_size={self.kv_cluster.window_size}, base_capacity={self.kv_cluster.base_capacity}, kernel_size={self.kv_cluster.kernel_size}, pooling={self.kv_cluster.pooling}, floor_alpha={self.kv_cluster.floor_ratio}, pyram_mode={self.kv_cluster.pyram_mode}, beta={self.kv_cluster.pyram_beta}", flush=True)
681 |
682 |
683 |
684 | class StreamingLLMKVCluster():
685 | def __init__(self, max_capacity_prompt = 256):
686 | self.max_capacity_prompt = max_capacity_prompt - 4
687 | self.sink_token = 4
688 | assert self.max_capacity_prompt - 4 > 0
689 |
690 |
691 | def update_kv(self, key_states, query_states, value_states, *args, **kwargs):
692 | # check if prefix phase
693 | assert key_states.shape[-2] == query_states.shape[-2]
694 | bsz, num_heads, q_len, head_dim = query_states.shape
695 |
696 | print(f"StreamingLLM max_capacity_prompt {self.max_capacity_prompt + 4}")
697 |
698 | if q_len < self.max_capacity_prompt + 4:
699 | return key_states, value_states
700 | else:
701 | k_past_compress = key_states[:, :, :self.sink_token, :]
702 | v_past_compress = value_states[:, :, :self.sink_token, :]
703 | k_cur = key_states[:, :, -self.max_capacity_prompt:, :]
704 | v_cur = value_states[:, :, -self.max_capacity_prompt:, :]
705 | key_states = torch.cat([k_past_compress, k_cur], dim = 2)
706 | value_states = torch.cat([v_past_compress, v_cur], dim = 2)
707 | return key_states, value_states
708 |
709 |
710 | def init_slm(self,**kwargs):
711 | assert hasattr(self.config, 'window_size'), "window_size not set"
712 | # init only once
713 | if not hasattr(self, "kv_cluster"):
714 | self.kv_cluster = StreamingLLMKVCluster(
715 | max_capacity_prompt = self.config.base_capacity,
716 | )
717 | print(f"Compress config(SLM): max_cap={self.config.base_capacity}")
718 |
719 |
720 |
721 |
722 |
723 |
--------------------------------------------------------------------------------
/assets/images/4k_ruler_average_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/4k_ruler_average_score.png
--------------------------------------------------------------------------------
/assets/images/LongBench_mistral.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/LongBench_mistral.png
--------------------------------------------------------------------------------
/assets/images/LongBench_mistral_gqa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/LongBench_mistral_gqa.png
--------------------------------------------------------------------------------
/assets/images/head_vary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/head_vary.png
--------------------------------------------------------------------------------
/assets/images/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/main.png
--------------------------------------------------------------------------------
/assets/images/mem.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/mem.png
--------------------------------------------------------------------------------
/assets/images/speed.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/speed.png
--------------------------------------------------------------------------------
/csrc/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | deps/
3 | tiny_pkg.egg-info/
4 | test.py
5 |
--------------------------------------------------------------------------------
/csrc/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 66RING
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/csrc/build.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | from packaging.version import parse, Version
4 | from pathlib import Path
5 | from setuptools import setup, find_packages
6 | from torch.utils.cpp_extension import (
7 | BuildExtension,
8 | CppExtension,
9 | CUDAExtension,
10 | CUDA_HOME,
11 | )
12 |
13 | # package name managed by pip, which can be remove by `pip uninstall tiny_pkg`
14 | PACKAGE_NAME = "tiny_pkg"
15 |
16 | ext_modules = []
17 | generator_flag = []
18 | cc_flag = []
19 | cc_flag.append("-gencode")
20 | cc_flag.append("arch=compute_80,code=sm_80")
21 |
22 |
23 | # helper function to get cuda version
24 | def get_cuda_bare_metal_version(cuda_dir):
25 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
26 | output = raw_output.split()
27 | release_idx = output.index("release") + 1
28 | bare_metal_version = parse(output[release_idx].split(",")[0])
29 |
30 | return raw_output, bare_metal_version
31 |
32 |
33 | # if CUDA_HOME is not None:
34 | # _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
35 | # if bare_metal_version >= Version("11.8"):
36 | # cc_flag.append("-gencode")
37 | # cc_flag.append("arch=compute_90,code=sm_90")
38 |
39 | # ninja build does not work unless include_dirs are abs path
40 | this_dir = os.path.dirname(os.path.abspath(__file__))
41 |
42 | # cuda module
43 | ext_modules.append(
44 | CUDAExtension(
45 | # package name for import
46 | name="tiny_api_cuda",
47 | sources=[
48 | "csrc/cuda_api.cu",
49 | ],
50 | extra_compile_args={
51 | # add c compile flags
52 | "cxx": ["-O3", "-std=c++17"] + generator_flag,
53 | # add nvcc compile flags
54 | "nvcc": [
55 | "-O3",
56 | "-std=c++17",
57 | "-U__CUDA_NO_HALF_OPERATORS__",
58 | "--use_fast_math",
59 | "-lineinfo",
60 | "--ptxas-options=-v",
61 | "--ptxas-options=-O2",
62 | "-U__CUDA_NO_HALF_OPERATORS__",
63 | "-U__CUDA_NO_HALF_CONVERSIONS__",
64 | "-U__CUDA_NO_HALF2_OPERATORS__",
65 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
66 | "--expt-relaxed-constexpr",
67 | "--expt-extended-lambda",
68 | "--use_fast_math",
69 | ]
70 | + generator_flag
71 | + cc_flag,
72 | },
73 | include_dirs=[
74 | Path(this_dir) / "csrc",
75 | Path(this_dir) / "include",
76 | # Path(this_dir) / "some" / "thing" / "more",
77 | ],
78 | )
79 | )
80 |
81 | setup(
82 | name=PACKAGE_NAME,
83 | packages=find_packages(
84 | exclude=(
85 | "build",
86 | "csrc",
87 | "include",
88 | "tests",
89 | "dist",
90 | "docs",
91 | "benchmarks",
92 | "tiny_pkg.egg-info",
93 | )
94 | ),
95 | description="Tiny cuda and c api binding for pytorch.",
96 | ext_modules=ext_modules,
97 | cmdclass={ "build_ext": BuildExtension},
98 | python_requires=">=3.7",
99 | install_requires=[
100 | "torch",
101 | "einops",
102 | "packaging",
103 | "ninja",
104 | ],
105 | )
106 |
107 |
108 |
109 |
110 |
--------------------------------------------------------------------------------
/csrc/csrc/cuda_api.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #include "cuda_api.h"
9 | #include "static_switch.h"
10 |
11 | template
12 | __global__ void update_flatten_view_kernel(tensor_t* dst_ptr, tensor_t* src_ptr, tensor_t* state_ptr, int* headlens,
13 | int *cu_headlens,
14 | int dim) {
15 | // Create new tensor from cache and insert element into it.
16 |
17 | int head_idx = blockIdx.x;
18 | int thread_group = blockIdx.y;
19 | int tid = threadIdx.x + thread_group * blockDim.x;
20 | int num_threads = blockDim.x * gridDim.y;
21 |
22 | int headlen = headlens[head_idx];
23 |
24 | // get position of src, dst, insert ptr
25 | int src_cum_off = cu_headlens[head_idx] * dim;
26 | int dst_cum_off = src_cum_off + head_idx * dim;
27 | int insert_off = (cu_headlens[head_idx+1] + head_idx) * dim;
28 |
29 | auto old_cache_ptr = src_ptr + src_cum_off;
30 | auto new_cache_ptr = dst_ptr + dst_cum_off;
31 | auto insert_cache_ptr = dst_ptr + insert_off;
32 |
33 | // copy old data
34 | for (int start_addr = 0; start_addr < headlen * dim; start_addr += kblock_size * num_threads) {
35 | auto src_addr = old_cache_ptr + start_addr + tid * kblock_size;
36 | auto dst_addr = new_cache_ptr + start_addr + tid * kblock_size;
37 |
38 | // TODO: LDSM speed up with SRAM
39 | #pragma unroll
40 | for (int i = 0; i < kblock_size; i++) {
41 | if (start_addr + tid * kblock_size + i >= headlen * dim) {
42 | break;
43 | }
44 | dst_addr[i] = src_addr[i];
45 | }
46 | }
47 |
48 | // insert new data
49 | if (tid < dim) {
50 | auto insert_src_ptr = state_ptr + head_idx * dim + tid;
51 | auto insert_dst_addr = insert_cache_ptr + tid;
52 | *insert_dst_addr = *insert_src_ptr;
53 | }
54 | }
55 |
56 | torch::Tensor update_flatten_view(torch::Tensor &cache, torch::Tensor &state, torch::Tensor &headlens, torch::Tensor& cu_headlens) {
57 | TORCH_CHECK(headlens.dtype() == torch::kInt32, "expected headlens to be int32");
58 | TORCH_CHECK(cu_headlens.dtype() == torch::kInt32, "expected cu_dst_pos to be int32");
59 |
60 | auto cache_shape = cache.sizes();
61 |
62 | int origin_len = cache_shape[0];
63 | int head_dim = cache_shape[1];
64 | int head_num = headlens.sizes()[0];
65 |
66 | torch::Tensor out = torch::empty({origin_len + head_num, head_dim}, cache.options());
67 |
68 | const int kblock_size = 1;
69 | const int num_threads_group = 1024;
70 | const int num_threads = 128;
71 |
72 | dim3 grid(head_num, num_threads_group);
73 |
74 | // TODO: dispatch with head_dim?? may loss performance
75 | dim3 block(num_threads);
76 | TORCH_CHECK(num_threads >= head_dim, "num threads should larger than head dim");
77 |
78 | FP16_SWITCH(cache.dtype() == torch::kFloat16, [&] {
79 | auto kernel = update_flatten_view_kernel;
80 | kernel<<>>((elem_type*)out.data_ptr(), (elem_type*)cache.data_ptr(), (elem_type*)state.data_ptr(), (int*)headlens.data_ptr(), (int*)cu_headlens.data_ptr(), head_dim);
81 | });
82 |
83 | // TODO: when to use sync or torch auto
84 | // cudaDeviceSynchronize();
85 |
86 | return out;
87 | }
88 |
89 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
90 | // m.def("package_name", &function_name, "function_docstring"")
91 |
92 | m.def("update_flatten_view", &update_flatten_view, "update flatten view cache");
93 | }
94 |
--------------------------------------------------------------------------------
/csrc/csrc/static_switch.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #define FP16_SWITCH(COND, ...) \
4 | [&] { \
5 | if (COND) { \
6 | using elem_type = at::Half; \
7 | return __VA_ARGS__(); \
8 | } else { \
9 | using elem_type = at::BFloat16; \
10 | return __VA_ARGS__(); \
11 | } \
12 | }()
13 |
14 |
15 |
--------------------------------------------------------------------------------
/csrc/include/cuda_api.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #define DEBUG 1
6 |
7 | #ifdef DEBUG
8 |
9 | // NOTE:tensor malloc as device before we call
10 | // e.g. data.to("cuda") in python
11 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
14 | #define CUDA_ERROR_CHECK(condition) \
15 | do { \
16 | cudaError_t error = condition; \
17 | if (error != cudaSuccess) { \
18 | printf("CUDA_CHECK error in line %d of file %s \
19 | : %s \n", \
20 | __LINE__, __FILE__, cudaGetErrorString(error)); \
21 | exit(EXIT_FAILURE); \
22 | } \
23 | } while (0)
24 |
25 | #else
26 |
27 | #define CHECK_CUDA(x) do { } while (0)
28 | #define CHECK_CONTIGUOUS(x) do { } while (0)
29 | #define CHECK_INPUT(x) do { } while (0)
30 | #define CUDA_ERROR_CHECK(condition) do { condition; } while (0)
31 |
32 | #endif // DEBUG
33 |
34 |
35 |
--------------------------------------------------------------------------------
/csrc/makefile:
--------------------------------------------------------------------------------
1 | build:
2 | python build.py install
3 |
4 | .PHONY: build
5 |
--------------------------------------------------------------------------------
/experiments/LongBench/GQA_eval_longbench.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_NAME="Mistral-7B-Instruct-v0.2"
4 | MODEL=${MODELS_DIR}/${MODEL_NAME}
5 | DATASET=${DATASETS_DIR}/LongBench
6 | MAX_LEN=31500
7 |
8 | echo "Testing $MODEL $DATASET $MAX_LEN"
9 |
10 | scopes=(128 256 512 1024)
11 | device="4090"
12 |
13 | for scope in ${scopes[@]}; do
14 | # ada-snapkv with gqa support
15 | # we suggest setting floor_alpha to 0.2 while using gqa
16 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode ada --compress_args_path c"$scope"_w32_k7_maxpool.json --floor_alpha 0.2 --out_name "$device"_ada_"$scope"_"$MODEL_NAME" --gqa_support
17 | # snapkv with gqa support
18 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode fix --compress_args_path c"$scope"_w32_k7_maxpool.json --out_name "$device"_fix_"$scope"_"$MODEL_NAME" --gqa_support
19 | # ada-snapkv without gqa support
20 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode ada --compress_args_path c"$scope"_w32_k7_maxpool.json --floor_alpha 0.2 --out_name "$device"_ada_"$scope"_"$MODEL_NAME"
21 | # snapkv without gqa support
22 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode fix --compress_args_path c"$scope"_w32_k7_maxpool.json --out_name "$device"_fix_"$scope"_"$MODEL_NAME"
23 | done
24 |
--------------------------------------------------------------------------------
/experiments/LongBench/README.md:
--------------------------------------------------------------------------------
1 | # Minimal version of LongBench
2 | Based on SnapKV
--------------------------------------------------------------------------------
/experiments/LongBench/config/dataset2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": 128,
3 | "qasper": 128,
4 | "multifieldqa_en": 64,
5 | "multifieldqa_zh": 64,
6 | "hotpotqa": 32,
7 | "2wikimqa": 32,
8 | "musique": 32,
9 | "dureader": 128,
10 | "gov_report": 512,
11 | "qmsum": 512,
12 | "multi_news": 512,
13 | "vcsum": 512,
14 | "trec": 64,
15 | "triviaqa": 32,
16 | "samsum": 128,
17 | "lsht": 64,
18 | "passage_count": 32,
19 | "passage_retrieval_en": 32,
20 | "passage_retrieval_zh": 32,
21 | "lcc": 64,
22 | "repobench-p": 64
23 | }
--------------------------------------------------------------------------------
/experiments/LongBench/config/dataset2prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:",
4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:",
6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:",
9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:",
10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:",
11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:",
12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:",
13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:",
14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}",
15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}",
16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}",
17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}",
18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ",
19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ",
20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:",
21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n",
22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n"
23 | }
--------------------------------------------------------------------------------
/experiments/LongBench/config/model2maxlen.json:
--------------------------------------------------------------------------------
1 | {
2 | "llama2-7b-chat-4k": 3500,
3 | "longchat-v1.5-7b-32k": 31500,
4 | "xgen-7b-8k": 7500,
5 | "internlm-7b-8k": 7500,
6 | "chatglm2-6b": 31500,
7 | "chatglm2-6b-32k": 31500,
8 | "chatglm3-6b-32k": 31500,
9 | "vicuna-v1.5-7b-16k": 15500,
10 | "mistral-7B-instruct-v0.2": 31500,
11 | "mistral-7B-instruct-v0.1": 31500,
12 | "mixtral-8x7B-instruct-v0.1": 31500,
13 | "llama-2-7B-32k-instruct": 31500,
14 | "lwm-text-chat-1m": 1048576,
15 | "lwm-text-1m": 1048576
16 | }
17 |
--------------------------------------------------------------------------------
/experiments/LongBench/config/model2path.json:
--------------------------------------------------------------------------------
1 | {
2 | "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf",
3 | "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k",
4 | "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst",
5 | "internlm-7b-8k": "internlm/internlm-chat-7b-8k",
6 | "chatglm2-6b": "THUDM/chatglm2-6b",
7 | "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
8 | "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k",
9 | "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k",
10 | "mistral-7B-instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2",
11 | "mistral-7B-instruct-v0.1": "mistralai/Mistral-7B-Instruct-v0.1",
12 | "mixtral-8x7B-instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1",
13 | "llama-2-7B-32k-instruct": "togethercomputer/Llama-2-7B-32K-Instruct",
14 | "lwm-text-chat-1m": "LargeWorldModel/LWM-Text-Chat-1M",
15 | "lwm-text-1m": "LargeWorldModel/LWM-Text-1M"
16 | }
17 |
--------------------------------------------------------------------------------
/experiments/LongBench/convert_to_execl.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pandas as pd
3 | import os
4 |
5 | """
6 | Utils to convert the results of the experiments to an excel file
7 | """
8 |
9 | dataframes = []
10 |
11 | for path in os.listdir('./pred/'):
12 | file_path = f'./pred/{path}/result.json'
13 |
14 | with open(file_path, 'r') as file:
15 | data = json.load(file)
16 |
17 | column_order = [
18 | "narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa",
19 | "musique", "gov_report", "qmsum", "multi_news", "trec",
20 | "triviaqa", "samsum", "passage_count", "passage_retrieval_en",
21 | "lcc", "repobench-p"
22 | ]
23 |
24 | data_renamed = {
25 | "narrativeqa": data.get("narrativeqa", -1),
26 | "qasper": data.get("qasper", -1),
27 | "multifieldqa_en": data.get("multifieldqa_en", -1),
28 | "hotpotqa": data.get("hotpotqa", -1),
29 | "2wikimqa": data.get("2wikimqa", -1),
30 | "musique": data.get("musique", -1),
31 | "gov_report": data.get("gov_report", -1),
32 | "qmsum": data.get("qmsum", -1),
33 | "multi_news": data.get("multi_news", -1),
34 | "trec": data.get("trec", -1),
35 | "triviaqa": data.get("triviaqa", -1),
36 | "samsum": data.get("samsum", -1),
37 | "passage_count": data.get("passage_count", -1),
38 | "passage_retrieval_en": data.get("passage_retrieval_en", -1),
39 | "lcc": data.get("lcc", -1),
40 | "repobench-p": data.get("repobench-p", -1)
41 | }
42 |
43 | df = pd.DataFrame([data_renamed], columns=column_order, index=[path])
44 | dataframes.append(df)
45 |
46 | result_df = pd.concat(dataframes)
47 | output_path = './pred/combined_results.xlsx'
48 | result_df.to_excel(output_path, index_label='Folder Name')
49 |
50 | print(f"Excel file saved to {output_path}")
51 |
--------------------------------------------------------------------------------
/experiments/LongBench/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import numpy as np
5 |
6 | from metrics import (
7 | qa_f1_score,
8 | rouge_zh_score,
9 | qa_f1_zh_score,
10 | rouge_score,
11 | classification_score,
12 | retrieval_score,
13 | retrieval_zh_score,
14 | count_score,
15 | code_sim_score,
16 | )
17 |
18 | """
19 | check the sample num of each dataset
20 | """
21 |
22 | dataset_samples = {
23 | "narrativeqa": 200,
24 | "qasper":200,
25 | "multifieldqa_en": 150,
26 | "hotpotqa": 200,
27 | "2wikimqa": 200,
28 | "musique": 200,
29 | "gov_report": 200,
30 | "qmsum": 200,
31 | "multi_news": 200,
32 | "trec": 200,
33 | "triviaqa": 200,
34 | "samsum": 200,
35 | "passage_retrieval_en": 200,
36 | "passage_count": 200,
37 | "lcc": 500,
38 | "repobench-p": 500,
39 | }
40 | dataset2metric = {
41 | "narrativeqa": qa_f1_score,
42 | "qasper": qa_f1_score,
43 | "multifieldqa_en": qa_f1_score,
44 | "multifieldqa_zh": qa_f1_zh_score,
45 | "hotpotqa": qa_f1_score,
46 | "2wikimqa": qa_f1_score,
47 | "musique": qa_f1_score,
48 | "dureader": rouge_zh_score,
49 | "gov_report": rouge_score,
50 | "qmsum": rouge_score,
51 | "multi_news": rouge_score,
52 | "vcsum": rouge_zh_score,
53 | "trec": classification_score,
54 | "triviaqa": qa_f1_score,
55 | "samsum": rouge_score,
56 | "lsht": classification_score,
57 | "passage_retrieval_en": retrieval_score,
58 | "passage_count": count_score,
59 | "passage_retrieval_zh": retrieval_zh_score,
60 | "lcc": code_sim_score,
61 | "repobench-p": code_sim_score,
62 | }
63 |
64 | def parse_args(args=None):
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument('--model', type=str, default=None)
67 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
68 | return parser.parse_args(args)
69 |
70 | def scorer_e(dataset, predictions, answers, lengths, all_classes):
71 | scores = {"0-4k": [], "4-8k": [], "8k+": []}
72 | for (prediction, ground_truths, length) in zip(predictions, answers, lengths):
73 | score = 0.
74 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
75 | prediction = prediction.lstrip('\n').split('\n')[0]
76 | for ground_truth in ground_truths:
77 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
78 | if length < 4000:
79 | scores["0-4k"].append(score)
80 | elif length < 8000:
81 | scores["4-8k"].append(score)
82 | else:
83 | scores["8k+"].append(score)
84 | for key in scores.keys():
85 | scores[key] = round(100 * np.mean(scores[key]), 2)
86 | return scores
87 |
88 | def scorer(dataset, predictions, answers, all_classes):
89 | score_list = []
90 | total_score = 0.
91 | for (prediction, ground_truths) in zip(predictions, answers):
92 | score = 0.
93 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]:
94 | prediction = prediction.lstrip('\n').split('\n')[0]
95 | for ground_truth in ground_truths:
96 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes))
97 | score_list.append(score)
98 | total_score += score
99 | return round(100 * total_score / len(predictions), 2),score_list
100 |
101 | if __name__ == '__main__':
102 | args = parse_args()
103 |
104 | for model in os.listdir("pred/"):
105 | scores = dict()
106 | scores_list = dict()
107 | args.model = model
108 | if args.e:
109 | path = f"pred_e/{args.model}/"
110 | else:
111 | path = f"pred/{args.model}/"
112 | all_files = os.listdir(path)
113 | print("Evaluating on:", all_files)
114 | for filename in all_files:
115 | if not filename.endswith("jsonl"):
116 | continue
117 | predictions, answers, lengths = [], [], []
118 | dataset = filename.split('.')[0]
119 | with open(f"{path}{filename}", "r", encoding="utf-8") as f:
120 | line_cnt = 0
121 | for line in f:
122 | line_cnt += 1
123 | data = json.loads(line)
124 | predictions.append(data["pred"])
125 | answers.append(data["answers"])
126 | all_classes = data["all_classes"]
127 | if "length" in data:
128 | lengths.append(data["length"])
129 | dataset_name = filename.split('.')[0]
130 | target_samples = dataset_samples.get(dataset_name, 200)
131 | if line_cnt != target_samples:
132 | print(f"Error: {dataset_name} has {line_cnt} samples, expected {target_samples}")
133 | continue
134 | if args.e:
135 | score = scorer_e(dataset, predictions, answers, lengths, all_classes)
136 | else:
137 | score,score_list = scorer(dataset, predictions, answers, all_classes)
138 | # if dataset == 'qasper':
139 | # score_e = scorer_e(dataset, predictions, answers, lengths, all_classes)
140 | scores[dataset] = score
141 | scores_list[dataset] = score_list
142 | # if dataset == 'qasper':
143 | # scores[dataset + '_e'] = score_e
144 | if args.e:
145 | # out_path = f"H2O/results/{args.model}/result.json"
146 | out_path = f"pred_e/{args.model}/result.json"
147 |
148 | else:
149 | # out_path = f"H2O/results/{args.model}/result.json"
150 | out_path = f"pred/{args.model}/result.json"
151 | list_out_path = f"pred/{args.model}/list_result.json"
152 | # with open(out_path_e, "w") as f:
153 | # json.dump(score_e, f, ensure_ascii=False, indent=4)
154 | with open(out_path, "w") as f:
155 | json.dump(scores, f, ensure_ascii=False, indent=4)
156 | with open(list_out_path,"w") as f:
157 | json.dump(scores_list,f,ensure_ascii=False,indent=4)
--------------------------------------------------------------------------------
/experiments/LongBench/metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 | import string
3 |
4 | import jieba
5 | from fuzzywuzzy import fuzz
6 | import difflib
7 |
8 | from typing import List
9 | from collections import Counter
10 | from rouge import Rouge
11 |
12 | def normalize_answer(s):
13 | """Lower text and remove punctuation, articles and extra whitespace."""
14 |
15 | def remove_articles(text):
16 | return re.sub(r"\b(a|an|the)\b", " ", text)
17 |
18 | def white_space_fix(text):
19 | return " ".join(text.split())
20 |
21 | def remove_punc(text):
22 | exclude = set(string.punctuation)
23 | return "".join(ch for ch in text if ch not in exclude)
24 |
25 | def lower(text):
26 | return text.lower()
27 |
28 | return white_space_fix(remove_articles(remove_punc(lower(s))))
29 |
30 |
31 | def normalize_zh_answer(s):
32 | """Lower text and remove punctuation, extra whitespace."""
33 |
34 | def white_space_fix(text):
35 | return "".join(text.split())
36 |
37 | def remove_punc(text):
38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
39 | all_punctuation = set(string.punctuation + cn_punctuation)
40 | return "".join(ch for ch in text if ch not in all_punctuation)
41 |
42 | def lower(text):
43 | return text.lower()
44 |
45 | return white_space_fix(remove_punc(lower(s)))
46 |
47 | def count_score(prediction, ground_truth, **kwargs):
48 | numbers = re.findall(r"\d+", prediction)
49 | right_num = 0
50 | for number in numbers:
51 | if str(number) == str(ground_truth):
52 | right_num += 1
53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
54 | return float(final_score)
55 |
56 | def retrieval_score(prediction, ground_truth, **kwargs):
57 | pattern = r'Paragraph (\d+)'
58 | matches = re.findall(pattern, ground_truth)
59 | ground_truth_id = matches[0]
60 | numbers = re.findall(r"\d+", prediction)
61 | right_num = 0
62 | for number in numbers:
63 | if str(number) == str(ground_truth_id):
64 | right_num += 1
65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
66 | return float(final_score)
67 |
68 | def retrieval_zh_score(prediction, ground_truth, **kwargs):
69 | pattern = r'段落(\d+)'
70 | matches = re.findall(pattern, ground_truth)
71 | ground_truth_id = matches[0]
72 | numbers = re.findall(r"\d+", prediction)
73 | right_num = 0
74 | for number in numbers:
75 | if str(number) == str(ground_truth_id):
76 | right_num += 1
77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
78 | return float(final_score)
79 |
80 | def code_sim_score(prediction, ground_truth, **kwargs):
81 | all_lines = prediction.lstrip('\n').split('\n')
82 | prediction = ""
83 | for line in all_lines:
84 | if ('`' not in line) and ('#' not in line) and ('//' not in line):
85 | prediction = line
86 | break
87 | return (fuzz.ratio(prediction, ground_truth) / 100)
88 |
89 | def classification_score(prediction, ground_truth, **kwargs):
90 | em_match_list = []
91 | all_classes = kwargs["all_classes"]
92 | for class_name in all_classes:
93 | if class_name in prediction:
94 | em_match_list.append(class_name)
95 | for match_term in em_match_list:
96 | if match_term in ground_truth and match_term != ground_truth:
97 | em_match_list.remove(match_term)
98 | if ground_truth in em_match_list:
99 | score = (1.0 / len(em_match_list))
100 | else:
101 | score = 0.0
102 | return score
103 |
104 | def rouge_score(prediction, ground_truth, **kwargs):
105 | rouge = Rouge()
106 | try:
107 | scores = rouge.get_scores([prediction], [ground_truth], avg=True)
108 | except:
109 | return 0.0
110 | return scores["rouge-l"]["f"]
111 |
112 | def rouge_zh_score(prediction, ground_truth, **kwargs):
113 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
114 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
115 | score = rouge_score(prediction, ground_truth)
116 | return score
117 |
118 | def f1_score(prediction, ground_truth, **kwargs):
119 | common = Counter(prediction) & Counter(ground_truth)
120 | num_same = sum(common.values())
121 | if num_same == 0:
122 | return 0
123 | precision = 1.0 * num_same / len(prediction)
124 | recall = 1.0 * num_same / len(ground_truth)
125 | f1 = (2 * precision * recall) / (precision + recall)
126 | return f1
127 |
128 | def qa_f1_score(prediction, ground_truth, **kwargs):
129 | normalized_prediction = normalize_answer(prediction)
130 | normalized_ground_truth = normalize_answer(ground_truth)
131 |
132 | prediction_tokens = normalized_prediction.split()
133 | ground_truth_tokens = normalized_ground_truth.split()
134 | return f1_score(prediction_tokens, ground_truth_tokens)
135 |
136 |
137 | def qa_f1_zh_score(prediction, ground_truth, **kwargs):
138 | prediction_tokens = list(jieba.cut(prediction, cut_all=False))
139 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
140 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
141 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
142 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
143 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
144 | return f1_score(prediction_tokens, ground_truth_tokens)
145 |
--------------------------------------------------------------------------------
/experiments/LongBench/pred.py:
--------------------------------------------------------------------------------
1 | import os
2 | from datasets import load_dataset
3 | import torch
4 | import json
5 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6 | from tqdm import tqdm
7 | import numpy as np
8 | import random
9 | import argparse
10 | import torch.distributed as dist
11 | import torch.multiprocessing as mp
12 | import gc
13 | import time
14 |
15 | from adaptive_snapkv.monkeypatch.monkeypatch import replace_mistral_adaptive, replace_llama_adaptive, config_compress, replace_llama_fixed, replace_mistral_fixed,replace_mistral_slm, replace_llama_slm
16 |
17 | def parse_args(args=None):
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("-m", '--model_name_or_path', type=str, required=True)
20 | parser.add_argument('--max_length', type=int, required=True)
21 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E")
22 | parser.add_argument("-d", '--dataset', type=str, default="THUDM/LongBench")
23 | parser.add_argument("--out_name", type=str, required=True)
24 | parser.add_argument('--compress_args_path', type=str, default=None, help="Path to the compress args")
25 | # parser.add_argument('--adaptive', action='store_true', help="Use adaptive budgets allocation across heads")
26 | parser.add_argument('--mode', type=str, choices=['ada', 'fix', 'test', "slm"], help="Ada mode, fix mode or normal")
27 | parser.add_argument('--floor_alpha',type=float,default=0.2,help="floor_alpha budgets for each head")
28 | parser.add_argument('--gqa_support',action='store_true', default=False, help="init gqa_support")
29 | parser.add_argument('--gqa_func',type=str, default="mean", help="gqa operation:optional max mean")
30 | parser.add_argument('--normalize',action='store_true')
31 | parser.add_argument('--pyram',action='store_true',help="using pyram mode")
32 | parser.add_argument('--pyram_beta',default=20,type=int, help="hyper parameter for pyram")
33 | parser.add_argument('--budget',default=1024, type=int, help="budget size for kv cache")
34 |
35 | parser.add_argument('--budget_ratio',type=float,default=0.2,help="floor_alpha budgets for each head")
36 | return parser.parse_args(args)
37 |
38 | # This is the customized building prompt for chat models
39 | def build_chat(tokenizer, prompt, model_name):
40 | if "chatglm3" in model_name:
41 | prompt = tokenizer.build_chat_input(prompt)
42 | elif "chatglm" in model_name:
43 | prompt = tokenizer.build_prompt(prompt)
44 | elif "longchat" in model_name or "vicuna" in model_name:
45 | from fastchat.model import get_conversation_template
46 | conv = get_conversation_template("vicuna")
47 | conv.append_message(conv.roles[0], prompt)
48 | conv.append_message(conv.roles[1], None)
49 | prompt = conv.get_prompt()
50 | elif "llama2" in model_name:
51 | prompt = f"[INST]{prompt}[/INST]"
52 | elif "xgen" in model_name:
53 | header = (
54 | "A chat between a curious human and an artificial intelligence assistant. "
55 | "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
56 | )
57 | prompt = header + f" ### Human: {prompt}\n###"
58 | elif "internlm" in model_name:
59 | prompt = f"<|User|>:{prompt}\n<|Bot|>:"
60 | elif "llama-3" in model_name.lower() and "instruct" in model_name.lower():
61 | prompt = [{ "role": "user", "content": prompt}]
62 | prompt = tokenizer.apply_chat_template(
63 | prompt,
64 | tokenize=False,
65 | add_generation_prompt=True
66 | )
67 | return prompt
68 |
69 | def post_process(response, model_name):
70 | if "xgen" in model_name:
71 | response = response.strip().replace("Assistant:", "")
72 | elif "internlm" in model_name:
73 | response = response.split("")[0]
74 | return response
75 |
76 | def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name_or_path, out_path):
77 | device = "cuda"
78 | preds = []
79 | times = []
80 | with open(f"{out_path}_tmp", "w", encoding="utf-8") as f:
81 | for json_obj in tqdm(data):
82 | prompt = prompt_format.format(**json_obj)
83 | # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions)
84 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0]
85 | if "chatglm3" in model_name_or_path:
86 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0]
87 | if len(tokenized_prompt) > max_length:
88 | half = int(max_length/2)
89 | prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True)
90 | if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks
91 | prompt = build_chat(tokenizer, prompt, model_name_or_path)
92 | if "chatglm3" in model_name_or_path:
93 | if dataset in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]:
94 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
95 | else:
96 | input = prompt.to(device)
97 | else:
98 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device)
99 | context_length = input.input_ids.shape[-1]
100 |
101 | torch.cuda.synchronize()
102 | t = time.time()
103 | if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue
104 | output = model.generate(
105 | **input,
106 | max_new_tokens=max_gen,
107 | num_beams=1,
108 | do_sample=False,
109 | temperature=1.0,
110 | min_length=context_length+1,
111 | eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]],
112 | )[0]
113 | else:
114 | output = model.generate(
115 | **input,
116 | max_new_tokens=max_gen,
117 | num_beams=1,
118 | do_sample=False,
119 | temperature=1.0,
120 | eos_token_id=[tokenizer.eos_token_id],
121 | )[0]
122 | torch.cuda.synchronize()
123 | t = time.time() - t
124 | times.append(t)
125 |
126 | pred = tokenizer.decode(output[context_length:], skip_special_tokens=True)
127 | pred = post_process(pred, model_name_or_path)
128 | preds.append(pred)
129 |
130 |
131 | json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"], "time": t}, f, ensure_ascii=False )
132 | f.write('\n')
133 | f.flush()
134 |
135 |
136 | gc.collect()
137 | torch.cuda.empty_cache()
138 |
139 |
140 | with open(out_path, "w", encoding="utf-8") as f:
141 | for json_obj, pred, t in zip(data, preds, times):
142 | json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"], "time": t}, f, ensure_ascii=False )
143 | f.write('\n')
144 |
145 |
146 | def seed_everything(seed):
147 | torch.manual_seed(seed)
148 | torch.cuda.manual_seed(seed)
149 | np.random.seed(seed)
150 | random.seed(seed)
151 | torch.backends.cudnn.benchmark = False
152 | torch.backends.cudnn.deterministic = True
153 | torch.cuda.manual_seed_all(seed)
154 |
155 | def load_model_and_tokenizer(path):
156 | tokenizer = AutoTokenizer.from_pretrained(path,
157 | trust_remote_code=True,
158 | )
159 | model = AutoModelForCausalLM.from_pretrained(path,
160 | torch_dtype=torch.bfloat16,
161 | # TODO: hard code
162 | device_map="auto",
163 | attn_implementation="flash_attention_2",
164 | trust_remote_code=True,
165 | )
166 | model = model.eval()
167 | return model, tokenizer
168 |
169 | if __name__ == '__main__':
170 | seed_everything(42)
171 | args = parse_args()
172 | world_size = torch.cuda.device_count()
173 | mp.set_start_method('spawn', force=True)
174 |
175 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
176 | model_name_or_path = args.model_name_or_path
177 | model_name = args.model_name_or_path.split("/")[-1]
178 | # define your model
179 | max_length = args.max_length
180 | if args.e:
181 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \
182 | "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
183 | else:
184 | datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
185 | "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
186 | "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
187 |
188 | datasets = [
189 | "qasper", "narrativeqa", "multifieldqa_en", # single doc
190 | "hotpotqa", "2wikimqa", "musique", # multi doc
191 | "trec", "triviaqa", "samsum", # few-shot
192 | "gov_report", "qmsum", "multi_news", # sum
193 | "passage_count", "passage_retrieval_en", # Synthetic
194 | "lcc", "repobench-p", # code
195 | ]
196 |
197 | print(datasets)
198 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output
199 | dataset2prompt = json.load(open("config/dataset2prompt.json", "r"))
200 | dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r"))
201 | # predict on each dataset
202 | if not os.path.exists("pred"):
203 | os.makedirs("pred")
204 | if not os.path.exists("pred_e"):
205 | os.makedirs("pred_e")
206 |
207 | if args.mode == "ada":
208 | print("Ada mode")
209 | replace_mistral_adaptive()
210 | replace_llama_adaptive()
211 | elif args.mode == "fix":
212 | print("Fix mode")
213 | replace_mistral_fixed()
214 | replace_llama_fixed()
215 | elif args.mode == "slm":
216 | print("Slm mode")
217 | replace_mistral_slm()
218 | replace_llama_slm()
219 | else:
220 | print("Base mode")
221 |
222 |
223 | # NOTE: load model after replace
224 | model, tokenizer = load_model_and_tokenizer(model_name_or_path)
225 |
226 | config_compress(model, base_capacity=args.budget, pyram_mode=args.pyram, floor_alpha=args.floor_alpha, gqa_support=args.gqa_support, gqa_func=args.gqa_func)
227 |
228 | for dataset in datasets:
229 | if args.e:
230 | data = load_dataset(args.dataset, f"{dataset}_e", split='test', data_dir=f"{args.dataset}/data")
231 | if not os.path.exists(f"pred_e/{args.out_name}"):
232 | os.makedirs(f"pred_e/{args.out_name}")
233 | out_path = f"pred_e/{args.out_name}/{dataset}.jsonl"
234 | else:
235 | data = load_dataset(args.dataset, f"{dataset}", split='test', data_dir=f"{args.dataset}/data")
236 | if not os.path.exists(f"pred/{args.out_name}"):
237 | os.makedirs(f"pred/{args.out_name}")
238 | out_path = f"pred/{args.out_name}/{dataset}.jsonl"
239 |
240 | if os.path.exists(out_path):
241 | print(f"{out_path} exists skip")
242 | continue
243 |
244 | prompt_format = dataset2prompt[dataset]
245 | max_gen = dataset2maxlen[dataset]
246 | data_all = [data_sample for data_sample in data]
247 | # TODO: hard code single process, which use all gpus
248 | torch.cuda.synchronize()
249 | t = time.time()
250 | get_pred(model, tokenizer, data_all, max_length, max_gen, prompt_format, dataset, device, model_name_or_path, out_path)
251 | torch.cuda.synchronize()
252 | t = time.time() - t
253 | print(f"== {args.out_name} {dataset} Time: {t}")
254 |
--------------------------------------------------------------------------------
/makefile:
--------------------------------------------------------------------------------
1 | i install:
2 | cd csrc && make
3 | pip install -e .
4 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel", "packaging"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "adakv"
7 | version = "0.0.1"
8 |
9 | dependencies = [
10 | "rouge",
11 | "tokenizers",
12 | "rouge_score",
13 | "numpy==1.26.0",
14 | "jieba==0.42.1",
15 | "datasets",
16 | "accelerate",
17 | "tqdm==4.66.1",
18 | "icetk==0.0.7",
19 | "transformers==4.44.2",
20 | ]
21 |
22 | [tool.setuptools.packages.find]
23 | where = ["."]
24 |
--------------------------------------------------------------------------------