├── .gitignore ├── LICENSE ├── README.md ├── adaptive_snapkv └── monkeypatch │ ├── adaptive_llama_hijack.py │ ├── adaptive_mistral_hijack.py │ ├── fixed_llama_hijack.py │ ├── fixed_mistral_hijack.py │ ├── monkeypatch.py │ ├── slm_llama_hijack.py │ ├── slm_mistral_hijack.py │ └── snapkv_utils.py ├── assets └── images │ ├── 4k_ruler_average_score.png │ ├── LongBench_mistral.png │ ├── LongBench_mistral_gqa.png │ ├── head_vary.png │ ├── main.png │ ├── mem.png │ └── speed.png ├── csrc ├── .gitignore ├── LICENSE ├── build.py ├── csrc │ ├── cuda_api.cu │ └── static_switch.h ├── include │ └── cuda_api.h └── makefile ├── experiments └── LongBench │ ├── GQA_eval_longbench.sh │ ├── README.md │ ├── config │ ├── dataset2maxlen.json │ ├── dataset2prompt.json │ ├── model2maxlen.json │ └── model2path.json │ ├── convert_to_execl.py │ ├── eval.py │ ├── metrics.py │ └── pred.py ├── makefile └── pyproject.toml /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | experiments/LongBench/pred/* 9 | experiments/LongBench/pred_e/* 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | notebooks/test* 164 | experiments/LongBench/bak/buggy 165 | experiments/LongBench/archive 166 | experiments/LongBench/ignore 167 | experiments/LongBench/Log 168 | experiments/LongBench/*pred 169 | experiments/LongBench/runall_*.sh 170 | .vscode 171 | .idea 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yuan Feng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | 24 | This project includes portions of the following open-source projects: 25 | 26 | - SnapKV: https://github.com/FasterDecoding/SnapKV 27 | - PyramidKV (MIT License): https://github.com/Zefan-Cai/PyramidKV 28 | 29 | Each of these projects retains their respective copyright and license. For more details, refer to their original LICENSE files. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaKV 2 | Adaptive Budget Allocation across different attention heads based on their concentration degrees effectively improves budget utilization, thereby improving post-eviction generation quality. 3 |

4 | 5 |

6 | 7 | 8 | ## Updates 9 | * __[2024.11.08 GQA Support]__ In response to numerous community requests, we’ve just uploaded a new testing branch, named `test_gqa_support`, which introduces support for Grouped Query Attention (GQA) under the methods "SnapKV," "PyramidKV," "Ada-SnapKV," and "Ada-PyramidKV." You can use the `GQA_eval_longbench.sh` script for direct evaluation on the LongBench benchmark. Detailed results will be released soon. 10 | * __[2024.11.15 GQA Results]__ We have integrated GQA support for Mistral-7B-Instruct-v0.2 in SnapKV, PyramidKV, and our Ada-KV. You can try it by running the `GQA_eval_longbench.sh` script! Preliminary results on a single A800 are provided in the table below. In future updates, we will include more comprehensive GQA test results in the appendix of our [Ada-KV paper](arxiv.org/abs/2407.11550). 11 | ![](./assets/images/LongBench_mistral_gqa.png) 12 | * __[2024.12.12 Community Collaboration ]__ We’re collaborating with NVIDIA’s [KVpress](https://github.com/NVIDIA/kvpress) team to integrate Ada-KV into their excellent project. This effort also aims to build a foundation for future research on head-specific cache compression methods. See the [draft PR](https://github.com/NVIDIA/kvpress/pull/25) for progress. Meanwhile, we’ve conducted broader evaluations of Ada-KV(Ada-SnapKV), including preliminary results on the 4K Ruler Benchmark (context compression with and without questions). For further insights into the evaluation, please refer to the kvpress [repository](https://github.com/NVIDIA/kvpress). 13 | 14 | 15 |

16 | 17 |

18 | 19 | ## Usage 20 | 21 | ### Requirements 22 | 23 | ``` 24 | transformers==4.37.2 25 | flash-attn==2.4.0 26 | 27 | datasets 28 | tiktoken 29 | jieba 30 | rouge_score 31 | ``` 32 | 33 | ### Installation 34 | 35 | ``` 36 | git clone https://github.com/FFY0/AdaKV 37 | cd AdaKV 38 | make i 39 | ``` 40 | 41 | ### Quick Start 42 | 43 | ```python 44 | # replace modeling with adakv 45 | from adaptive_snapkv.monkeypatch.monkeypatch import replace_mistral_adaptive, replace_llama_adaptive 46 | replace_mistral_adaptive() 47 | replace_llama_adaptive() 48 | 49 | model = AutoModelForCausalLM.from_pretrained( 50 | model_name_or_path, 51 | config=config, 52 | device_map=device_map, 53 | attn_implementation="flash_attention_2", 54 | torch_dtype=torch.bfloat16, 55 | trust_remote_code=True, 56 | ) 57 | 58 | # config hyperparameters 59 | compress_args = {} 60 | def config_compress(model, window_size=32, base_capacity=512, kernel_size=7, pooling="maxpool", floor_alpha=0.5, pyram_mode = False, beta = 20): 61 | model.model.config.window_size = window_size 62 | model.model.config.base_capacity = base_capacity 63 | model.model.config.kernel_size = kernel_size 64 | 65 | model.model.config.pooling = pooling 66 | model.model.config.floor_alpha = floor_alpha 67 | 68 | model.model.config.pyram_mode = pyram_mode 69 | model.model.config.pyram_beta = beta 70 | return model 71 | 72 | model = config_compress(model, **compress_args) 73 | ``` 74 | 75 | #### Flattened Storage and Flash Attention Support 76 | 77 | Considering varied cache length across heads, we implement a flattened storage layout of KV cache combined with `flash_attn_varlen_func` for efficent computation. 78 | 79 | ##### Regular MHA Cache Storage 80 | 81 | ``` 82 | Layer i: 83 | head0: (t00, t01, t02) 84 | head1: (t10, t11, t12) 85 | head2: (t20, t21, t22) 86 | 87 | past_key_value.update(): 88 | 89 | Layer i: 90 | head0: (t00, t01, t02, t03) 91 | head1: (t10, t11, t12, t13) 92 | head2: (t20, t21, t22, t23) 93 | 94 | ``` 95 | 96 | Note. `tij` means cache element of token j on head i in this case. 97 | 98 | ##### Flattened Cache Storage 99 | 100 | The corresponding cuda code can be found in [`./csrc/csrc/cuda_api.cu`](./csrc/csrc/cuda_api.cu). 101 | ``` 102 | Layer i: 103 | (t00, t01, t02, t03) (t10, t11) (t20, t21, t22) 104 | 105 | past_key_value.update(): 106 | 107 | Layer i: 108 | phase 0: malloc empty cache 109 | (_, _, _, _, _) (_, _, _) (_, _, _, _) 110 | 111 | phase 1: copy old value 112 | (t00, t01, t02, t03, _) (t10, t11, _) (t20, t21, t22, _) 113 | 114 | phase 2: insert new value 115 | (t00, t01, t02, t03, t04) (t10, t11, t12) (t20, t21, t22, t23) 116 | ``` 117 | 118 | Details about flash_attn_varlen_func can be found in [`Repo`](https://github.com/Dao-AILab/flash-attention/blob/c4b9015d74bd9f638c6fd574482accf4bbbd4197/flash_attn/flash_attn_interface.py#L1051). 119 | 120 | ##### Peak Memory Footprint and Decoding Latency For Our Implementation: 121 |

122 | 123 |

124 | 125 | ## Evaluations 126 | ### LongBench without GQA Support 127 | 128 | ![](./assets/images/LongBench_mistral.png) 129 | 130 | ```bash 131 | cd ./experiments/LongBench 132 | bash runall.sh 133 | ``` 134 | 135 | 136 | ## Citation 137 | If you found our work valuable, please cite: 138 | ``` 139 | @misc{feng2024adakvoptimizingkvcache, 140 | title={Ada-KV: Optimizing KV Cache Eviction by Adaptive Budget Allocation for Efficient LLM Inference}, 141 | author={Yuan Feng and Junlin Lv and Yukun Cao and Xike Xie and S. Kevin Zhou}, 142 | year={2024}, 143 | eprint={2407.11550}, 144 | archivePrefix={arXiv}, 145 | primaryClass={cs.CL}, 146 | url={https://arxiv.org/abs/2407.11550}, 147 | } 148 | ``` 149 | 150 | ## Acknowledgement 151 | 152 | We extend our gratitude to [SnapKV](https://github.com/FasterDecoding/SnapKV) and [PyramidKV](https://github.com/Zefan-Cai/PyramidKV) for their contributions of open-source code, which have significantly facilitated the advancement of this project. 153 | 154 | ## Misc 155 | 156 | ### Observation 157 | 158 | Different attention heads within each layer of LLMs exhibit significant disparities in the degrees of attention concentration. 159 | 160 | Therefore, we can improves budget utilization by dynamically allocating the budget across different attention heads within the same layer based on their concentration degrees. 161 | 162 | ![](./assets/images/head_vary.png) 163 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/adaptive_llama_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List, Optional, Tuple, Union 5 | import warnings 6 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 7 | from transformers.models.llama.modeling_llama import ( 8 | apply_rotary_pos_emb, 9 | repeat_kv, 10 | ) 11 | from transformers.utils import ( 12 | logging, 13 | is_flash_attn_2_available, 14 | ) 15 | from transformers.modeling_attn_mask_utils import ( 16 | AttentionMaskConverter, 17 | _prepare_4d_attention_mask, 18 | _prepare_4d_causal_attention_mask, 19 | _prepare_4d_causal_attention_mask_for_sdpa, 20 | ) 21 | from transformers.modeling_outputs import BaseModelOutputWithPast 22 | from transformers.models.llama.modeling_llama import ( 23 | apply_rotary_pos_emb, 24 | repeat_kv, 25 | ) 26 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 27 | 28 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_adaptive_snapkv, DynamicCacheSplitHeadFlatten 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | if is_flash_attn_2_available(): 33 | from flash_attn import flash_attn_func, flash_attn_varlen_func 34 | 35 | def adaptive_LlamaModel_forward( 36 | self, 37 | input_ids: torch.LongTensor = None, 38 | attention_mask: Optional[torch.Tensor] = None, 39 | position_ids: Optional[torch.LongTensor] = None, 40 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 41 | inputs_embeds: Optional[torch.FloatTensor] = None, 42 | use_cache: Optional[bool] = None, 43 | output_attentions: Optional[bool] = None, 44 | output_hidden_states: Optional[bool] = None, 45 | return_dict: Optional[bool] = None, 46 | cache_position: Optional[torch.LongTensor] = None, 47 | ) -> Union[Tuple, BaseModelOutputWithPast]: 48 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 49 | output_hidden_states = ( 50 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 51 | ) 52 | use_cache = use_cache if use_cache is not None else self.config.use_cache 53 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 54 | 55 | if (input_ids is None) ^ (inputs_embeds is not None): 56 | raise ValueError( 57 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 58 | ) 59 | 60 | if self.gradient_checkpointing and self.training and use_cache: 61 | logger.warning_once( 62 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 63 | ) 64 | use_cache = False 65 | 66 | if inputs_embeds is None: 67 | inputs_embeds = self.embed_tokens(input_ids) 68 | 69 | # return_legacy_cache = False 70 | # if ( 71 | # use_cache and not isinstance(past_key_values, Cache) and not self.training 72 | # ): # kept for BC (non `Cache` `past_key_values` inputs) 73 | # return_legacy_cache = True 74 | # past_key_values = DynamicCache.from_legacy_cache(past_key_values) 75 | # logger.warning_once( 76 | # "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " 77 | # "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" 78 | # ) 79 | 80 | # NOTE: adakv 81 | return_legacy_cache = True 82 | past_key_values = DynamicCacheSplitHeadFlatten.from_legacy_cache(past_key_values) 83 | logger.warning_once( 84 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " 85 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" 86 | ) 87 | 88 | if cache_position is None: 89 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 90 | cache_position = torch.arange( 91 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 92 | ) 93 | if position_ids is None: 94 | position_ids = cache_position.unsqueeze(0) 95 | 96 | causal_mask = self._update_causal_mask( 97 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions 98 | ) 99 | hidden_states = inputs_embeds 100 | 101 | # create position embeddings to be shared across the decoder layers 102 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 103 | 104 | # decoder layers 105 | all_hidden_states = () if output_hidden_states else None 106 | all_self_attns = () if output_attentions else None 107 | next_decoder_cache = None 108 | 109 | for decoder_layer in self.layers: 110 | if output_hidden_states: 111 | all_hidden_states += (hidden_states,) 112 | 113 | if self.gradient_checkpointing and self.training: 114 | layer_outputs = self._gradient_checkpointing_func( 115 | decoder_layer.__call__, 116 | hidden_states, 117 | causal_mask, 118 | position_ids, 119 | past_key_values, 120 | output_attentions, 121 | use_cache, 122 | cache_position, 123 | position_embeddings, 124 | ) 125 | else: 126 | layer_outputs = decoder_layer( 127 | hidden_states, 128 | attention_mask=causal_mask, 129 | position_ids=position_ids, 130 | past_key_value=past_key_values, 131 | output_attentions=output_attentions, 132 | use_cache=use_cache, 133 | cache_position=cache_position, 134 | position_embeddings=position_embeddings, 135 | ) 136 | 137 | hidden_states = layer_outputs[0] 138 | 139 | if use_cache: 140 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 141 | 142 | if output_attentions: 143 | all_self_attns += (layer_outputs[1],) 144 | 145 | hidden_states = self.norm(hidden_states) 146 | 147 | # add hidden states from the last decoder layer 148 | if output_hidden_states: 149 | all_hidden_states += (hidden_states,) 150 | 151 | next_cache = next_decoder_cache if use_cache else None 152 | if return_legacy_cache: 153 | next_cache = next_cache.to_legacy_cache() 154 | 155 | hidden_states = hidden_states[:, -1,:].unsqueeze(1) 156 | 157 | if not return_dict: 158 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 159 | return BaseModelOutputWithPast( 160 | last_hidden_state=hidden_states, 161 | past_key_values=next_cache, 162 | hidden_states=all_hidden_states, 163 | attentions=all_self_attns, 164 | ) 165 | 166 | def adaptive_llama_flash_attn2_forward( 167 | self, 168 | hidden_states: torch.Tensor, 169 | attention_mask: Optional[torch.LongTensor] = None, 170 | position_ids: Optional[torch.LongTensor] = None, 171 | past_key_value: Optional[Cache] = None, 172 | output_attentions: bool = False, 173 | use_cache: bool = False, 174 | cache_position: Optional[torch.LongTensor] = None, 175 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 176 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 177 | # NOTE: adakv 178 | init_adaptive_snapkv(self) 179 | if isinstance(past_key_value, StaticCache): 180 | raise ValueError( 181 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " 182 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" 183 | ) 184 | 185 | output_attentions = False 186 | 187 | bsz, q_len, _ = hidden_states.size() 188 | 189 | query_states = self.q_proj(hidden_states) 190 | key_states = self.k_proj(hidden_states) 191 | value_states = self.v_proj(hidden_states) 192 | 193 | # Flash attention requires the input to have the shape 194 | # batch_size x seq_length x head_dim x hidden_dim 195 | # therefore we just need to keep the original shape 196 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 197 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 198 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 199 | 200 | if position_embeddings is None: 201 | logger.warning_once( 202 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 203 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 204 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " 205 | "removed and `position_embeddings` will be mandatory." 206 | ) 207 | cos, sin = self.rotary_emb(value_states, position_ids) 208 | else: 209 | cos, sin = position_embeddings 210 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 211 | 212 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 213 | 214 | is_prefill = q_len != 1 215 | 216 | if is_prefill: 217 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states) 218 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) 219 | 220 | # repeat k/v heads if n_kv_heads < n_heads 221 | # [SnapKV] move to ahead 222 | key_states = repeat_kv(key_states, self.num_key_value_groups) 223 | value_states = repeat_kv(value_states, self.num_key_value_groups) 224 | 225 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache 226 | # to be able to avoid many of these transpose/reshape/view. 227 | query_states = query_states.transpose(1, 2) 228 | key_states = key_states.transpose(1, 2) 229 | value_states = value_states.transpose(1, 2) 230 | 231 | dropout_rate = self.attention_dropout if self.training else 0.0 232 | 233 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 234 | # therefore the input hidden states gets silently casted in float32. Hence, we need 235 | # cast them back in the correct dtype just to be sure everything works as expected. 236 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 237 | # in fp32. (LlamaRMSNorm handles it correctly) 238 | 239 | input_dtype = query_states.dtype 240 | if input_dtype == torch.float32: 241 | if torch.is_autocast_enabled(): 242 | target_dtype = torch.get_autocast_gpu_dtype() 243 | # Handle the case where the model is quantized 244 | elif hasattr(self.config, "_pre_quantization_dtype"): 245 | target_dtype = self.config._pre_quantization_dtype 246 | else: 247 | target_dtype = self.q_proj.weight.dtype 248 | 249 | logger.warning_once( 250 | f"The input hidden states seems to be silently casted in float32, this might be related to" 251 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 252 | f" {target_dtype}." 253 | ) 254 | 255 | query_states = query_states.to(target_dtype) 256 | key_states = key_states.to(target_dtype) 257 | value_states = value_states.to(target_dtype) 258 | 259 | attn_output = _flash_attention_forward( 260 | query_states, 261 | key_states, 262 | value_states, 263 | attention_mask, 264 | q_len, 265 | position_ids=position_ids, 266 | dropout=dropout_rate, 267 | sliding_window=getattr(self, "sliding_window", None), 268 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 269 | is_causal=self.is_causal, 270 | ) 271 | 272 | attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() 273 | else: 274 | # decoding 275 | cache_kwargs["head_lens"] = self.kv_cluster.head_lens 276 | cache_kwargs["cu_klen"] = self.kv_cluster.cu_klen 277 | 278 | if not self.kv_cluster.gqa_support: 279 | key_states = repeat_kv(key_states, self.num_key_value_groups) 280 | value_states = repeat_kv(value_states, self.num_key_value_groups) 281 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 282 | 283 | # NOTE: update meta data 284 | self.kv_cluster.klen_sum += self.num_heads 285 | self.kv_cluster.max_seqlen_k += 1 286 | self.kv_cluster.cu_klen += self.kv_cluster.cu_offset 287 | self.kv_cluster.head_lens += 1 288 | 289 | if self.kv_cluster.gqa_support: 290 | query_states = query_states.view(-1, self.num_key_value_groups, self.head_dim) 291 | else: 292 | query_states = query_states.view(-1, 1, self.head_dim) 293 | 294 | key_states = key_states.view(-1,1,self.head_dim) 295 | value_states = value_states.view(-1,1,self.head_dim) 296 | 297 | cu_seqlens_q = self.kv_cluster.cu_qlen 298 | cu_seqlens_k = self.kv_cluster.cu_klen 299 | max_seqlen_q = 1 300 | max_seqlen_k = self.kv_cluster.max_seqlen_k 301 | 302 | attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens_q, 303 | cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True) 304 | # TODO: support batch size > 1 305 | assert bsz == 1 306 | attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim) 307 | attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) 308 | 309 | attn_output = self.o_proj(attn_output) 310 | 311 | if not output_attentions: 312 | attn_weights = None 313 | 314 | return attn_output, attn_weights, past_key_value 315 | 316 | def prepare_inputs_for_generation_llama( 317 | self, 318 | input_ids, 319 | past_key_values=None, 320 | attention_mask=None, 321 | inputs_embeds=None, 322 | cache_position=None, 323 | position_ids=None, 324 | use_cache=True, 325 | **kwargs, 326 | ): 327 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 328 | # Exception 1: when passing input_embeds, input_ids may be missing entries 329 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 330 | if past_key_values is not None: 331 | if inputs_embeds is not None: # Exception 1 332 | input_ids = input_ids[:, -cache_position.shape[0] :] 333 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) 334 | input_ids = input_ids[:, cache_position] 335 | 336 | if attention_mask is not None and position_ids is None: 337 | # create position_ids on the fly for batch generation 338 | position_ids = attention_mask.long().cumsum(-1) - 1 339 | position_ids.masked_fill_(attention_mask == 0, 1) 340 | if past_key_values: 341 | position_ids = position_ids[:, -input_ids.shape[1] :] 342 | 343 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. 344 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 345 | 346 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 347 | if inputs_embeds is not None and cache_position[0] == 0: 348 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} 349 | else: 350 | # The clone here is for the same reason as for `position_ids`. 351 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} 352 | 353 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 354 | if model_inputs["inputs_embeds"] is not None: 355 | batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape 356 | device = model_inputs["inputs_embeds"].device 357 | else: 358 | batch_size, sequence_length = model_inputs["input_ids"].shape 359 | device = model_inputs["input_ids"].device 360 | 361 | dtype = self.lm_head.weight.dtype 362 | min_dtype = torch.finfo(dtype).min 363 | 364 | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( 365 | attention_mask, 366 | sequence_length=sequence_length, 367 | target_length=past_key_values.get_max_length(), 368 | dtype=dtype, 369 | device=device, 370 | min_dtype=min_dtype, 371 | cache_position=cache_position, 372 | batch_size=batch_size, 373 | ) 374 | 375 | model_inputs.update( 376 | { 377 | "position_ids": position_ids, 378 | "cache_position": cache_position, 379 | "past_key_values": past_key_values, 380 | "use_cache": use_cache, 381 | "attention_mask": attention_mask, 382 | } 383 | ) 384 | return model_inputs 385 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/adaptive_mistral_hijack.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from typing import List, Optional, Tuple, Union, Any,Dict 6 | import warnings 7 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 8 | from transformers.models.mistral.modeling_mistral import ( 9 | apply_rotary_pos_emb, 10 | repeat_kv, 11 | ) 12 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask_for_sdpa, \ 13 | _prepare_4d_causal_attention_mask 14 | from transformers.modeling_outputs import BaseModelOutputWithPast 15 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 16 | from transformers.models.mistral.modeling_mistral import ( 17 | apply_rotary_pos_emb, 18 | repeat_kv, 19 | ) 20 | from transformers.utils import ( 21 | logging, 22 | is_flash_attn_2_available, 23 | ) 24 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_adaptive_snapkv, DynamicCacheSplitHeadFlatten 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | if is_flash_attn_2_available(): 29 | from flash_attn import flash_attn_func, flash_attn_varlen_func 30 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 31 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 32 | 33 | 34 | def adaptive_MistralModel_forward( 35 | self, 36 | input_ids: torch.LongTensor = None, 37 | attention_mask: Optional[torch.Tensor] = None, 38 | position_ids: Optional[torch.LongTensor] = None, 39 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 40 | inputs_embeds: Optional[torch.FloatTensor] = None, 41 | use_cache: Optional[bool] = None, 42 | output_attentions: Optional[bool] = None, 43 | output_hidden_states: Optional[bool] = None, 44 | return_dict: Optional[bool] = None, 45 | cache_position: Optional[torch.LongTensor] = None, 46 | ) -> Union[Tuple, BaseModelOutputWithPast]: 47 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 48 | output_hidden_states = ( 49 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 50 | ) 51 | use_cache = use_cache if use_cache is not None else self.config.use_cache 52 | 53 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 54 | 55 | # retrieve input_ids and inputs_embeds 56 | if (input_ids is None) ^ (inputs_embeds is not None): 57 | raise ValueError( 58 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 59 | ) 60 | 61 | if self.gradient_checkpointing and self.training and use_cache: 62 | logger.warning_once( 63 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 64 | ) 65 | use_cache = False 66 | 67 | if inputs_embeds is None: 68 | inputs_embeds = self.embed_tokens(input_ids) 69 | 70 | past_key_values = DynamicCacheSplitHeadFlatten.from_legacy_cache(past_key_values) 71 | return_legacy_cache = True 72 | 73 | if cache_position is None: 74 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 75 | cache_position = torch.arange( 76 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 77 | ) 78 | 79 | if position_ids is None: 80 | position_ids = cache_position.unsqueeze(0) 81 | 82 | causal_mask = self._update_causal_mask( 83 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions 84 | ) 85 | 86 | hidden_states = inputs_embeds 87 | 88 | # decoder layers 89 | all_hidden_states = () if output_hidden_states else None 90 | all_self_attns = () if output_attentions else None 91 | next_decoder_cache = None 92 | 93 | for decoder_layer in self.layers: 94 | if output_hidden_states: 95 | all_hidden_states += (hidden_states,) 96 | 97 | if self.gradient_checkpointing and self.training: 98 | layer_outputs = self._gradient_checkpointing_func( 99 | decoder_layer.__call__, 100 | hidden_states, 101 | causal_mask, 102 | position_ids, 103 | past_key_values, 104 | output_attentions, 105 | use_cache, 106 | cache_position, 107 | ) 108 | else: 109 | layer_outputs = decoder_layer( 110 | hidden_states, 111 | attention_mask=causal_mask, 112 | position_ids=position_ids, 113 | past_key_value=past_key_values, 114 | output_attentions=output_attentions, 115 | use_cache=use_cache, 116 | cache_position=cache_position, 117 | ) 118 | 119 | hidden_states = layer_outputs[0] 120 | 121 | if use_cache: 122 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 123 | 124 | if output_attentions: 125 | all_self_attns += (layer_outputs[1],) 126 | 127 | hidden_states = self.norm(hidden_states) 128 | 129 | # add hidden states from the last decoder layer 130 | if output_hidden_states: 131 | all_hidden_states += (hidden_states,) 132 | 133 | next_cache = next_decoder_cache if use_cache else None 134 | if return_legacy_cache: 135 | next_cache = next_cache.to_legacy_cache() 136 | 137 | if not return_dict: 138 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 139 | 140 | hidden_states = hidden_states[:, -1,:].unsqueeze(1) 141 | return BaseModelOutputWithPast( 142 | last_hidden_state=hidden_states, 143 | past_key_values=next_cache, 144 | hidden_states=all_hidden_states, 145 | attentions=all_self_attns, 146 | ) 147 | 148 | def adaptive_mistral_flash_attn2_forward( 149 | self, 150 | hidden_states: torch.Tensor, 151 | attention_mask: Optional[torch.Tensor] = None, 152 | position_ids: Optional[torch.LongTensor] = None, 153 | past_key_value: Optional[Cache] = None, 154 | output_attentions: bool = False, 155 | use_cache: bool = False, 156 | cache_position: Optional[torch.LongTensor] = None, 157 | ): 158 | if isinstance(past_key_value, StaticCache): 159 | raise ValueError( 160 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " 161 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" 162 | ) 163 | # NOTE: adakv 164 | init_adaptive_snapkv(self) 165 | 166 | output_attentions = False 167 | 168 | bsz, q_len, _ = hidden_states.size() 169 | 170 | query_states = self.q_proj(hidden_states) 171 | key_states = self.k_proj(hidden_states) 172 | value_states = self.v_proj(hidden_states) 173 | 174 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 175 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 176 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 177 | 178 | kv_seq_len = key_states.shape[-2] 179 | if past_key_value is not None: 180 | kv_seq_len += cache_position[0] 181 | 182 | cos, sin = self.rotary_emb(value_states, position_ids) 183 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 184 | 185 | dropout_rate = 0.0 if not self.training else self.attention_dropout 186 | 187 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 188 | if past_key_value is not None: 189 | # Activate slicing cache only if the config has a value `sliding_windows` attribute 190 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 191 | if ( 192 | getattr(self.config, "sliding_window", None) is not None 193 | and kv_seq_len > self.config.sliding_window 194 | and cache_has_contents 195 | ): 196 | slicing_tokens = 1 - self.config.sliding_window 197 | 198 | past_key = past_key_value[self.layer_idx][0] 199 | past_value = past_key_value[self.layer_idx][1] 200 | 201 | past_key = past_key[:, :, slicing_tokens:, :].contiguous() 202 | past_value = past_value[:, :, slicing_tokens:, :].contiguous() 203 | 204 | if past_key.shape[-2] != self.config.sliding_window - 1: 205 | raise ValueError( 206 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" 207 | f" {past_key.shape}" 208 | ) 209 | 210 | if attention_mask is not None: 211 | attention_mask = attention_mask[:, slicing_tokens:] 212 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) 213 | 214 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 215 | # therefore the input hidden states gets silently casted in float32. Hence, we need 216 | # cast them back in float16 just to be sure everything works as expected. 217 | input_dtype = query_states.dtype 218 | if input_dtype == torch.float32: 219 | if torch.is_autocast_enabled(): 220 | target_dtype = torch.get_autocast_gpu_dtype() 221 | # Handle the case where the model is quantized 222 | elif hasattr(self.config, "_pre_quantization_dtype"): 223 | target_dtype = self.config._pre_quantization_dtype 224 | else: 225 | target_dtype = self.q_proj.weight.dtype 226 | 227 | logger.warning_once( 228 | f"The input hidden states seems to be silently casted in float32, this might be related to" 229 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 230 | f" {target_dtype}." 231 | ) 232 | 233 | query_states = query_states.to(target_dtype) 234 | key_states = key_states.to(target_dtype) 235 | value_states = value_states.to(target_dtype) 236 | 237 | # TODO: naive for now 238 | is_prefill = q_len != 1 239 | 240 | if is_prefill: 241 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states) 242 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) 243 | 244 | # repeat k/v heads if n_kv_heads < n_heads 245 | # [SnapKV] move to ahead 246 | key_states = repeat_kv(key_states, self.num_key_value_groups) 247 | value_states = repeat_kv(value_states, self.num_key_value_groups) 248 | 249 | # Reashape to the expected shape for Flash Attention 250 | query_states = query_states.transpose(1, 2) 251 | key_states = key_states.transpose(1, 2) 252 | value_states = value_states.transpose(1, 2) 253 | 254 | attn_output = _flash_attention_forward( 255 | query_states, 256 | key_states, 257 | value_states, 258 | attention_mask, 259 | q_len, 260 | position_ids=position_ids, 261 | dropout=dropout_rate, 262 | sliding_window=getattr(self.config, "sliding_window", None), 263 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 264 | is_causal=self.is_causal, 265 | ) 266 | 267 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() 268 | 269 | else: 270 | # decoding 271 | cache_kwargs["head_lens"] = self.kv_cluster.head_lens 272 | cache_kwargs["cu_klen"] = self.kv_cluster.cu_klen 273 | # gqa_support 274 | if not self.kv_cluster.gqa_support: 275 | key_states = repeat_kv(key_states, self.num_key_value_groups) 276 | value_states = repeat_kv(value_states, self.num_key_value_groups) 277 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 278 | 279 | 280 | # NOTE: update meta data 281 | self.kv_cluster.klen_sum += self.num_heads 282 | self.kv_cluster.max_seqlen_k += 1 283 | self.kv_cluster.cu_klen += self.kv_cluster.cu_offset 284 | self.kv_cluster.head_lens += 1 285 | 286 | if self.kv_cluster.gqa_support: 287 | query_states = query_states.view(-1, self.num_key_value_groups, self.head_dim) 288 | else: 289 | query_states = query_states.view(-1, 1, self.head_dim) 290 | 291 | key_states = key_states.view(-1,1,self.head_dim) 292 | value_states = value_states.view(-1,1,self.head_dim) 293 | 294 | cu_seqlens_q = self.kv_cluster.cu_qlen 295 | cu_seqlens_k = self.kv_cluster.cu_klen 296 | max_seqlen_q = 1 297 | max_seqlen_k = self.kv_cluster.max_seqlen_k 298 | 299 | attn_output = flash_attn_varlen_func(query_states, key_states, value_states, cu_seqlens_q, 300 | cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True) 301 | # TODO: support batch size > 1 302 | assert bsz == 1 303 | attn_output = attn_output.reshape(bsz, self.num_heads, q_len, self.head_dim) 304 | attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) 305 | 306 | attn_output = self.o_proj(attn_output) 307 | 308 | if not output_attentions: 309 | attn_weights = None 310 | 311 | return attn_output, attn_weights, past_key_value 312 | 313 | def prepare_inputs_for_generation_mistral( 314 | self, 315 | input_ids, 316 | past_key_values=None, 317 | attention_mask=None, 318 | inputs_embeds=None, 319 | cache_position=None, 320 | position_ids=None, 321 | use_cache=True, 322 | **kwargs, 323 | ): 324 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 325 | # Exception 1: when passing input_embeds, input_ids may be missing entries 326 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 327 | if past_key_values is not None: 328 | if inputs_embeds is not None: # Exception 1 329 | input_ids = input_ids[:, -cache_position.shape[0] :] 330 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) 331 | input_ids = input_ids[:, cache_position] 332 | 333 | if attention_mask is not None and position_ids is None: 334 | # create position_ids on the fly for batch generation 335 | position_ids = attention_mask.long().cumsum(-1) - 1 336 | position_ids.masked_fill_(attention_mask == 0, 1) 337 | if past_key_values: 338 | position_ids = position_ids[:, -input_ids.shape[1] :] 339 | 340 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. 341 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 342 | 343 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 344 | if inputs_embeds is not None and cache_position[0] == 0: 345 | model_inputs = {"inputs_embeds": inputs_embeds} 346 | else: 347 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases 348 | 349 | model_inputs.update( 350 | { 351 | "position_ids": position_ids, 352 | "cache_position": cache_position, 353 | "past_key_values": past_key_values, 354 | "use_cache": use_cache, 355 | "attention_mask": attention_mask, 356 | } 357 | ) 358 | return model_inputs 359 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/fixed_llama_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List, Optional, Tuple, Union 5 | import warnings 6 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 7 | from transformers.modeling_outputs import BaseModelOutputWithPast 8 | from transformers.models.llama.modeling_llama import ( 9 | apply_rotary_pos_emb, 10 | repeat_kv, 11 | ) 12 | from transformers.utils import ( 13 | logging, 14 | ) 15 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 16 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_snapkv 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | def fixed_LlamaModel_forward( 21 | self, 22 | input_ids: torch.LongTensor = None, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | position_ids: Optional[torch.LongTensor] = None, 25 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 26 | inputs_embeds: Optional[torch.FloatTensor] = None, 27 | use_cache: Optional[bool] = None, 28 | output_attentions: Optional[bool] = None, 29 | output_hidden_states: Optional[bool] = None, 30 | return_dict: Optional[bool] = None, 31 | cache_position: Optional[torch.LongTensor] = None, 32 | ) -> Union[Tuple, BaseModelOutputWithPast]: 33 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 34 | output_hidden_states = ( 35 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 36 | ) 37 | use_cache = use_cache if use_cache is not None else self.config.use_cache 38 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 39 | 40 | if (input_ids is None) ^ (inputs_embeds is not None): 41 | raise ValueError( 42 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 43 | ) 44 | 45 | if self.gradient_checkpointing and self.training and use_cache: 46 | logger.warning_once( 47 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 48 | ) 49 | use_cache = False 50 | 51 | if inputs_embeds is None: 52 | inputs_embeds = self.embed_tokens(input_ids) 53 | 54 | return_legacy_cache = False 55 | if ( 56 | use_cache and not isinstance(past_key_values, Cache) and not self.training 57 | ): # kept for BC (non `Cache` `past_key_values` inputs) 58 | return_legacy_cache = True 59 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 60 | logger.warning_once( 61 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " 62 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" 63 | ) 64 | 65 | if cache_position is None: 66 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 67 | cache_position = torch.arange( 68 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 69 | ) 70 | if position_ids is None: 71 | position_ids = cache_position.unsqueeze(0) 72 | 73 | causal_mask = self._update_causal_mask( 74 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions 75 | ) 76 | hidden_states = inputs_embeds 77 | 78 | # create position embeddings to be shared across the decoder layers 79 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 80 | 81 | # decoder layers 82 | all_hidden_states = () if output_hidden_states else None 83 | all_self_attns = () if output_attentions else None 84 | next_decoder_cache = None 85 | 86 | for decoder_layer in self.layers: 87 | if output_hidden_states: 88 | all_hidden_states += (hidden_states,) 89 | 90 | if self.gradient_checkpointing and self.training: 91 | layer_outputs = self._gradient_checkpointing_func( 92 | decoder_layer.__call__, 93 | hidden_states, 94 | causal_mask, 95 | position_ids, 96 | past_key_values, 97 | output_attentions, 98 | use_cache, 99 | cache_position, 100 | position_embeddings, 101 | ) 102 | else: 103 | layer_outputs = decoder_layer( 104 | hidden_states, 105 | attention_mask=causal_mask, 106 | position_ids=position_ids, 107 | past_key_value=past_key_values, 108 | output_attentions=output_attentions, 109 | use_cache=use_cache, 110 | cache_position=cache_position, 111 | position_embeddings=position_embeddings, 112 | ) 113 | 114 | hidden_states = layer_outputs[0] 115 | 116 | if use_cache: 117 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 118 | 119 | if output_attentions: 120 | all_self_attns += (layer_outputs[1],) 121 | 122 | hidden_states = self.norm(hidden_states) 123 | 124 | # add hidden states from the last decoder layer 125 | if output_hidden_states: 126 | all_hidden_states += (hidden_states,) 127 | 128 | next_cache = next_decoder_cache if use_cache else None 129 | if return_legacy_cache: 130 | next_cache = next_cache.to_legacy_cache() 131 | 132 | hidden_states = hidden_states[:, -1,:].unsqueeze(1) 133 | 134 | if not return_dict: 135 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 136 | return BaseModelOutputWithPast( 137 | last_hidden_state=hidden_states, 138 | past_key_values=next_cache, 139 | hidden_states=all_hidden_states, 140 | attentions=all_self_attns, 141 | ) 142 | 143 | 144 | def fixed_llama_flash_attn2_forward( 145 | self, 146 | hidden_states: torch.Tensor, 147 | attention_mask: Optional[torch.LongTensor] = None, 148 | position_ids: Optional[torch.LongTensor] = None, 149 | past_key_value: Optional[Cache] = None, 150 | output_attentions: bool = False, 151 | use_cache: bool = False, 152 | cache_position: Optional[torch.LongTensor] = None, 153 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 154 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 155 | if isinstance(past_key_value, StaticCache): 156 | raise ValueError( 157 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " 158 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" 159 | ) 160 | init_snapkv(self) 161 | 162 | output_attentions = False 163 | 164 | bsz, q_len, _ = hidden_states.size() 165 | 166 | query_states = self.q_proj(hidden_states) 167 | key_states = self.k_proj(hidden_states) 168 | value_states = self.v_proj(hidden_states) 169 | 170 | # Flash attention requires the input to have the shape 171 | # batch_size x seq_length x head_dim x hidden_dim 172 | # therefore we just need to keep the original shape 173 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 174 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 175 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 176 | 177 | if position_embeddings is None: 178 | logger.warning_once( 179 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 180 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 181 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " 182 | "removed and `position_embeddings` will be mandatory." 183 | ) 184 | cos, sin = self.rotary_emb(value_states, position_ids) 185 | else: 186 | cos, sin = position_embeddings 187 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 188 | 189 | 190 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 191 | if past_key_value is not None: 192 | # NOTE: decoding update 193 | if q_len == 1: 194 | # support gqa 195 | if not self.kv_cluster.gqa_support: 196 | key_states = repeat_kv(key_states, self.num_key_value_groups) 197 | value_states = repeat_kv(value_states, self.num_key_value_groups) 198 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 199 | else: 200 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states) 201 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) 202 | 203 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache 204 | # to be able to avoid many of these transpose/reshape/view. 205 | 206 | dropout_rate = self.attention_dropout if self.training else 0.0 207 | 208 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 209 | # therefore the input hidden states gets silently casted in float32. Hence, we need 210 | # cast them back in the correct dtype just to be sure everything works as expected. 211 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 212 | # in fp32. (LlamaRMSNorm handles it correctly) 213 | 214 | input_dtype = query_states.dtype 215 | if input_dtype == torch.float32: 216 | if torch.is_autocast_enabled(): 217 | target_dtype = torch.get_autocast_gpu_dtype() 218 | # Handle the case where the model is quantized 219 | elif hasattr(self.config, "_pre_quantization_dtype"): 220 | target_dtype = self.config._pre_quantization_dtype 221 | else: 222 | target_dtype = self.q_proj.weight.dtype 223 | 224 | logger.warning_once( 225 | f"The input hidden states seems to be silently casted in float32, this might be related to" 226 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 227 | f" {target_dtype}." 228 | ) 229 | 230 | query_states = query_states.to(target_dtype) 231 | key_states = key_states.to(target_dtype) 232 | value_states = value_states.to(target_dtype) 233 | 234 | # Reashape to the expected shape for Flash Attention 235 | query_states = query_states.transpose(1, 2) 236 | key_states = key_states.transpose(1, 2) 237 | value_states = value_states.transpose(1, 2) 238 | 239 | attn_output = _flash_attention_forward( 240 | query_states, 241 | key_states, 242 | value_states, 243 | attention_mask, 244 | q_len, 245 | position_ids=position_ids, 246 | dropout=dropout_rate, 247 | sliding_window=getattr(self, "sliding_window", None), 248 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 249 | is_causal=self.is_causal, 250 | ) 251 | 252 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 253 | attn_output = self.o_proj(attn_output) 254 | 255 | if not output_attentions: 256 | attn_weights = None 257 | 258 | return attn_output, attn_weights, past_key_value 259 | 260 | def prepare_inputs_for_generation_llama( 261 | self, 262 | input_ids, 263 | past_key_values=None, 264 | attention_mask=None, 265 | inputs_embeds=None, 266 | cache_position=None, 267 | position_ids=None, 268 | use_cache=True, 269 | **kwargs, 270 | ): 271 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 272 | # Exception 1: when passing input_embeds, input_ids may be missing entries 273 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 274 | if past_key_values is not None: 275 | if inputs_embeds is not None: # Exception 1 276 | input_ids = input_ids[:, -cache_position.shape[0] :] 277 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) 278 | input_ids = input_ids[:, cache_position] 279 | 280 | if attention_mask is not None and position_ids is None: 281 | # create position_ids on the fly for batch generation 282 | position_ids = attention_mask.long().cumsum(-1) - 1 283 | position_ids.masked_fill_(attention_mask == 0, 1) 284 | if past_key_values: 285 | position_ids = position_ids[:, -input_ids.shape[1] :] 286 | 287 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. 288 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 289 | 290 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 291 | if inputs_embeds is not None and cache_position[0] == 0: 292 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} 293 | else: 294 | # The clone here is for the same reason as for `position_ids`. 295 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} 296 | 297 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 298 | if model_inputs["inputs_embeds"] is not None: 299 | batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape 300 | device = model_inputs["inputs_embeds"].device 301 | else: 302 | batch_size, sequence_length = model_inputs["input_ids"].shape 303 | device = model_inputs["input_ids"].device 304 | 305 | dtype = self.lm_head.weight.dtype 306 | min_dtype = torch.finfo(dtype).min 307 | 308 | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( 309 | attention_mask, 310 | sequence_length=sequence_length, 311 | target_length=past_key_values.get_max_length(), 312 | dtype=dtype, 313 | device=device, 314 | min_dtype=min_dtype, 315 | cache_position=cache_position, 316 | batch_size=batch_size, 317 | ) 318 | 319 | model_inputs.update( 320 | { 321 | "position_ids": position_ids, 322 | "cache_position": cache_position, 323 | "past_key_values": past_key_values, 324 | "use_cache": use_cache, 325 | "attention_mask": attention_mask, 326 | } 327 | ) 328 | return model_inputs 329 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/fixed_mistral_hijack.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import torch 3 | from typing import List, Optional, Tuple, Union 4 | import warnings 5 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 6 | from transformers.modeling_outputs import BaseModelOutputWithPast 7 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 8 | from transformers.models.mistral.modeling_mistral import ( 9 | apply_rotary_pos_emb, 10 | repeat_kv, 11 | ) 12 | from transformers.utils import ( 13 | logging, 14 | is_flash_attn_2_available, 15 | ) 16 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_snapkv 17 | from flash_attn import flash_attn_func 18 | 19 | logger = logging.get_logger(__name__) 20 | 21 | if is_flash_attn_2_available(): 22 | from flash_attn import flash_attn_func, flash_attn_varlen_func 23 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 24 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 25 | 26 | def fixed_MistralModel_forward( 27 | self, 28 | input_ids: torch.LongTensor = None, 29 | attention_mask: Optional[torch.Tensor] = None, 30 | position_ids: Optional[torch.LongTensor] = None, 31 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 32 | inputs_embeds: Optional[torch.FloatTensor] = None, 33 | use_cache: Optional[bool] = None, 34 | output_attentions: Optional[bool] = None, 35 | output_hidden_states: Optional[bool] = None, 36 | return_dict: Optional[bool] = None, 37 | cache_position: Optional[torch.LongTensor] = None, 38 | ) -> Union[Tuple, BaseModelOutputWithPast]: 39 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 40 | output_hidden_states = ( 41 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 42 | ) 43 | use_cache = use_cache if use_cache is not None else self.config.use_cache 44 | 45 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 46 | 47 | # retrieve input_ids and inputs_embeds 48 | if (input_ids is None) ^ (inputs_embeds is not None): 49 | raise ValueError( 50 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 51 | ) 52 | 53 | if self.gradient_checkpointing and self.training and use_cache: 54 | logger.warning_once( 55 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 56 | ) 57 | use_cache = False 58 | 59 | if inputs_embeds is None: 60 | inputs_embeds = self.embed_tokens(input_ids) 61 | 62 | return_legacy_cache = False 63 | if use_cache and not isinstance(past_key_values, Cache) and not self.training: 64 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 65 | return_legacy_cache = True 66 | logger.warning_once( 67 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " 68 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" 69 | ) 70 | 71 | if cache_position is None: 72 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 73 | cache_position = torch.arange( 74 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 75 | ) 76 | 77 | if position_ids is None: 78 | position_ids = cache_position.unsqueeze(0) 79 | 80 | causal_mask = self._update_causal_mask( 81 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions 82 | ) 83 | 84 | hidden_states = inputs_embeds 85 | 86 | # decoder layers 87 | all_hidden_states = () if output_hidden_states else None 88 | all_self_attns = () if output_attentions else None 89 | next_decoder_cache = None 90 | 91 | for decoder_layer in self.layers: 92 | if output_hidden_states: 93 | all_hidden_states += (hidden_states,) 94 | 95 | if self.gradient_checkpointing and self.training: 96 | layer_outputs = self._gradient_checkpointing_func( 97 | decoder_layer.__call__, 98 | hidden_states, 99 | causal_mask, 100 | position_ids, 101 | past_key_values, 102 | output_attentions, 103 | use_cache, 104 | cache_position, 105 | ) 106 | else: 107 | layer_outputs = decoder_layer( 108 | hidden_states, 109 | attention_mask=causal_mask, 110 | position_ids=position_ids, 111 | past_key_value=past_key_values, 112 | output_attentions=output_attentions, 113 | use_cache=use_cache, 114 | cache_position=cache_position, 115 | ) 116 | 117 | hidden_states = layer_outputs[0] 118 | 119 | if use_cache: 120 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 121 | 122 | if output_attentions: 123 | all_self_attns += (layer_outputs[1],) 124 | 125 | hidden_states = self.norm(hidden_states) 126 | 127 | # add hidden states from the last decoder layer 128 | if output_hidden_states: 129 | all_hidden_states += (hidden_states,) 130 | 131 | next_cache = next_decoder_cache if use_cache else None 132 | if return_legacy_cache: 133 | next_cache = next_cache.to_legacy_cache() 134 | 135 | if not return_dict: 136 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 137 | 138 | hidden_states = hidden_states[:, -1,:].unsqueeze(1) 139 | return BaseModelOutputWithPast( 140 | last_hidden_state=hidden_states, 141 | past_key_values=next_cache, 142 | hidden_states=all_hidden_states, 143 | attentions=all_self_attns, 144 | ) 145 | 146 | def fixed_mistral_flash_attn2_forward( 147 | self, 148 | hidden_states: torch.Tensor, 149 | attention_mask: Optional[torch.Tensor] = None, 150 | position_ids: Optional[torch.LongTensor] = None, 151 | past_key_value: Optional[Cache] = None, 152 | output_attentions: bool = False, 153 | use_cache: bool = False, 154 | cache_position: Optional[torch.LongTensor] = None, 155 | ): 156 | if isinstance(past_key_value, StaticCache): 157 | raise ValueError( 158 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " 159 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" 160 | ) 161 | # [SnapKV] register kv_cluster 162 | init_snapkv(self) 163 | output_attentions = False 164 | 165 | output_attentions = False 166 | 167 | bsz, q_len, _ = hidden_states.size() 168 | 169 | query_states = self.q_proj(hidden_states) 170 | key_states = self.k_proj(hidden_states) 171 | value_states = self.v_proj(hidden_states) 172 | 173 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 174 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 175 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 176 | 177 | kv_seq_len = key_states.shape[-2] 178 | if past_key_value is not None: 179 | kv_seq_len += cache_position[0] 180 | 181 | cos, sin = self.rotary_emb(value_states, position_ids) 182 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 183 | 184 | dropout_rate = 0.0 if not self.training else self.attention_dropout 185 | 186 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 187 | if past_key_value is not None: 188 | # Activate slicing cache only if the config has a value `sliding_windows` attribute 189 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 190 | if ( 191 | getattr(self.config, "sliding_window", None) is not None 192 | and kv_seq_len > self.config.sliding_window 193 | and cache_has_contents 194 | ): 195 | slicing_tokens = 1 - self.config.sliding_window 196 | 197 | past_key = past_key_value[self.layer_idx][0] 198 | past_value = past_key_value[self.layer_idx][1] 199 | 200 | past_key = past_key[:, :, slicing_tokens:, :].contiguous() 201 | past_value = past_value[:, :, slicing_tokens:, :].contiguous() 202 | 203 | if past_key.shape[-2] != self.config.sliding_window - 1: 204 | raise ValueError( 205 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" 206 | f" {past_key.shape}" 207 | ) 208 | 209 | if attention_mask is not None: 210 | attention_mask = attention_mask[:, slicing_tokens:] 211 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) 212 | 213 | if q_len == 1: 214 | # support gqa 215 | if not self.kv_cluster.gqa_support: 216 | key_states = repeat_kv(key_states, self.num_key_value_groups) 217 | value_states = repeat_kv(value_states, self.num_key_value_groups) 218 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 219 | else: 220 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states) 221 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) 222 | 223 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 224 | # therefore the input hidden states gets silently casted in float32. Hence, we need 225 | # cast them back in float16 just to be sure everything works as expected. 226 | input_dtype = query_states.dtype 227 | if input_dtype == torch.float32: 228 | if torch.is_autocast_enabled(): 229 | target_dtype = torch.get_autocast_gpu_dtype() 230 | # Handle the case where the model is quantized 231 | elif hasattr(self.config, "_pre_quantization_dtype"): 232 | target_dtype = self.config._pre_quantization_dtype 233 | else: 234 | target_dtype = self.q_proj.weight.dtype 235 | 236 | logger.warning_once( 237 | f"The input hidden states seems to be silently casted in float32, this might be related to" 238 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 239 | f" {target_dtype}." 240 | ) 241 | 242 | query_states = query_states.to(target_dtype) 243 | key_states = key_states.to(target_dtype) 244 | value_states = value_states.to(target_dtype) 245 | 246 | # Reashape to the expected shape for Flash Attention 247 | query_states = query_states.transpose(1, 2) 248 | key_states = key_states.transpose(1, 2) 249 | value_states = value_states.transpose(1, 2) 250 | 251 | attn_output = _flash_attention_forward( 252 | query_states, 253 | key_states, 254 | value_states, 255 | attention_mask, 256 | q_len, 257 | position_ids=position_ids, 258 | dropout=dropout_rate, 259 | sliding_window=getattr(self.config, "sliding_window", None), 260 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 261 | is_causal=self.is_causal, 262 | ) 263 | 264 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() 265 | attn_output = self.o_proj(attn_output) 266 | 267 | if not output_attentions: 268 | attn_weights = None 269 | 270 | return attn_output, attn_weights, past_key_value 271 | 272 | 273 | def prepare_inputs_for_generation_mistral( 274 | self, 275 | input_ids, 276 | past_key_values=None, 277 | attention_mask=None, 278 | inputs_embeds=None, 279 | cache_position=None, 280 | position_ids=None, 281 | use_cache=True, 282 | **kwargs, 283 | ): 284 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 285 | # Exception 1: when passing input_embeds, input_ids may be missing entries 286 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 287 | if past_key_values is not None: 288 | if inputs_embeds is not None: # Exception 1 289 | input_ids = input_ids[:, -cache_position.shape[0] :] 290 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) 291 | input_ids = input_ids[:, cache_position] 292 | 293 | if attention_mask is not None and position_ids is None: 294 | # create position_ids on the fly for batch generation 295 | position_ids = attention_mask.long().cumsum(-1) - 1 296 | position_ids.masked_fill_(attention_mask == 0, 1) 297 | if past_key_values: 298 | position_ids = position_ids[:, -input_ids.shape[1] :] 299 | 300 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. 301 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 302 | 303 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 304 | if inputs_embeds is not None and cache_position[0] == 0: 305 | model_inputs = {"inputs_embeds": inputs_embeds} 306 | else: 307 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases 308 | 309 | model_inputs.update( 310 | { 311 | "position_ids": position_ids, 312 | "cache_position": cache_position, 313 | "past_key_values": past_key_values, 314 | "use_cache": use_cache, 315 | "attention_mask": attention_mask, 316 | } 317 | ) 318 | return model_inputs 319 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/monkeypatch.py: -------------------------------------------------------------------------------- 1 | from importlib.metadata import version 2 | import warnings 3 | import transformers 4 | import transformers.models.mistral.modeling_mistral 5 | from adaptive_snapkv.monkeypatch.fixed_mistral_hijack import fixed_mistral_flash_attn2_forward, fixed_MistralModel_forward 6 | from adaptive_snapkv.monkeypatch.fixed_mistral_hijack import prepare_inputs_for_generation_mistral as fixed_prepare_inputs_for_generation_mistral 7 | from adaptive_snapkv.monkeypatch.adaptive_mistral_hijack import adaptive_mistral_flash_attn2_forward,adaptive_MistralModel_forward 8 | from adaptive_snapkv.monkeypatch.adaptive_mistral_hijack import prepare_inputs_for_generation_mistral as ada_prepare_inputs_for_generation_mistral 9 | 10 | from adaptive_snapkv.monkeypatch.fixed_llama_hijack import fixed_llama_flash_attn2_forward, fixed_LlamaModel_forward 11 | from adaptive_snapkv.monkeypatch.fixed_llama_hijack import prepare_inputs_for_generation_llama as fixed_prepare_inputs_for_generation_llama 12 | from adaptive_snapkv.monkeypatch.adaptive_llama_hijack import adaptive_llama_flash_attn2_forward,adaptive_LlamaModel_forward 13 | from adaptive_snapkv.monkeypatch.adaptive_llama_hijack import prepare_inputs_for_generation_llama as ada_prepare_inputs_for_generation_llama 14 | 15 | from adaptive_snapkv.monkeypatch.slm_llama_hijack import slm_llama_flash_attn2_forward, slm_LlamaModel_forward 16 | from adaptive_snapkv.monkeypatch.slm_mistral_hijack import slm_mistral_flash_attn2_forward, slm_MistralModel_forward 17 | 18 | def check_version(): 19 | try: 20 | transformers_version = version("transformers") 21 | except Exception as e: 22 | print(f"Transformers not installed: {e}") 23 | version_list = ['4.37'] 24 | warning_flag = True 25 | for x in version_list: 26 | if x in transformers_version: 27 | warning_flag = False 28 | break 29 | if warning_flag: 30 | warnings.warn(f"Transformers version {transformers_version} might not be compatible with SnapKV. SnapKV is tested with Transformers version {version_list}.") 31 | 32 | 33 | # config hyperparameters 34 | def config_compress(model, window_size=32, base_capacity=1024, kernel_size=7, pooling="maxpool", floor_alpha=0.5, pyram_mode = False, beta = 20, skip=0, gqa_support=False,gqa_func="mean"): 35 | model.model.config.window_size = window_size 36 | model.model.config.base_capacity = base_capacity 37 | model.model.config.kernel_size = kernel_size 38 | 39 | model.model.config.pooling = pooling 40 | model.model.config.floor_alpha = floor_alpha 41 | model.model.config.skip = skip 42 | model.model.config.normalize = None 43 | 44 | model.model.config.pyram_mode = pyram_mode 45 | model.model.config.pyram_beta = beta 46 | 47 | model.model.config.gqa_support = gqa_support 48 | model.model.config.gqa_func = gqa_func 49 | 50 | return model 51 | 52 | 53 | def replace_mistral_fixed(): 54 | check_version() 55 | transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral 56 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = fixed_mistral_flash_attn2_forward 57 | transformers.models.mistral.modeling_mistral.MistralModel.forward = fixed_MistralModel_forward 58 | 59 | def replace_mistral_adaptive(): 60 | check_version() 61 | transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_mistral 62 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = adaptive_mistral_flash_attn2_forward 63 | transformers.models.mistral.modeling_mistral.MistralModel.forward = adaptive_MistralModel_forward 64 | 65 | def replace_llama_fixed(): 66 | check_version() 67 | transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama 68 | transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = fixed_llama_flash_attn2_forward 69 | transformers.models.llama.modeling_llama.LlamaModel.forward = fixed_LlamaModel_forward 70 | 71 | def replace_llama_adaptive(): 72 | check_version() 73 | transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = ada_prepare_inputs_for_generation_llama 74 | transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = adaptive_llama_flash_attn2_forward 75 | transformers.models.llama.modeling_llama.LlamaModel.forward = adaptive_LlamaModel_forward 76 | 77 | 78 | def replace_llama_slm(): 79 | check_version() 80 | transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_llama 81 | transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward = slm_llama_flash_attn2_forward 82 | transformers.models.llama.modeling_llama.LlamaModel.forward = slm_LlamaModel_forward 83 | 84 | def replace_mistral_slm(): 85 | check_version() 86 | transformers.models.mistral.modeling_mistral.MistralForCausalLM.prepare_inputs_for_generation = fixed_prepare_inputs_for_generation_mistral 87 | transformers.models.mistral.modeling_mistral.MistralFlashAttention2.forward = slm_mistral_flash_attn2_forward 88 | transformers.models.mistral.modeling_mistral.MistralModel.forward = slm_MistralModel_forward 89 | 90 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/slm_llama_hijack.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import List, Optional, Tuple, Union 5 | import warnings 6 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 7 | from transformers.modeling_outputs import BaseModelOutputWithPast 8 | from transformers.models.llama.modeling_llama import ( 9 | apply_rotary_pos_emb, 10 | repeat_kv, 11 | ) 12 | from transformers.utils import ( 13 | logging, 14 | ) 15 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 16 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_slm 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | def slm_LlamaModel_forward( 21 | self, 22 | input_ids: torch.LongTensor = None, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | position_ids: Optional[torch.LongTensor] = None, 25 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 26 | inputs_embeds: Optional[torch.FloatTensor] = None, 27 | use_cache: Optional[bool] = None, 28 | output_attentions: Optional[bool] = None, 29 | output_hidden_states: Optional[bool] = None, 30 | return_dict: Optional[bool] = None, 31 | cache_position: Optional[torch.LongTensor] = None, 32 | ) -> Union[Tuple, BaseModelOutputWithPast]: 33 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 34 | output_hidden_states = ( 35 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 36 | ) 37 | use_cache = use_cache if use_cache is not None else self.config.use_cache 38 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 39 | 40 | if (input_ids is None) ^ (inputs_embeds is not None): 41 | raise ValueError( 42 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 43 | ) 44 | 45 | if self.gradient_checkpointing and self.training and use_cache: 46 | logger.warning_once( 47 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." 48 | ) 49 | use_cache = False 50 | 51 | if inputs_embeds is None: 52 | inputs_embeds = self.embed_tokens(input_ids) 53 | 54 | return_legacy_cache = False 55 | if ( 56 | use_cache and not isinstance(past_key_values, Cache) and not self.training 57 | ): # kept for BC (non `Cache` `past_key_values` inputs) 58 | return_legacy_cache = True 59 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 60 | logger.warning_once( 61 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " 62 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" 63 | ) 64 | 65 | if cache_position is None: 66 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 67 | cache_position = torch.arange( 68 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 69 | ) 70 | if position_ids is None: 71 | position_ids = cache_position.unsqueeze(0) 72 | 73 | causal_mask = self._update_causal_mask( 74 | attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions 75 | ) 76 | hidden_states = inputs_embeds 77 | 78 | # create position embeddings to be shared across the decoder layers 79 | position_embeddings = self.rotary_emb(hidden_states, position_ids) 80 | 81 | # decoder layers 82 | all_hidden_states = () if output_hidden_states else None 83 | all_self_attns = () if output_attentions else None 84 | next_decoder_cache = None 85 | 86 | for decoder_layer in self.layers: 87 | if output_hidden_states: 88 | all_hidden_states += (hidden_states,) 89 | 90 | if self.gradient_checkpointing and self.training: 91 | layer_outputs = self._gradient_checkpointing_func( 92 | decoder_layer.__call__, 93 | hidden_states, 94 | causal_mask, 95 | position_ids, 96 | past_key_values, 97 | output_attentions, 98 | use_cache, 99 | cache_position, 100 | position_embeddings, 101 | ) 102 | else: 103 | layer_outputs = decoder_layer( 104 | hidden_states, 105 | attention_mask=causal_mask, 106 | position_ids=position_ids, 107 | past_key_value=past_key_values, 108 | output_attentions=output_attentions, 109 | use_cache=use_cache, 110 | cache_position=cache_position, 111 | position_embeddings=position_embeddings, 112 | ) 113 | 114 | hidden_states = layer_outputs[0] 115 | 116 | if use_cache: 117 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 118 | 119 | if output_attentions: 120 | all_self_attns += (layer_outputs[1],) 121 | 122 | hidden_states = self.norm(hidden_states) 123 | 124 | # add hidden states from the last decoder layer 125 | if output_hidden_states: 126 | all_hidden_states += (hidden_states,) 127 | 128 | next_cache = next_decoder_cache if use_cache else None 129 | if return_legacy_cache: 130 | next_cache = next_cache.to_legacy_cache() 131 | 132 | hidden_states = hidden_states[:, -1,:].unsqueeze(1) 133 | 134 | if not return_dict: 135 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 136 | return BaseModelOutputWithPast( 137 | last_hidden_state=hidden_states, 138 | past_key_values=next_cache, 139 | hidden_states=all_hidden_states, 140 | attentions=all_self_attns, 141 | ) 142 | 143 | 144 | def slm_llama_flash_attn2_forward( 145 | self, 146 | hidden_states: torch.Tensor, 147 | attention_mask: Optional[torch.LongTensor] = None, 148 | position_ids: Optional[torch.LongTensor] = None, 149 | past_key_value: Optional[Cache] = None, 150 | output_attentions: bool = False, 151 | use_cache: bool = False, 152 | cache_position: Optional[torch.LongTensor] = None, 153 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 154 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 155 | if isinstance(past_key_value, StaticCache): 156 | raise ValueError( 157 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " 158 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" 159 | ) 160 | init_slm(self) 161 | 162 | output_attentions = False 163 | 164 | bsz, q_len, _ = hidden_states.size() 165 | 166 | query_states = self.q_proj(hidden_states) 167 | key_states = self.k_proj(hidden_states) 168 | value_states = self.v_proj(hidden_states) 169 | 170 | # Flash attention requires the input to have the shape 171 | # batch_size x seq_length x head_dim x hidden_dim 172 | # therefore we just need to keep the original shape 173 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 174 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 175 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 176 | 177 | if position_embeddings is None: 178 | logger.warning_once( 179 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 180 | "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " 181 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " 182 | "removed and `position_embeddings` will be mandatory." 183 | ) 184 | cos, sin = self.rotary_emb(value_states, position_ids) 185 | else: 186 | cos, sin = position_embeddings 187 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 188 | 189 | 190 | cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} 191 | if past_key_value is not None: 192 | # NOTE: decoding update 193 | if q_len == 1: 194 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 195 | else: 196 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states) 197 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) 198 | 199 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache 200 | # to be able to avoid many of these transpose/reshape/view. 201 | 202 | dropout_rate = self.attention_dropout if self.training else 0.0 203 | 204 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 205 | # therefore the input hidden states gets silently casted in float32. Hence, we need 206 | # cast them back in the correct dtype just to be sure everything works as expected. 207 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 208 | # in fp32. (LlamaRMSNorm handles it correctly) 209 | 210 | input_dtype = query_states.dtype 211 | if input_dtype == torch.float32: 212 | if torch.is_autocast_enabled(): 213 | target_dtype = torch.get_autocast_gpu_dtype() 214 | # Handle the case where the model is quantized 215 | elif hasattr(self.config, "_pre_quantization_dtype"): 216 | target_dtype = self.config._pre_quantization_dtype 217 | else: 218 | target_dtype = self.q_proj.weight.dtype 219 | 220 | logger.warning_once( 221 | f"The input hidden states seems to be silently casted in float32, this might be related to" 222 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 223 | f" {target_dtype}." 224 | ) 225 | 226 | query_states = query_states.to(target_dtype) 227 | key_states = key_states.to(target_dtype) 228 | value_states = value_states.to(target_dtype) 229 | 230 | # Reashape to the expected shape for Flash Attention 231 | query_states = query_states.transpose(1, 2) 232 | key_states = key_states.transpose(1, 2) 233 | value_states = value_states.transpose(1, 2) 234 | 235 | attn_output = _flash_attention_forward( 236 | query_states, 237 | key_states, 238 | value_states, 239 | attention_mask, 240 | q_len, 241 | position_ids=position_ids, 242 | dropout=dropout_rate, 243 | sliding_window=getattr(self, "sliding_window", None), 244 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 245 | is_causal=self.is_causal, 246 | ) 247 | 248 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 249 | attn_output = self.o_proj(attn_output) 250 | 251 | if not output_attentions: 252 | attn_weights = None 253 | 254 | return attn_output, attn_weights, past_key_value 255 | 256 | def prepare_inputs_for_generation_llama( 257 | self, 258 | input_ids, 259 | past_key_values=None, 260 | attention_mask=None, 261 | inputs_embeds=None, 262 | cache_position=None, 263 | position_ids=None, 264 | use_cache=True, 265 | **kwargs, 266 | ): 267 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 268 | # Exception 1: when passing input_embeds, input_ids may be missing entries 269 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 270 | if past_key_values is not None: 271 | if inputs_embeds is not None: # Exception 1 272 | input_ids = input_ids[:, -cache_position.shape[0] :] 273 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) 274 | input_ids = input_ids[:, cache_position] 275 | 276 | if attention_mask is not None and position_ids is None: 277 | # create position_ids on the fly for batch generation 278 | position_ids = attention_mask.long().cumsum(-1) - 1 279 | position_ids.masked_fill_(attention_mask == 0, 1) 280 | if past_key_values: 281 | position_ids = position_ids[:, -input_ids.shape[1] :] 282 | 283 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. 284 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 285 | 286 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 287 | if inputs_embeds is not None and cache_position[0] == 0: 288 | model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} 289 | else: 290 | # The clone here is for the same reason as for `position_ids`. 291 | model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} 292 | 293 | if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: 294 | if model_inputs["inputs_embeds"] is not None: 295 | batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape 296 | device = model_inputs["inputs_embeds"].device 297 | else: 298 | batch_size, sequence_length = model_inputs["input_ids"].shape 299 | device = model_inputs["input_ids"].device 300 | 301 | dtype = self.lm_head.weight.dtype 302 | min_dtype = torch.finfo(dtype).min 303 | 304 | attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( 305 | attention_mask, 306 | sequence_length=sequence_length, 307 | target_length=past_key_values.get_max_length(), 308 | dtype=dtype, 309 | device=device, 310 | min_dtype=min_dtype, 311 | cache_position=cache_position, 312 | batch_size=batch_size, 313 | ) 314 | 315 | model_inputs.update( 316 | { 317 | "position_ids": position_ids, 318 | "cache_position": cache_position, 319 | "past_key_values": past_key_values, 320 | "use_cache": use_cache, 321 | "attention_mask": attention_mask, 322 | } 323 | ) 324 | return model_inputs 325 | 326 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/slm_mistral_hijack.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from numpy import diff 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from typing import List, Optional, Tuple, Union 7 | import warnings 8 | from transformers.cache_utils import Cache, DynamicCache, StaticCache 9 | from transformers.modeling_outputs import BaseModelOutputWithPast 10 | from transformers.modeling_flash_attention_utils import _flash_attention_forward 11 | from transformers.models.mistral.modeling_mistral import ( 12 | apply_rotary_pos_emb, 13 | repeat_kv, 14 | ) 15 | from transformers.utils import ( 16 | logging, 17 | is_flash_attn_2_available, 18 | ) 19 | from adaptive_snapkv.monkeypatch.snapkv_utils import init_slm 20 | from flash_attn import flash_attn_func 21 | 22 | logger = logging.get_logger(__name__) 23 | 24 | if is_flash_attn_2_available(): 25 | from flash_attn import flash_attn_func, flash_attn_varlen_func 26 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 27 | _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) 28 | 29 | def slm_MistralModel_forward( 30 | self, 31 | input_ids: torch.LongTensor = None, 32 | attention_mask: Optional[torch.Tensor] = None, 33 | position_ids: Optional[torch.LongTensor] = None, 34 | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, 35 | inputs_embeds: Optional[torch.FloatTensor] = None, 36 | use_cache: Optional[bool] = None, 37 | output_attentions: Optional[bool] = None, 38 | output_hidden_states: Optional[bool] = None, 39 | return_dict: Optional[bool] = None, 40 | cache_position: Optional[torch.LongTensor] = None, 41 | ) -> Union[Tuple, BaseModelOutputWithPast]: 42 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 43 | output_hidden_states = ( 44 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 45 | ) 46 | use_cache = use_cache if use_cache is not None else self.config.use_cache 47 | 48 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 49 | 50 | # retrieve input_ids and inputs_embeds 51 | if (input_ids is None) ^ (inputs_embeds is not None): 52 | raise ValueError( 53 | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" 54 | ) 55 | 56 | if self.gradient_checkpointing and self.training and use_cache: 57 | logger.warning_once( 58 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 59 | ) 60 | use_cache = False 61 | 62 | if inputs_embeds is None: 63 | inputs_embeds = self.embed_tokens(input_ids) 64 | 65 | return_legacy_cache = False 66 | if use_cache and not isinstance(past_key_values, Cache) and not self.training: 67 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 68 | return_legacy_cache = True 69 | logger.warning_once( 70 | "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " 71 | "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" 72 | ) 73 | 74 | if cache_position is None: 75 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 76 | cache_position = torch.arange( 77 | past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device 78 | ) 79 | 80 | if position_ids is None: 81 | position_ids = cache_position.unsqueeze(0) 82 | 83 | causal_mask = self._update_causal_mask( 84 | attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions 85 | ) 86 | 87 | hidden_states = inputs_embeds 88 | 89 | # decoder layers 90 | all_hidden_states = () if output_hidden_states else None 91 | all_self_attns = () if output_attentions else None 92 | next_decoder_cache = None 93 | 94 | for decoder_layer in self.layers: 95 | if output_hidden_states: 96 | all_hidden_states += (hidden_states,) 97 | 98 | if self.gradient_checkpointing and self.training: 99 | layer_outputs = self._gradient_checkpointing_func( 100 | decoder_layer.__call__, 101 | hidden_states, 102 | causal_mask, 103 | position_ids, 104 | past_key_values, 105 | output_attentions, 106 | use_cache, 107 | cache_position, 108 | ) 109 | else: 110 | layer_outputs = decoder_layer( 111 | hidden_states, 112 | attention_mask=causal_mask, 113 | position_ids=position_ids, 114 | past_key_value=past_key_values, 115 | output_attentions=output_attentions, 116 | use_cache=use_cache, 117 | cache_position=cache_position, 118 | ) 119 | 120 | hidden_states = layer_outputs[0] 121 | 122 | if use_cache: 123 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 124 | 125 | if output_attentions: 126 | all_self_attns += (layer_outputs[1],) 127 | 128 | hidden_states = self.norm(hidden_states) 129 | 130 | # add hidden states from the last decoder layer 131 | if output_hidden_states: 132 | all_hidden_states += (hidden_states,) 133 | 134 | next_cache = next_decoder_cache if use_cache else None 135 | if return_legacy_cache: 136 | next_cache = next_cache.to_legacy_cache() 137 | 138 | if not return_dict: 139 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 140 | 141 | hidden_states = hidden_states[:, -1,:].unsqueeze(1) 142 | return BaseModelOutputWithPast( 143 | last_hidden_state=hidden_states, 144 | past_key_values=next_cache, 145 | hidden_states=all_hidden_states, 146 | attentions=all_self_attns, 147 | ) 148 | 149 | def slm_mistral_flash_attn2_forward( 150 | self, 151 | hidden_states: torch.Tensor, 152 | attention_mask: Optional[torch.Tensor] = None, 153 | position_ids: Optional[torch.LongTensor] = None, 154 | past_key_value: Optional[Cache] = None, 155 | output_attentions: bool = False, 156 | use_cache: bool = False, 157 | cache_position: Optional[torch.LongTensor] = None, 158 | ): 159 | if isinstance(past_key_value, StaticCache): 160 | raise ValueError( 161 | "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " 162 | "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" 163 | ) 164 | init_slm(self) 165 | output_attentions = False 166 | 167 | output_attentions = False 168 | 169 | bsz, q_len, _ = hidden_states.size() 170 | 171 | query_states = self.q_proj(hidden_states) 172 | key_states = self.k_proj(hidden_states) 173 | value_states = self.v_proj(hidden_states) 174 | 175 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 176 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 177 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 178 | 179 | kv_seq_len = key_states.shape[-2] 180 | if past_key_value is not None: 181 | kv_seq_len += cache_position[0] 182 | 183 | cos, sin = self.rotary_emb(value_states, position_ids) 184 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) 185 | 186 | dropout_rate = 0.0 if not self.training else self.attention_dropout 187 | 188 | cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models 189 | if past_key_value is not None: 190 | # Activate slicing cache only if the config has a value `sliding_windows` attribute 191 | cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 192 | if ( 193 | getattr(self.config, "sliding_window", None) is not None 194 | and kv_seq_len > self.config.sliding_window 195 | and cache_has_contents 196 | ): 197 | slicing_tokens = 1 - self.config.sliding_window 198 | 199 | past_key = past_key_value[self.layer_idx][0] 200 | past_value = past_key_value[self.layer_idx][1] 201 | 202 | past_key = past_key[:, :, slicing_tokens:, :].contiguous() 203 | past_value = past_value[:, :, slicing_tokens:, :].contiguous() 204 | 205 | if past_key.shape[-2] != self.config.sliding_window - 1: 206 | raise ValueError( 207 | f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" 208 | f" {past_key.shape}" 209 | ) 210 | 211 | if attention_mask is not None: 212 | attention_mask = attention_mask[:, slicing_tokens:] 213 | attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) 214 | 215 | if q_len == 1: 216 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 217 | else: 218 | key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states) 219 | past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) 220 | 221 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 222 | # therefore the input hidden states gets silently casted in float32. Hence, we need 223 | # cast them back in float16 just to be sure everything works as expected. 224 | input_dtype = query_states.dtype 225 | if input_dtype == torch.float32: 226 | if torch.is_autocast_enabled(): 227 | target_dtype = torch.get_autocast_gpu_dtype() 228 | # Handle the case where the model is quantized 229 | elif hasattr(self.config, "_pre_quantization_dtype"): 230 | target_dtype = self.config._pre_quantization_dtype 231 | else: 232 | target_dtype = self.q_proj.weight.dtype 233 | 234 | logger.warning_once( 235 | f"The input hidden states seems to be silently casted in float32, this might be related to" 236 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 237 | f" {target_dtype}." 238 | ) 239 | 240 | query_states = query_states.to(target_dtype) 241 | key_states = key_states.to(target_dtype) 242 | value_states = value_states.to(target_dtype) 243 | 244 | # Reashape to the expected shape for Flash Attention 245 | query_states = query_states.transpose(1, 2) 246 | key_states = key_states.transpose(1, 2) 247 | value_states = value_states.transpose(1, 2) 248 | 249 | attn_output = _flash_attention_forward( 250 | query_states, 251 | key_states, 252 | value_states, 253 | attention_mask, 254 | q_len, 255 | position_ids=position_ids, 256 | dropout=dropout_rate, 257 | sliding_window=getattr(self.config, "sliding_window", None), 258 | use_top_left_mask=self._flash_attn_uses_top_left_mask, 259 | is_causal=self.is_causal, 260 | ) 261 | 262 | attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() 263 | attn_output = self.o_proj(attn_output) 264 | 265 | if not output_attentions: 266 | attn_weights = None 267 | 268 | return attn_output, attn_weights, past_key_value 269 | 270 | 271 | def prepare_inputs_for_generation_mistral( 272 | self, 273 | input_ids, 274 | past_key_values=None, 275 | attention_mask=None, 276 | inputs_embeds=None, 277 | cache_position=None, 278 | position_ids=None, 279 | use_cache=True, 280 | **kwargs, 281 | ): 282 | # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens 283 | # Exception 1: when passing input_embeds, input_ids may be missing entries 284 | # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here 285 | if past_key_values is not None: 286 | if inputs_embeds is not None: # Exception 1 287 | input_ids = input_ids[:, -cache_position.shape[0] :] 288 | elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) 289 | input_ids = input_ids[:, cache_position] 290 | 291 | if attention_mask is not None and position_ids is None: 292 | # create position_ids on the fly for batch generation 293 | position_ids = attention_mask.long().cumsum(-1) - 1 294 | position_ids.masked_fill_(attention_mask == 0, 1) 295 | if past_key_values: 296 | position_ids = position_ids[:, -input_ids.shape[1] :] 297 | 298 | # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. 299 | position_ids = position_ids.clone(memory_format=torch.contiguous_format) 300 | 301 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 302 | if inputs_embeds is not None and cache_position[0] == 0: 303 | model_inputs = {"inputs_embeds": inputs_embeds} 304 | else: 305 | model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases 306 | 307 | model_inputs.update( 308 | { 309 | "position_ids": position_ids, 310 | "cache_position": cache_position, 311 | "past_key_values": past_key_values, 312 | "use_cache": use_cache, 313 | "attention_mask": attention_mask, 314 | } 315 | ) 316 | return model_inputs 317 | -------------------------------------------------------------------------------- /adaptive_snapkv/monkeypatch/snapkv_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import time 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import math 8 | from typing import List, Optional, Tuple, Union, Any,Dict 9 | from transformers.cache_utils import Cache, DynamicCache 10 | from flash_attn import flash_attn_func 11 | # perform qk calculation and get indices 12 | # this version will not update in inference mode 13 | 14 | class DynamicCacheSplitHeadFlatten(Cache): 15 | """ 16 | Flattened version of DynamicCacheSplitHead 17 | """ 18 | def __init__(self) ->None: 19 | # Token wise List[] Head wise KV List[torch.Tensor] 20 | super().__init__() 21 | self.key_cache: List[List[torch.Tensor]] = [] 22 | self.value_cache: List[List[torch.Tensor]] = [] 23 | self._seen_tokens = 0 24 | 25 | def __len__(self): 26 | return len(self.key_cache) 27 | 28 | def __iter__(self): 29 | for layer_idx in range(len(self)): 30 | yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) 31 | 32 | def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]: 33 | if layer_idx < len(self): 34 | return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) 35 | else: 36 | raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") 37 | 38 | def update(self, key_states, value_states, layer_idx, cache_kwargs=None): 39 | # NOTE: k, v = [head_num](bs, 1, seqlen, dim) 40 | # each layer is a flatten layout like: 41 | # [head_0_len + head_1_len + ..., dim] 42 | if len(self.key_cache) <= layer_idx: 43 | self.key_cache.append(key_states) 44 | self.value_cache.append(value_states) 45 | else: 46 | assert self.key_cache[layer_idx].dim() == 2 47 | bs, head, seqlen, dim = key_states.shape 48 | assert bs == 1 and seqlen == 1 49 | # NOTE: phase 2. we got [bs, head, seqlen, dim] as k, v input 50 | head_lens = cache_kwargs["head_lens"] 51 | cu_klen = cache_kwargs["cu_klen"] 52 | 53 | # TODO: wrap as a python interface 54 | from tiny_api_cuda import update_flatten_view 55 | new_key_cache = update_flatten_view(self.key_cache[layer_idx].view(-1,dim), key_states.view(-1, dim), head_lens, cu_klen) 56 | new_value_cache = update_flatten_view(self.value_cache[layer_idx].view(-1,dim), value_states.view(-1, dim), head_lens, cu_klen) 57 | 58 | 59 | self.key_cache[layer_idx] = new_key_cache 60 | self.value_cache[layer_idx] = new_value_cache 61 | 62 | 63 | return self.key_cache[layer_idx], self.value_cache[layer_idx] 64 | 65 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 66 | if len(self.key_cache) <= layer_idx: 67 | return 0 68 | 69 | # TODO: return 1 to means has content for now 70 | return 1 71 | # return max(map(lambda states: states.shape[-2], self.key_cache[layer_idx])) 72 | 73 | def get_max_length(self) -> Optional[int]: 74 | return None 75 | 76 | def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: 77 | """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" 78 | legacy_cache = () 79 | for layer_idx in range(len(self)): 80 | legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) 81 | return legacy_cache 82 | 83 | @classmethod 84 | def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCacheEachHead": 85 | """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" 86 | cache = cls() 87 | if past_key_values is not None: 88 | for layer_idx in range(len(past_key_values)): 89 | key_states, value_states = past_key_values[layer_idx] 90 | cache.update(key_states, value_states, layer_idx) 91 | return cache 92 | 93 | 94 | 95 | # class DynamicCacheSplitHead(Cache): 96 | # """ 97 | # demo for illustrate the splited cache update 98 | # This class is slower than DynamicCacheSplitHeadFlatten, due to the frequent tensor copy 99 | # """ 100 | # def __init__(self) ->None: 101 | # # Token wise List[] Head wise KV List[torch.Tensor] 102 | # super().__init__() 103 | # self.key_cache: List[List[torch.Tensor]] = [] 104 | # self.value_cache: List[List[torch.Tensor]] = [] 105 | # self._seen_tokens = 0 106 | 107 | # def __len__(self): 108 | # return len(self.key_cache) 109 | 110 | # def __iter__(self): 111 | # for layer_idx in range(len(self)): 112 | # yield (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) 113 | 114 | # def __getitem__(self, layer_idx: int) -> Tuple[Tuple[torch.Tensor],Tuple[torch.Tensor]]: 115 | # if layer_idx < len(self): 116 | # return (tuple(self.key_cache[layer_idx]),tuple(self.value_cache[layer_idx])) 117 | # else: 118 | # raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") 119 | 120 | # def update( 121 | # self, 122 | # key_states: List[torch.Tensor], 123 | # value_states: List[torch.Tensor], 124 | # layer_idx: int, 125 | # cache_kwargs: Optional[Dict[str, Any]] = None, 126 | # ) -> Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]]: 127 | # if layer_idx == 0: 128 | # self._seen_tokens += max(map(lambda states: states.shape[-2], key_states)) 129 | 130 | # if len(self.key_cache)<=layer_idx: 131 | # self.key_cache.append(list(key_states)) 132 | # self.value_cache.append(list(value_states)) 133 | # else: 134 | # # tensor shape[ [bsz, seq, dim] * head_nums] 135 | # # [bsz,\sum seq,dim] 136 | # # [bsz,\sum seq+headnum,dim ] 137 | # for head_idx in range(len(key_states)): 138 | # self.key_cache[layer_idx][head_idx] = torch.cat([self.key_cache[layer_idx][head_idx],key_states[head_idx]], dim=-2) 139 | # self.value_cache[layer_idx][head_idx] = torch.cat([self.value_cache[layer_idx][head_idx],value_states[head_idx]], dim=-2) 140 | # return tuple(self.key_cache[layer_idx]), tuple(self.value_cache[layer_idx]) 141 | 142 | # def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 143 | # if len(self.key_cache) <= layer_idx: 144 | # return 0 145 | # return max(map(lambda states: states.shape[-2], self.key_cache[layer_idx])) 146 | 147 | # def get_max_length(self) -> Optional[int]: 148 | # return None 149 | 150 | 151 | # # Tuple[Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]],...] 152 | # def to_legacy_cache(self)-> Tuple[Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]],...]: 153 | # """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" 154 | # legacy_cache = () 155 | # for layer_idx in range(len(self)): 156 | # legacy_cache += ((tuple(self.key_cache[layer_idx]), tuple(self.value_cache[layer_idx])),) 157 | # return legacy_cache 158 | # @classmethod 159 | # def from_legacy_cache(cls,past_key_values:Optional[ Tuple[Tuple[Tuple[torch.Tensor,...],Tuple[torch.Tensor,...]],...]]=None)->"DynamicCacheEachHead": 160 | # cache = cls() 161 | # if past_key_values is not None: 162 | # for layer_idx in range(len(past_key_values)): 163 | # key_states,value_states = past_key_values[layer_idx] 164 | # cache.update(list(key_states),list(value_states),layer_idx) 165 | # return cache 166 | 167 | 168 | 169 | # Copied from transformers.models.llama.modeling_llama.repeat_kv for gqa_support 170 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 171 | """ 172 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 173 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 174 | """ 175 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 176 | if n_rep == 1: 177 | return hidden_states 178 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 179 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 180 | 181 | class SnapKVCluster(): 182 | def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool', layer_idx = None, num_hidden_layers = None, pyram_mode = False, pyram_beta = 20,gqa_support=False,num_key_value_groups = 1, gqa_func=None): 183 | self.window_size = window_size 184 | self.max_capacity_prompt = max_capacity_prompt 185 | assert self.max_capacity_prompt - self.window_size > 0 186 | self.kernel_size = kernel_size 187 | self.pooling = pooling 188 | 189 | self.pyram_init = False 190 | self.pyram_mode = pyram_mode 191 | self.pyram_beta = pyram_beta 192 | self.layer_idx = layer_idx 193 | self.num_hidden_layers = num_hidden_layers 194 | 195 | # support gqa 196 | self.gqa_support = gqa_support 197 | self.num_key_value_groups = num_key_value_groups 198 | self.gqa_func = gqa_func 199 | if self.gqa_support: 200 | assert gqa_func is not None, "gqa_func should not be None" 201 | assert gqa_func in ['max','mean'], "currently gqa_func should be in ['max','mean']" 202 | if self.num_key_value_groups == 1: 203 | warnings.warn("gqa_support is enabled, but num_key_value_groups is 1, which means the model is not using gqa. Please check the model configuration.") 204 | 205 | 206 | 207 | def reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'): 208 | self.window_size = window_size 209 | self.max_capacity_prompt = max_capacity_prompt 210 | assert self.max_capacity_prompt - self.window_size > 0 211 | self.kernel_size = kernel_size 212 | self.pooling = pooling 213 | 214 | def update_kv(self, origin_key_states, query_states, origin_value_states): 215 | 216 | # support gqa 217 | key_states = repeat_kv(origin_key_states, self.num_key_value_groups) 218 | value_states = repeat_kv(origin_value_states, self.num_key_value_groups) 219 | # check if prefix phase 220 | assert key_states.shape[-2] == query_states.shape[-2] 221 | bsz, num_heads, q_len, head_dim = query_states.shape 222 | 223 | # compute pyramidal capacity 224 | if self.pyram_mode and not self.pyram_init: 225 | # NOTE: (max_num + min_num) / 2 == base_capacity to restrict the total capacity 226 | base_capacity = self.max_capacity_prompt - self.window_size 227 | min_num = base_capacity // self.pyram_beta 228 | max_num = base_capacity * 2 - min_num 229 | 230 | # if the max_num is larger than the query length, we need to adjust the max_num 231 | if max_num >= q_len - self.window_size: 232 | max_num = q_len - self.window_size 233 | min_num = base_capacity * 2 - max_num 234 | 235 | # NOTE: compute interval 236 | steps = (max_num - min_num) // (self.num_hidden_layers - 1) 237 | 238 | self.max_capacity_prompt = max_num - self.layer_idx * steps + self.window_size 239 | self.pyram_init = True 240 | print(f"Pyram mode adaptive capacity, layer: {self.layer_idx}, max_capacity_prompt: {self.max_capacity_prompt}, base_capacity: {self.max_capacity_prompt - self.window_size}", flush=True) 241 | 242 | if q_len < self.max_capacity_prompt: 243 | # support gqa 244 | if self.gqa_support: 245 | return origin_key_states, origin_value_states 246 | else: 247 | return key_states, value_states 248 | else: 249 | attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim) 250 | mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device) 251 | mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) 252 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 253 | mask = mask.to(attn_weights.device) 254 | attention_mask = mask[None, None, :, :] 255 | 256 | attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask 257 | 258 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 259 | attn_weights_mean = attn_weights[:, :, -self.window_size:, : -self.window_size].mean(dim = -2) 260 | 261 | # gqa_support 262 | if self.gqa_support: 263 | attn_weights_mean = attn_weights_mean.view(attn_weights_mean.shape[0], -1, self.num_key_value_groups, attn_weights_mean.shape[-1]) 264 | if self.gqa_func == 'max': 265 | attn_weights_mean = attn_weights_mean.max(dim=-2).values 266 | elif self.gqa_func == 'mean': 267 | attn_weights_mean = attn_weights_mean.mean(dim=-2) 268 | else: 269 | raise ValueError('gqa_func not supported') 270 | 271 | if self.pooling == 'avgpool': 272 | attn_cache = F.avg_pool1d(attn_weights_mean, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1) 273 | elif self.pooling == 'maxpool': 274 | attn_cache = F.max_pool1d(attn_weights_mean, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1) 275 | else: 276 | raise ValueError('Pooling method not supported') 277 | 278 | indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices 279 | indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) 280 | 281 | # support gqa 282 | if self.gqa_support: 283 | k_past_compress = origin_key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) 284 | v_past_compress = origin_value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) 285 | k_cur = origin_key_states[:, :, -self.window_size:, :] 286 | v_cur = origin_value_states[:, :, -self.window_size:, :] 287 | else: 288 | k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) 289 | v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices) 290 | k_cur = key_states[:, :, -self.window_size:, :] 291 | v_cur = value_states[:, :, -self.window_size:, :] 292 | key_states = torch.cat([k_past_compress, k_cur], dim = 2) 293 | value_states = torch.cat([v_past_compress, v_cur], dim = 2) 294 | return key_states, value_states 295 | 296 | 297 | class AdaptiveSnapKVCluster(): 298 | def __init__(self, window_size = 32, kernel_size = 7, pooling = 'maxpool',base_capacity=None,floor_alpha = None,skip = None,normalize=None, 299 | layer_idx = None, num_hidden_layers = None, pyram_mode = False, pyram_beta = 20,gqa_support=False,num_key_value_groups = 1, gqa_func=None): 300 | self.window_size = window_size 301 | self.kernel_size = kernel_size 302 | self.pooling = pooling 303 | self.base_capacity = base_capacity - window_size 304 | self.floor_ratio = floor_alpha 305 | self.floor_capacity = int(self.base_capacity * self.floor_ratio) 306 | self.adaptive_capacity = self.base_capacity - self.floor_capacity 307 | self.skip = skip 308 | 309 | self.normalize = normalize 310 | self.pyram_init = False 311 | self.pyram_mode = pyram_mode 312 | self.pyram_beta = pyram_beta 313 | self.layer_idx = layer_idx 314 | self.num_hidden_layers = num_hidden_layers 315 | 316 | # NOTE: layer-wise meta-data 317 | self.head_lens = None 318 | self.max_seqlen_k = 0 319 | self.klen_sum = 0 320 | self.cu_klen = 0 321 | self.cu_offset = None 322 | self.cu_headlens = None 323 | 324 | # support gqa 325 | self.gqa_support = gqa_support 326 | self.num_key_value_groups = num_key_value_groups 327 | self.gqa_func = gqa_func 328 | if self.gqa_support: 329 | assert gqa_func is not None, "gqa_func should not be None" 330 | assert gqa_func in ['max','mean'], "currently gqa_func should be in ['max','mean']" 331 | if self.num_key_value_groups == 1: 332 | warnings.warn("gqa_support is enabled, but num_key_value_groups is 1, which means the model is not using gqa. Please check the model configuration.") 333 | 334 | 335 | def calcul_attn_sore(self, key_states, query_states): 336 | bsz, num_heads, q_len, head_dim = query_states.shape 337 | attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt( 338 | head_dim) 339 | mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, 340 | device=attn_weights.device) 341 | mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) 342 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 343 | mask = mask.to(attn_weights.device) 344 | attention_mask = mask[None, None, :, :] 345 | 346 | attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask 347 | 348 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 349 | attn_weights_mean = attn_weights[:, :, -self.window_size:, : -self.window_size].mean(dim=-2) 350 | 351 | if self.gqa_support: 352 | attn_weights_mean = attn_weights_mean.view(attn_weights_mean.shape[0],num_heads//self.num_key_value_groups,self.num_key_value_groups,-1) 353 | if self.gqa_func == 'max': 354 | attn_weights_mean = attn_weights_mean.max(dim=-2).values 355 | elif self.gqa_func == 'mean': 356 | attn_weights_mean = attn_weights_mean.mean(dim=-2) 357 | else: 358 | raise ValueError('gqa_func not supported') 359 | 360 | if self.pooling == 'avgpool': 361 | attn_weights_mean_pooling = F.avg_pool1d(attn_weights_mean, kernel_size=self.kernel_size, 362 | padding=self.kernel_size // 2, 363 | stride=1) 364 | elif self.pooling == 'maxpool': 365 | attn_weights_mean_pooling = F.max_pool1d(attn_weights_mean, kernel_size=self.kernel_size, 366 | padding=self.kernel_size // 2, 367 | stride=1) 368 | else: 369 | raise ValueError('Pooling method not supported') 370 | return attn_weights_mean_pooling 371 | 372 | def update_kv(self, origin_key_states, query_states, origin_value_states): 373 | if self.gqa_support: 374 | return self.update_kv_gqa(origin_key_states, query_states, origin_value_states) 375 | else: 376 | return self.update_kv_wo_gqa(origin_key_states, query_states, origin_value_states) 377 | 378 | 379 | # update kv with gqa_support 380 | def update_kv_gqa(self, origin_key_states, query_states, origin_value_states): 381 | key_states = repeat_kv(origin_key_states, self.num_key_value_groups) 382 | # value_states = repeat_kv(origin_value_states, self.num_key_value_groups) 383 | 384 | # check if prefix phase assert key_states.shape[-2] == query_states.shape[-2] 385 | _device = key_states.device 386 | bsz, num_heads, q_len, head_dim = query_states.shape 387 | attn_score= self.calcul_attn_sore(key_states,query_states) 388 | origin_heads_key_states = torch.split(origin_key_states, 1, dim=1) 389 | origin_heads_value_states = torch.split(origin_value_states, 1, dim=1) 390 | 391 | # compute pyramidal capacity 392 | if self.pyram_mode and not self.pyram_init: 393 | # NOTE: (max_num + min_num) / 2 == base_capacity to restrict the total capacity 394 | min_num = self.base_capacity // self.pyram_beta 395 | max_num = self.base_capacity * 2 - min_num 396 | 397 | # if the max_num is larger than the query length, we need to adjust the max_num 398 | if max_num >= q_len - self.window_size: 399 | max_num = q_len - self.window_size 400 | min_num = self.base_capacity * 2 - max_num 401 | 402 | # NOTE: compute interval 403 | steps = (max_num - min_num) // (self.num_hidden_layers - 1) 404 | 405 | # renew adaptive capacity 406 | self.base_capacity = max_num - self.layer_idx * steps 407 | self.floor_capacity = int(self.base_capacity * self.floor_ratio) 408 | self.adaptive_capacity = self.base_capacity - self.floor_capacity 409 | self.pyram_init = True 410 | print(f"Pyram mode adaptive capacity, layer: {self.layer_idx}, acap: {self.adaptive_capacity}, bcap: {self.base_capacity}, fcap: {self.floor_capacity}", flush=True) 411 | 412 | def init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k): 413 | # init metadata 414 | self.head_lens = torch.tensor(k_lens, dtype=torch.int32, device=_device) 415 | self.klen_sum = klen_sum 416 | self.max_seqlen_k = max_seqlen_k 417 | self.cu_headlens = torch.cumsum(self.head_lens, dim=0, dtype=torch.int32) 418 | # init varlen flash attention metadata 419 | self.cu_klen = self.cu_headlens - self.head_lens 420 | self.cu_klen = torch.cat( 421 | [self.cu_klen, torch.tensor([self.klen_sum], dtype=torch.int32, device=_device)], dim=0) 422 | # check bug 423 | self.layer_qlens = torch.ones(num_heads//self.num_key_value_groups, dtype=torch.int32,device=_device) 424 | self.qlen_sum = num_heads//self.num_key_value_groups 425 | self.cu_qlen = torch.cumsum(self.layer_qlens, dim=0, dtype=torch.int32) - self.layer_qlens 426 | self.cu_qlen = torch.cat( 427 | [self.cu_qlen, torch.tensor([self.qlen_sum], dtype=torch.int32, device=_device)], dim=0) 428 | 429 | 430 | if self.gqa_support: 431 | self.cu_offset = torch.arange(0, num_heads//self.num_key_value_groups + 1, dtype=torch.int32, device=_device) 432 | self.cu_head_offset = torch.arange(1, num_heads//self.num_key_value_groups +1, dtype=torch.int32, device=_device) 433 | 434 | else: 435 | self.cu_offset = torch.arange(0, num_heads + 1, dtype=torch.int32, device=_device) 436 | self.cu_head_offset = torch.arange(1, num_heads+1, dtype=torch.int32, device=_device) 437 | 438 | 439 | if self.base_capacity > attn_score.size(-1): 440 | init_metadata(num_heads, [q_len] * (num_heads//self.num_key_value_groups), q_len * (num_heads//self.num_key_value_groups), q_len) 441 | # not compress 442 | return origin_key_states.reshape(-1, head_dim), origin_value_states.reshape(-1, head_dim) 443 | 444 | 445 | sorted_attn_score,sorted_attn_score_indices = attn_score.sort(dim=-1,descending=True) 446 | if self.layer_idx >= self.skip: 447 | adaptive_attn_score = sorted_attn_score 448 | length = adaptive_attn_score.size(dim=-1) 449 | if self.normalize: 450 | ratio_weight = sorted_attn_score[...,:self.base_capacity].sum(dim=-1,keepdim=True)/sorted_attn_score.sum(dim=-1,keepdim=True) 451 | adaptive_attn_score = adaptive_attn_score*ratio_weight 452 | adaptive_attn_score = adaptive_attn_score.reshape(bsz,length*num_heads//self.num_key_value_groups) 453 | sorted_indices = torch.topk(adaptive_attn_score,k=num_heads*self.base_capacity//self.num_key_value_groups,dim=-1).indices 454 | sorted_indices = sorted_indices//length 455 | 456 | # floor_alpha capacity set 457 | head_adaptive_capacity = torch.zeros((bsz,num_heads//self.num_key_value_groups),device=_device,dtype = sorted_indices.dtype) 458 | head_adaptive_capacity.scatter_add_(-1,sorted_indices,torch.ones_like(sorted_indices,dtype=head_adaptive_capacity.dtype),) 459 | assert head_adaptive_capacity.sum().item() == num_heads*self.base_capacity//self.num_key_value_groups 460 | head_adaptive_capacity = torch.round(head_adaptive_capacity * (1-self.floor_ratio) + self.floor_capacity).int() 461 | else: 462 | head_adaptive_capacity = torch.ones((bsz,num_heads),device=_device,dtype = sorted_attn_score_indices.dtype) * self.base_capacity 463 | sorted_attn_score_indices = sorted_attn_score_indices.split(1,dim=1) 464 | 465 | heads_key_states = [] 466 | heads_value_states = [] 467 | assert bsz == 1 468 | # per head 469 | 470 | # reinit varlen metadata 471 | k_lens = [] 472 | klen_sum = 0 473 | max_seqlen_k = 0 474 | self.cu_klen = 0 475 | 476 | 477 | for head_idx in range(num_heads//self.num_key_value_groups): 478 | cache_index = sorted_attn_score_indices[head_idx][...,:head_adaptive_capacity[0][head_idx]] 479 | 480 | l = cache_index.shape[-1] + self.window_size 481 | k_lens.append(l) 482 | max_seqlen_k = max(max_seqlen_k, l) 483 | klen_sum += l 484 | 485 | cache_index = cache_index.view(1, 1, -1, 1).expand(-1, -1, -1, head_dim) 486 | top_Kcache = origin_heads_key_states[head_idx].gather(dim=2,index=cache_index) 487 | top_Vcache = origin_heads_value_states[head_idx].gather(dim=2,index=cache_index) 488 | selected_k = torch.cat([top_Kcache,origin_heads_key_states[head_idx][:, :, -self.window_size:, :]],dim=2) 489 | selected_v = torch.cat([top_Vcache,origin_heads_value_states[head_idx][:, :, -self.window_size:, :]],dim=2) 490 | 491 | # NOTE: flatten view 492 | heads_key_states.append(selected_k.view(-1, head_dim)) 493 | heads_value_states.append(selected_v.view(-1, head_dim)) 494 | 495 | init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k) 496 | 497 | # NOTE: compose as flatten view 498 | heads_key_states = torch.cat(heads_key_states, dim=0) 499 | heads_value_states = torch.cat(heads_value_states, dim=0) 500 | 501 | return heads_key_states, heads_value_states 502 | 503 | 504 | 505 | 506 | # update without gqa_support 507 | def update_kv_wo_gqa(self, origin_key_states, query_states, origin_value_states): 508 | key_states = repeat_kv(origin_key_states, self.num_key_value_groups) 509 | value_states = repeat_kv(origin_value_states, self.num_key_value_groups) 510 | 511 | # check if prefix phase assert key_states.shape[-2] == query_states.shape[-2] 512 | _device = key_states.device 513 | bsz, num_heads, q_len, head_dim = query_states.shape 514 | attn_score= self.calcul_attn_sore(key_states,query_states) 515 | origin_heads_key_states = torch.split(key_states, 1, dim=1) 516 | origin_heads_value_states = torch.split(value_states, 1, dim=1) 517 | 518 | # compute pyramidal capacity 519 | if self.pyram_mode and not self.pyram_init: 520 | # NOTE: (max_num + min_num) / 2 == base_capacity to restrict the total capacity 521 | min_num = self.base_capacity // self.pyram_beta 522 | max_num = self.base_capacity * 2 - min_num 523 | 524 | # if the max_num is larger than the query length, we need to adjust the max_num 525 | if max_num >= q_len - self.window_size: 526 | max_num = q_len - self.window_size 527 | min_num = self.base_capacity * 2 - max_num 528 | 529 | # NOTE: compute interval 530 | steps = (max_num - min_num) // (self.num_hidden_layers - 1) 531 | 532 | # renew adaptive capacity 533 | self.base_capacity = max_num - self.layer_idx * steps 534 | self.floor_capacity = int(self.base_capacity * self.floor_ratio) 535 | self.adaptive_capacity = self.base_capacity - self.floor_capacity 536 | self.pyram_init = True 537 | print(f"Pyram mode adaptive capacity, layer: {self.layer_idx}, acap: {self.adaptive_capacity}, bcap: {self.base_capacity}, fcap: {self.floor_capacity}", flush=True) 538 | 539 | def init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k): 540 | # init metadata 541 | self.head_lens = torch.tensor(k_lens, dtype=torch.int32, device=_device) 542 | self.klen_sum = klen_sum 543 | self.max_seqlen_k = max_seqlen_k 544 | self.cu_headlens = torch.cumsum(self.head_lens, dim=0, dtype=torch.int32) 545 | # init varlen flash attention metadata 546 | self.cu_klen = self.cu_headlens - self.head_lens 547 | self.cu_klen = torch.cat( 548 | [self.cu_klen, torch.tensor([self.klen_sum], dtype=torch.int32, device=_device)], dim=0) 549 | self.layer_qlens = torch.ones(num_heads, dtype=torch.int32,device=_device) 550 | self.qlen_sum = num_heads 551 | self.cu_qlen = torch.cumsum(self.layer_qlens, dim=0, dtype=torch.int32) - self.layer_qlens 552 | self.cu_qlen = torch.cat( 553 | [self.cu_qlen, torch.tensor([self.qlen_sum], dtype=torch.int32, device=_device)], dim=0) 554 | self.cu_offset = torch.arange(0, num_heads + 1, dtype=torch.int32, device=_device) 555 | self.cu_head_offset = torch.arange(1, num_heads+1, dtype=torch.int32, device=_device) 556 | 557 | if self.base_capacity > attn_score.size(-1): 558 | init_metadata(num_heads, [q_len] * num_heads, q_len * num_heads, q_len) 559 | # not compress 560 | return key_states.reshape(-1, head_dim), value_states.reshape(-1, head_dim) 561 | 562 | # if you need to weight the attn_score 563 | pass 564 | sorted_attn_score,sorted_attn_score_indices = attn_score.sort(dim=-1,descending=True) 565 | if self.layer_idx >= self.skip: 566 | adaptive_attn_score = sorted_attn_score 567 | length = adaptive_attn_score.size(dim=-1) 568 | if self.normalize: 569 | ratio_weight = sorted_attn_score[...,:self.base_capacity].sum(dim=-1,keepdim=True)/sorted_attn_score.sum(dim=-1,keepdim=True) 570 | adaptive_attn_score = adaptive_attn_score*ratio_weight 571 | adaptive_attn_score = adaptive_attn_score.reshape(bsz,length*num_heads) 572 | sorted_indices = torch.topk(adaptive_attn_score,k=num_heads*self.base_capacity,dim=-1).indices 573 | sorted_indices = sorted_indices//length 574 | # floor_alpha capacity set 575 | head_adaptive_capacity = torch.zeros((bsz,num_heads),device=_device,dtype = sorted_indices.dtype) 576 | head_adaptive_capacity.scatter_add_(-1,sorted_indices,torch.ones_like(sorted_indices,dtype=head_adaptive_capacity.dtype),) 577 | assert head_adaptive_capacity.sum().item() == num_heads*self.base_capacity 578 | head_adaptive_capacity = torch.round(head_adaptive_capacity * (1-self.floor_ratio) + self.floor_capacity).int() 579 | else: 580 | head_adaptive_capacity = torch.ones((bsz,num_heads),device=_device,dtype = sorted_attn_score_indices.dtype) * self.base_capacity 581 | sorted_attn_score_indices = sorted_attn_score_indices.split(1,dim=1) 582 | 583 | heads_key_states = [] 584 | heads_value_states = [] 585 | assert bsz == 1 586 | # per head 587 | 588 | # reinit varlen metadata 589 | k_lens = [] 590 | klen_sum = 0 591 | max_seqlen_k = 0 592 | self.cu_klen = 0 593 | 594 | 595 | for head_idx in range(num_heads): 596 | cache_index = sorted_attn_score_indices[head_idx][...,:head_adaptive_capacity[0][head_idx]] 597 | 598 | l = cache_index.shape[-1] + self.window_size 599 | k_lens.append(l) 600 | max_seqlen_k = max(max_seqlen_k, l) 601 | klen_sum += l 602 | 603 | cache_index = cache_index.view(1, 1, -1, 1).expand(-1, -1, -1, head_dim) 604 | top_Kcache = origin_heads_key_states[head_idx].gather(dim=2,index=cache_index) 605 | top_Vcache = origin_heads_value_states[head_idx].gather(dim=2,index=cache_index) 606 | selected_k = torch.cat([top_Kcache,origin_heads_key_states[head_idx][:, :, -self.window_size:, :]],dim=2) 607 | selected_v = torch.cat([top_Vcache,origin_heads_value_states[head_idx][:, :, -self.window_size:, :]],dim=2) 608 | 609 | # NOTE: flatten view 610 | heads_key_states.append(selected_k.view(-1, head_dim)) 611 | heads_value_states.append(selected_v.view(-1, head_dim)) 612 | 613 | init_metadata(num_heads, k_lens, klen_sum, max_seqlen_k) 614 | 615 | # NOTE: compose as flatten view 616 | heads_key_states = torch.cat(heads_key_states, dim=0) 617 | heads_value_states = torch.cat(heads_value_states, dim=0) 618 | 619 | return heads_key_states,heads_value_states 620 | 621 | 622 | def init_snapkv(self): 623 | 624 | assert hasattr(self.config, 'window_size'), "window_size not set" 625 | assert hasattr(self.config, 'kernel_size'), "kernel_size not set" 626 | assert hasattr(self.config, "pooling"), "pooling not set" 627 | assert hasattr(self.config, "base_capacity"), "base_capacity not set" 628 | # init only once 629 | if not hasattr(self, "kv_cluster"): 630 | self.kv_cluster = SnapKVCluster( 631 | window_size = self.config.window_size, 632 | max_capacity_prompt = self.config.base_capacity, 633 | kernel_size = self.config.kernel_size, 634 | pooling = self.config.pooling, 635 | layer_idx = self.layer_idx, 636 | num_hidden_layers = self.config.num_hidden_layers, 637 | pyram_mode = self.config.pyram_mode, 638 | pyram_beta = self.config.pyram_beta, 639 | gqa_support = self.config.gqa_support, 640 | num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads, 641 | gqa_func = self.config.gqa_func 642 | ) 643 | if self.config.gqa_support: 644 | if self.config.model_type != "mistral": 645 | warnings.warn("GQA currently supports only for mistral-7B-v0.2 model") 646 | # if len(self.config.skip) > 0: 647 | # warnings.warn("vanilla transformer should not enable skip",self.config.skip) 648 | print(f"Compress config(Snap): window_size={self.kv_cluster.window_size}, max_capacity_prompt={self.kv_cluster.max_capacity_prompt}, kernel_size={self.kv_cluster.kernel_size}, pooling={self.kv_cluster.pooling}, pyram_mode={self.kv_cluster.pyram_mode}, beta={self.kv_cluster.pyram_beta}", flush=True) 649 | 650 | def init_adaptive_snapkv(self): 651 | assert hasattr(self.config,'window_size'),"window_size not set" 652 | assert hasattr(self.config,'kernel_size'),"kernel_size not set" 653 | assert hasattr(self.config,"pooling"),"pooling not set" 654 | assert hasattr(self.config, "base_capacity"), "base_capacity not set" 655 | assert hasattr(self.config,"floor_alpha"),"floor_alpha not set" 656 | assert self.config.floor_alpha is not None 657 | 658 | 659 | # init only once 660 | if not hasattr(self, "kv_cluster"): 661 | self.kv_cluster = AdaptiveSnapKVCluster( 662 | window_size = self.config.window_size, 663 | base_capacity=self.config.base_capacity, 664 | kernel_size = self.config.kernel_size, 665 | pooling = self.config.pooling, 666 | floor_alpha= self.config.floor_alpha, 667 | skip = self.config.skip, 668 | layer_idx = self.layer_idx, 669 | normalize = self.config.normalize, 670 | num_hidden_layers = self.config.num_hidden_layers, 671 | pyram_mode = self.config.pyram_mode, 672 | pyram_beta = self.config.pyram_beta, 673 | gqa_support = self.config.gqa_support, 674 | num_key_value_groups = self.config.num_attention_heads // self.config.num_key_value_heads, 675 | gqa_func = self.config.gqa_func 676 | ) 677 | if self.config.gqa_support: 678 | if self.config.model_type != "mistral": 679 | warnings.warn("GQA currently supports only for mistral-7B-v0.2 model") 680 | print(f"Compress config(Ada): window_size={self.kv_cluster.window_size}, base_capacity={self.kv_cluster.base_capacity}, kernel_size={self.kv_cluster.kernel_size}, pooling={self.kv_cluster.pooling}, floor_alpha={self.kv_cluster.floor_ratio}, pyram_mode={self.kv_cluster.pyram_mode}, beta={self.kv_cluster.pyram_beta}", flush=True) 681 | 682 | 683 | 684 | class StreamingLLMKVCluster(): 685 | def __init__(self, max_capacity_prompt = 256): 686 | self.max_capacity_prompt = max_capacity_prompt - 4 687 | self.sink_token = 4 688 | assert self.max_capacity_prompt - 4 > 0 689 | 690 | 691 | def update_kv(self, key_states, query_states, value_states, *args, **kwargs): 692 | # check if prefix phase 693 | assert key_states.shape[-2] == query_states.shape[-2] 694 | bsz, num_heads, q_len, head_dim = query_states.shape 695 | 696 | print(f"StreamingLLM max_capacity_prompt {self.max_capacity_prompt + 4}") 697 | 698 | if q_len < self.max_capacity_prompt + 4: 699 | return key_states, value_states 700 | else: 701 | k_past_compress = key_states[:, :, :self.sink_token, :] 702 | v_past_compress = value_states[:, :, :self.sink_token, :] 703 | k_cur = key_states[:, :, -self.max_capacity_prompt:, :] 704 | v_cur = value_states[:, :, -self.max_capacity_prompt:, :] 705 | key_states = torch.cat([k_past_compress, k_cur], dim = 2) 706 | value_states = torch.cat([v_past_compress, v_cur], dim = 2) 707 | return key_states, value_states 708 | 709 | 710 | def init_slm(self,**kwargs): 711 | assert hasattr(self.config, 'window_size'), "window_size not set" 712 | # init only once 713 | if not hasattr(self, "kv_cluster"): 714 | self.kv_cluster = StreamingLLMKVCluster( 715 | max_capacity_prompt = self.config.base_capacity, 716 | ) 717 | print(f"Compress config(SLM): max_cap={self.config.base_capacity}") 718 | 719 | 720 | 721 | 722 | 723 | -------------------------------------------------------------------------------- /assets/images/4k_ruler_average_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/4k_ruler_average_score.png -------------------------------------------------------------------------------- /assets/images/LongBench_mistral.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/LongBench_mistral.png -------------------------------------------------------------------------------- /assets/images/LongBench_mistral_gqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/LongBench_mistral_gqa.png -------------------------------------------------------------------------------- /assets/images/head_vary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/head_vary.png -------------------------------------------------------------------------------- /assets/images/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/main.png -------------------------------------------------------------------------------- /assets/images/mem.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/mem.png -------------------------------------------------------------------------------- /assets/images/speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FFY0/AdaKV/66b95be2d9e3c5e113874ee6d5370086838c2e16/assets/images/speed.png -------------------------------------------------------------------------------- /csrc/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | deps/ 3 | tiny_pkg.egg-info/ 4 | test.py 5 | -------------------------------------------------------------------------------- /csrc/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 66RING 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /csrc/build.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | from packaging.version import parse, Version 4 | from pathlib import Path 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import ( 7 | BuildExtension, 8 | CppExtension, 9 | CUDAExtension, 10 | CUDA_HOME, 11 | ) 12 | 13 | # package name managed by pip, which can be remove by `pip uninstall tiny_pkg` 14 | PACKAGE_NAME = "tiny_pkg" 15 | 16 | ext_modules = [] 17 | generator_flag = [] 18 | cc_flag = [] 19 | cc_flag.append("-gencode") 20 | cc_flag.append("arch=compute_80,code=sm_80") 21 | 22 | 23 | # helper function to get cuda version 24 | def get_cuda_bare_metal_version(cuda_dir): 25 | raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) 26 | output = raw_output.split() 27 | release_idx = output.index("release") + 1 28 | bare_metal_version = parse(output[release_idx].split(",")[0]) 29 | 30 | return raw_output, bare_metal_version 31 | 32 | 33 | # if CUDA_HOME is not None: 34 | # _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) 35 | # if bare_metal_version >= Version("11.8"): 36 | # cc_flag.append("-gencode") 37 | # cc_flag.append("arch=compute_90,code=sm_90") 38 | 39 | # ninja build does not work unless include_dirs are abs path 40 | this_dir = os.path.dirname(os.path.abspath(__file__)) 41 | 42 | # cuda module 43 | ext_modules.append( 44 | CUDAExtension( 45 | # package name for import 46 | name="tiny_api_cuda", 47 | sources=[ 48 | "csrc/cuda_api.cu", 49 | ], 50 | extra_compile_args={ 51 | # add c compile flags 52 | "cxx": ["-O3", "-std=c++17"] + generator_flag, 53 | # add nvcc compile flags 54 | "nvcc": [ 55 | "-O3", 56 | "-std=c++17", 57 | "-U__CUDA_NO_HALF_OPERATORS__", 58 | "--use_fast_math", 59 | "-lineinfo", 60 | "--ptxas-options=-v", 61 | "--ptxas-options=-O2", 62 | "-U__CUDA_NO_HALF_OPERATORS__", 63 | "-U__CUDA_NO_HALF_CONVERSIONS__", 64 | "-U__CUDA_NO_HALF2_OPERATORS__", 65 | "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", 66 | "--expt-relaxed-constexpr", 67 | "--expt-extended-lambda", 68 | "--use_fast_math", 69 | ] 70 | + generator_flag 71 | + cc_flag, 72 | }, 73 | include_dirs=[ 74 | Path(this_dir) / "csrc", 75 | Path(this_dir) / "include", 76 | # Path(this_dir) / "some" / "thing" / "more", 77 | ], 78 | ) 79 | ) 80 | 81 | setup( 82 | name=PACKAGE_NAME, 83 | packages=find_packages( 84 | exclude=( 85 | "build", 86 | "csrc", 87 | "include", 88 | "tests", 89 | "dist", 90 | "docs", 91 | "benchmarks", 92 | "tiny_pkg.egg-info", 93 | ) 94 | ), 95 | description="Tiny cuda and c api binding for pytorch.", 96 | ext_modules=ext_modules, 97 | cmdclass={ "build_ext": BuildExtension}, 98 | python_requires=">=3.7", 99 | install_requires=[ 100 | "torch", 101 | "einops", 102 | "packaging", 103 | "ninja", 104 | ], 105 | ) 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /csrc/csrc/cuda_api.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cuda_api.h" 9 | #include "static_switch.h" 10 | 11 | template 12 | __global__ void update_flatten_view_kernel(tensor_t* dst_ptr, tensor_t* src_ptr, tensor_t* state_ptr, int* headlens, 13 | int *cu_headlens, 14 | int dim) { 15 | // Create new tensor from cache and insert element into it. 16 | 17 | int head_idx = blockIdx.x; 18 | int thread_group = blockIdx.y; 19 | int tid = threadIdx.x + thread_group * blockDim.x; 20 | int num_threads = blockDim.x * gridDim.y; 21 | 22 | int headlen = headlens[head_idx]; 23 | 24 | // get position of src, dst, insert ptr 25 | int src_cum_off = cu_headlens[head_idx] * dim; 26 | int dst_cum_off = src_cum_off + head_idx * dim; 27 | int insert_off = (cu_headlens[head_idx+1] + head_idx) * dim; 28 | 29 | auto old_cache_ptr = src_ptr + src_cum_off; 30 | auto new_cache_ptr = dst_ptr + dst_cum_off; 31 | auto insert_cache_ptr = dst_ptr + insert_off; 32 | 33 | // copy old data 34 | for (int start_addr = 0; start_addr < headlen * dim; start_addr += kblock_size * num_threads) { 35 | auto src_addr = old_cache_ptr + start_addr + tid * kblock_size; 36 | auto dst_addr = new_cache_ptr + start_addr + tid * kblock_size; 37 | 38 | // TODO: LDSM speed up with SRAM 39 | #pragma unroll 40 | for (int i = 0; i < kblock_size; i++) { 41 | if (start_addr + tid * kblock_size + i >= headlen * dim) { 42 | break; 43 | } 44 | dst_addr[i] = src_addr[i]; 45 | } 46 | } 47 | 48 | // insert new data 49 | if (tid < dim) { 50 | auto insert_src_ptr = state_ptr + head_idx * dim + tid; 51 | auto insert_dst_addr = insert_cache_ptr + tid; 52 | *insert_dst_addr = *insert_src_ptr; 53 | } 54 | } 55 | 56 | torch::Tensor update_flatten_view(torch::Tensor &cache, torch::Tensor &state, torch::Tensor &headlens, torch::Tensor& cu_headlens) { 57 | TORCH_CHECK(headlens.dtype() == torch::kInt32, "expected headlens to be int32"); 58 | TORCH_CHECK(cu_headlens.dtype() == torch::kInt32, "expected cu_dst_pos to be int32"); 59 | 60 | auto cache_shape = cache.sizes(); 61 | 62 | int origin_len = cache_shape[0]; 63 | int head_dim = cache_shape[1]; 64 | int head_num = headlens.sizes()[0]; 65 | 66 | torch::Tensor out = torch::empty({origin_len + head_num, head_dim}, cache.options()); 67 | 68 | const int kblock_size = 1; 69 | const int num_threads_group = 1024; 70 | const int num_threads = 128; 71 | 72 | dim3 grid(head_num, num_threads_group); 73 | 74 | // TODO: dispatch with head_dim?? may loss performance 75 | dim3 block(num_threads); 76 | TORCH_CHECK(num_threads >= head_dim, "num threads should larger than head dim"); 77 | 78 | FP16_SWITCH(cache.dtype() == torch::kFloat16, [&] { 79 | auto kernel = update_flatten_view_kernel; 80 | kernel<<>>((elem_type*)out.data_ptr(), (elem_type*)cache.data_ptr(), (elem_type*)state.data_ptr(), (int*)headlens.data_ptr(), (int*)cu_headlens.data_ptr(), head_dim); 81 | }); 82 | 83 | // TODO: when to use sync or torch auto 84 | // cudaDeviceSynchronize(); 85 | 86 | return out; 87 | } 88 | 89 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 90 | // m.def("package_name", &function_name, "function_docstring"") 91 | 92 | m.def("update_flatten_view", &update_flatten_view, "update flatten view cache"); 93 | } 94 | -------------------------------------------------------------------------------- /csrc/csrc/static_switch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define FP16_SWITCH(COND, ...) \ 4 | [&] { \ 5 | if (COND) { \ 6 | using elem_type = at::Half; \ 7 | return __VA_ARGS__(); \ 8 | } else { \ 9 | using elem_type = at::BFloat16; \ 10 | return __VA_ARGS__(); \ 11 | } \ 12 | }() 13 | 14 | 15 | -------------------------------------------------------------------------------- /csrc/include/cuda_api.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #define DEBUG 1 6 | 7 | #ifdef DEBUG 8 | 9 | // NOTE:tensor malloc as device before we call 10 | // e.g. data.to("cuda") in python 11 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 14 | #define CUDA_ERROR_CHECK(condition) \ 15 | do { \ 16 | cudaError_t error = condition; \ 17 | if (error != cudaSuccess) { \ 18 | printf("CUDA_CHECK error in line %d of file %s \ 19 | : %s \n", \ 20 | __LINE__, __FILE__, cudaGetErrorString(error)); \ 21 | exit(EXIT_FAILURE); \ 22 | } \ 23 | } while (0) 24 | 25 | #else 26 | 27 | #define CHECK_CUDA(x) do { } while (0) 28 | #define CHECK_CONTIGUOUS(x) do { } while (0) 29 | #define CHECK_INPUT(x) do { } while (0) 30 | #define CUDA_ERROR_CHECK(condition) do { condition; } while (0) 31 | 32 | #endif // DEBUG 33 | 34 | 35 | -------------------------------------------------------------------------------- /csrc/makefile: -------------------------------------------------------------------------------- 1 | build: 2 | python build.py install 3 | 4 | .PHONY: build 5 | -------------------------------------------------------------------------------- /experiments/LongBench/GQA_eval_longbench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME="Mistral-7B-Instruct-v0.2" 4 | MODEL=${MODELS_DIR}/${MODEL_NAME} 5 | DATASET=${DATASETS_DIR}/LongBench 6 | MAX_LEN=31500 7 | 8 | echo "Testing $MODEL $DATASET $MAX_LEN" 9 | 10 | scopes=(128 256 512 1024) 11 | device="4090" 12 | 13 | for scope in ${scopes[@]}; do 14 | # ada-snapkv with gqa support 15 | # we suggest setting floor_alpha to 0.2 while using gqa 16 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode ada --compress_args_path c"$scope"_w32_k7_maxpool.json --floor_alpha 0.2 --out_name "$device"_ada_"$scope"_"$MODEL_NAME" --gqa_support 17 | # snapkv with gqa support 18 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode fix --compress_args_path c"$scope"_w32_k7_maxpool.json --out_name "$device"_fix_"$scope"_"$MODEL_NAME" --gqa_support 19 | # ada-snapkv without gqa support 20 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode ada --compress_args_path c"$scope"_w32_k7_maxpool.json --floor_alpha 0.2 --out_name "$device"_ada_"$scope"_"$MODEL_NAME" 21 | # snapkv without gqa support 22 | python pred.py -m $MODEL --max_length $MAX_LEN -d $DATASET --mode fix --compress_args_path c"$scope"_w32_k7_maxpool.json --out_name "$device"_fix_"$scope"_"$MODEL_NAME" 23 | done 24 | -------------------------------------------------------------------------------- /experiments/LongBench/README.md: -------------------------------------------------------------------------------- 1 | # Minimal version of LongBench 2 | Based on SnapKV -------------------------------------------------------------------------------- /experiments/LongBench/config/dataset2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": 128, 3 | "qasper": 128, 4 | "multifieldqa_en": 64, 5 | "multifieldqa_zh": 64, 6 | "hotpotqa": 32, 7 | "2wikimqa": 32, 8 | "musique": 32, 9 | "dureader": 128, 10 | "gov_report": 512, 11 | "qmsum": 512, 12 | "multi_news": 512, 13 | "vcsum": 512, 14 | "trec": 64, 15 | "triviaqa": 32, 16 | "samsum": 128, 17 | "lsht": 64, 18 | "passage_count": 32, 19 | "passage_retrieval_en": 32, 20 | "passage_retrieval_zh": 32, 21 | "lcc": 64, 22 | "repobench-p": 64 23 | } -------------------------------------------------------------------------------- /experiments/LongBench/config/dataset2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" 23 | } -------------------------------------------------------------------------------- /experiments/LongBench/config/model2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": 3500, 3 | "longchat-v1.5-7b-32k": 31500, 4 | "xgen-7b-8k": 7500, 5 | "internlm-7b-8k": 7500, 6 | "chatglm2-6b": 31500, 7 | "chatglm2-6b-32k": 31500, 8 | "chatglm3-6b-32k": 31500, 9 | "vicuna-v1.5-7b-16k": 15500, 10 | "mistral-7B-instruct-v0.2": 31500, 11 | "mistral-7B-instruct-v0.1": 31500, 12 | "mixtral-8x7B-instruct-v0.1": 31500, 13 | "llama-2-7B-32k-instruct": 31500, 14 | "lwm-text-chat-1m": 1048576, 15 | "lwm-text-1m": 1048576 16 | } 17 | -------------------------------------------------------------------------------- /experiments/LongBench/config/model2path.json: -------------------------------------------------------------------------------- 1 | { 2 | "llama2-7b-chat-4k": "meta-llama/Llama-2-7b-chat-hf", 3 | "longchat-v1.5-7b-32k": "lmsys/longchat-7b-v1.5-32k", 4 | "xgen-7b-8k": "Salesforce/xgen-7b-8k-inst", 5 | "internlm-7b-8k": "internlm/internlm-chat-7b-8k", 6 | "chatglm2-6b": "THUDM/chatglm2-6b", 7 | "chatglm2-6b-32k": "THUDM/chatglm2-6b-32k", 8 | "chatglm3-6b-32k": "THUDM/chatglm3-6b-32k", 9 | "vicuna-v1.5-7b-16k": "lmsys/vicuna-7b-v1.5-16k", 10 | "mistral-7B-instruct-v0.2": "mistralai/Mistral-7B-Instruct-v0.2", 11 | "mistral-7B-instruct-v0.1": "mistralai/Mistral-7B-Instruct-v0.1", 12 | "mixtral-8x7B-instruct-v0.1": "mistralai/Mixtral-8x7B-Instruct-v0.1", 13 | "llama-2-7B-32k-instruct": "togethercomputer/Llama-2-7B-32K-Instruct", 14 | "lwm-text-chat-1m": "LargeWorldModel/LWM-Text-Chat-1M", 15 | "lwm-text-1m": "LargeWorldModel/LWM-Text-1M" 16 | } 17 | -------------------------------------------------------------------------------- /experiments/LongBench/convert_to_execl.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | import os 4 | 5 | """ 6 | Utils to convert the results of the experiments to an excel file 7 | """ 8 | 9 | dataframes = [] 10 | 11 | for path in os.listdir('./pred/'): 12 | file_path = f'./pred/{path}/result.json' 13 | 14 | with open(file_path, 'r') as file: 15 | data = json.load(file) 16 | 17 | column_order = [ 18 | "narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", 19 | "musique", "gov_report", "qmsum", "multi_news", "trec", 20 | "triviaqa", "samsum", "passage_count", "passage_retrieval_en", 21 | "lcc", "repobench-p" 22 | ] 23 | 24 | data_renamed = { 25 | "narrativeqa": data.get("narrativeqa", -1), 26 | "qasper": data.get("qasper", -1), 27 | "multifieldqa_en": data.get("multifieldqa_en", -1), 28 | "hotpotqa": data.get("hotpotqa", -1), 29 | "2wikimqa": data.get("2wikimqa", -1), 30 | "musique": data.get("musique", -1), 31 | "gov_report": data.get("gov_report", -1), 32 | "qmsum": data.get("qmsum", -1), 33 | "multi_news": data.get("multi_news", -1), 34 | "trec": data.get("trec", -1), 35 | "triviaqa": data.get("triviaqa", -1), 36 | "samsum": data.get("samsum", -1), 37 | "passage_count": data.get("passage_count", -1), 38 | "passage_retrieval_en": data.get("passage_retrieval_en", -1), 39 | "lcc": data.get("lcc", -1), 40 | "repobench-p": data.get("repobench-p", -1) 41 | } 42 | 43 | df = pd.DataFrame([data_renamed], columns=column_order, index=[path]) 44 | dataframes.append(df) 45 | 46 | result_df = pd.concat(dataframes) 47 | output_path = './pred/combined_results.xlsx' 48 | result_df.to_excel(output_path, index_label='Folder Name') 49 | 50 | print(f"Excel file saved to {output_path}") 51 | -------------------------------------------------------------------------------- /experiments/LongBench/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import numpy as np 5 | 6 | from metrics import ( 7 | qa_f1_score, 8 | rouge_zh_score, 9 | qa_f1_zh_score, 10 | rouge_score, 11 | classification_score, 12 | retrieval_score, 13 | retrieval_zh_score, 14 | count_score, 15 | code_sim_score, 16 | ) 17 | 18 | """ 19 | check the sample num of each dataset 20 | """ 21 | 22 | dataset_samples = { 23 | "narrativeqa": 200, 24 | "qasper":200, 25 | "multifieldqa_en": 150, 26 | "hotpotqa": 200, 27 | "2wikimqa": 200, 28 | "musique": 200, 29 | "gov_report": 200, 30 | "qmsum": 200, 31 | "multi_news": 200, 32 | "trec": 200, 33 | "triviaqa": 200, 34 | "samsum": 200, 35 | "passage_retrieval_en": 200, 36 | "passage_count": 200, 37 | "lcc": 500, 38 | "repobench-p": 500, 39 | } 40 | dataset2metric = { 41 | "narrativeqa": qa_f1_score, 42 | "qasper": qa_f1_score, 43 | "multifieldqa_en": qa_f1_score, 44 | "multifieldqa_zh": qa_f1_zh_score, 45 | "hotpotqa": qa_f1_score, 46 | "2wikimqa": qa_f1_score, 47 | "musique": qa_f1_score, 48 | "dureader": rouge_zh_score, 49 | "gov_report": rouge_score, 50 | "qmsum": rouge_score, 51 | "multi_news": rouge_score, 52 | "vcsum": rouge_zh_score, 53 | "trec": classification_score, 54 | "triviaqa": qa_f1_score, 55 | "samsum": rouge_score, 56 | "lsht": classification_score, 57 | "passage_retrieval_en": retrieval_score, 58 | "passage_count": count_score, 59 | "passage_retrieval_zh": retrieval_zh_score, 60 | "lcc": code_sim_score, 61 | "repobench-p": code_sim_score, 62 | } 63 | 64 | def parse_args(args=None): 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--model', type=str, default=None) 67 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 68 | return parser.parse_args(args) 69 | 70 | def scorer_e(dataset, predictions, answers, lengths, all_classes): 71 | scores = {"0-4k": [], "4-8k": [], "8k+": []} 72 | for (prediction, ground_truths, length) in zip(predictions, answers, lengths): 73 | score = 0. 74 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 75 | prediction = prediction.lstrip('\n').split('\n')[0] 76 | for ground_truth in ground_truths: 77 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 78 | if length < 4000: 79 | scores["0-4k"].append(score) 80 | elif length < 8000: 81 | scores["4-8k"].append(score) 82 | else: 83 | scores["8k+"].append(score) 84 | for key in scores.keys(): 85 | scores[key] = round(100 * np.mean(scores[key]), 2) 86 | return scores 87 | 88 | def scorer(dataset, predictions, answers, all_classes): 89 | score_list = [] 90 | total_score = 0. 91 | for (prediction, ground_truths) in zip(predictions, answers): 92 | score = 0. 93 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 94 | prediction = prediction.lstrip('\n').split('\n')[0] 95 | for ground_truth in ground_truths: 96 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 97 | score_list.append(score) 98 | total_score += score 99 | return round(100 * total_score / len(predictions), 2),score_list 100 | 101 | if __name__ == '__main__': 102 | args = parse_args() 103 | 104 | for model in os.listdir("pred/"): 105 | scores = dict() 106 | scores_list = dict() 107 | args.model = model 108 | if args.e: 109 | path = f"pred_e/{args.model}/" 110 | else: 111 | path = f"pred/{args.model}/" 112 | all_files = os.listdir(path) 113 | print("Evaluating on:", all_files) 114 | for filename in all_files: 115 | if not filename.endswith("jsonl"): 116 | continue 117 | predictions, answers, lengths = [], [], [] 118 | dataset = filename.split('.')[0] 119 | with open(f"{path}{filename}", "r", encoding="utf-8") as f: 120 | line_cnt = 0 121 | for line in f: 122 | line_cnt += 1 123 | data = json.loads(line) 124 | predictions.append(data["pred"]) 125 | answers.append(data["answers"]) 126 | all_classes = data["all_classes"] 127 | if "length" in data: 128 | lengths.append(data["length"]) 129 | dataset_name = filename.split('.')[0] 130 | target_samples = dataset_samples.get(dataset_name, 200) 131 | if line_cnt != target_samples: 132 | print(f"Error: {dataset_name} has {line_cnt} samples, expected {target_samples}") 133 | continue 134 | if args.e: 135 | score = scorer_e(dataset, predictions, answers, lengths, all_classes) 136 | else: 137 | score,score_list = scorer(dataset, predictions, answers, all_classes) 138 | # if dataset == 'qasper': 139 | # score_e = scorer_e(dataset, predictions, answers, lengths, all_classes) 140 | scores[dataset] = score 141 | scores_list[dataset] = score_list 142 | # if dataset == 'qasper': 143 | # scores[dataset + '_e'] = score_e 144 | if args.e: 145 | # out_path = f"H2O/results/{args.model}/result.json" 146 | out_path = f"pred_e/{args.model}/result.json" 147 | 148 | else: 149 | # out_path = f"H2O/results/{args.model}/result.json" 150 | out_path = f"pred/{args.model}/result.json" 151 | list_out_path = f"pred/{args.model}/list_result.json" 152 | # with open(out_path_e, "w") as f: 153 | # json.dump(score_e, f, ensure_ascii=False, indent=4) 154 | with open(out_path, "w") as f: 155 | json.dump(scores, f, ensure_ascii=False, indent=4) 156 | with open(list_out_path,"w") as f: 157 | json.dump(scores_list,f,ensure_ascii=False,indent=4) -------------------------------------------------------------------------------- /experiments/LongBench/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | import jieba 5 | from fuzzywuzzy import fuzz 6 | import difflib 7 | 8 | from typing import List 9 | from collections import Counter 10 | from rouge import Rouge 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | 15 | def remove_articles(text): 16 | return re.sub(r"\b(a|an|the)\b", " ", text) 17 | 18 | def white_space_fix(text): 19 | return " ".join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | 31 | def normalize_zh_answer(s): 32 | """Lower text and remove punctuation, extra whitespace.""" 33 | 34 | def white_space_fix(text): 35 | return "".join(text.split()) 36 | 37 | def remove_punc(text): 38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 39 | all_punctuation = set(string.punctuation + cn_punctuation) 40 | return "".join(ch for ch in text if ch not in all_punctuation) 41 | 42 | def lower(text): 43 | return text.lower() 44 | 45 | return white_space_fix(remove_punc(lower(s))) 46 | 47 | def count_score(prediction, ground_truth, **kwargs): 48 | numbers = re.findall(r"\d+", prediction) 49 | right_num = 0 50 | for number in numbers: 51 | if str(number) == str(ground_truth): 52 | right_num += 1 53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 54 | return float(final_score) 55 | 56 | def retrieval_score(prediction, ground_truth, **kwargs): 57 | pattern = r'Paragraph (\d+)' 58 | matches = re.findall(pattern, ground_truth) 59 | ground_truth_id = matches[0] 60 | numbers = re.findall(r"\d+", prediction) 61 | right_num = 0 62 | for number in numbers: 63 | if str(number) == str(ground_truth_id): 64 | right_num += 1 65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 66 | return float(final_score) 67 | 68 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 69 | pattern = r'段落(\d+)' 70 | matches = re.findall(pattern, ground_truth) 71 | ground_truth_id = matches[0] 72 | numbers = re.findall(r"\d+", prediction) 73 | right_num = 0 74 | for number in numbers: 75 | if str(number) == str(ground_truth_id): 76 | right_num += 1 77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 78 | return float(final_score) 79 | 80 | def code_sim_score(prediction, ground_truth, **kwargs): 81 | all_lines = prediction.lstrip('\n').split('\n') 82 | prediction = "" 83 | for line in all_lines: 84 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 85 | prediction = line 86 | break 87 | return (fuzz.ratio(prediction, ground_truth) / 100) 88 | 89 | def classification_score(prediction, ground_truth, **kwargs): 90 | em_match_list = [] 91 | all_classes = kwargs["all_classes"] 92 | for class_name in all_classes: 93 | if class_name in prediction: 94 | em_match_list.append(class_name) 95 | for match_term in em_match_list: 96 | if match_term in ground_truth and match_term != ground_truth: 97 | em_match_list.remove(match_term) 98 | if ground_truth in em_match_list: 99 | score = (1.0 / len(em_match_list)) 100 | else: 101 | score = 0.0 102 | return score 103 | 104 | def rouge_score(prediction, ground_truth, **kwargs): 105 | rouge = Rouge() 106 | try: 107 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 108 | except: 109 | return 0.0 110 | return scores["rouge-l"]["f"] 111 | 112 | def rouge_zh_score(prediction, ground_truth, **kwargs): 113 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 114 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 115 | score = rouge_score(prediction, ground_truth) 116 | return score 117 | 118 | def f1_score(prediction, ground_truth, **kwargs): 119 | common = Counter(prediction) & Counter(ground_truth) 120 | num_same = sum(common.values()) 121 | if num_same == 0: 122 | return 0 123 | precision = 1.0 * num_same / len(prediction) 124 | recall = 1.0 * num_same / len(ground_truth) 125 | f1 = (2 * precision * recall) / (precision + recall) 126 | return f1 127 | 128 | def qa_f1_score(prediction, ground_truth, **kwargs): 129 | normalized_prediction = normalize_answer(prediction) 130 | normalized_ground_truth = normalize_answer(ground_truth) 131 | 132 | prediction_tokens = normalized_prediction.split() 133 | ground_truth_tokens = normalized_ground_truth.split() 134 | return f1_score(prediction_tokens, ground_truth_tokens) 135 | 136 | 137 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 138 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 139 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 140 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] 141 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] 142 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 143 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 144 | return f1_score(prediction_tokens, ground_truth_tokens) 145 | -------------------------------------------------------------------------------- /experiments/LongBench/pred.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datasets import load_dataset 3 | import torch 4 | import json 5 | from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 6 | from tqdm import tqdm 7 | import numpy as np 8 | import random 9 | import argparse 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | import gc 13 | import time 14 | 15 | from adaptive_snapkv.monkeypatch.monkeypatch import replace_mistral_adaptive, replace_llama_adaptive, config_compress, replace_llama_fixed, replace_mistral_fixed,replace_mistral_slm, replace_llama_slm 16 | 17 | def parse_args(args=None): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("-m", '--model_name_or_path', type=str, required=True) 20 | parser.add_argument('--max_length', type=int, required=True) 21 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 22 | parser.add_argument("-d", '--dataset', type=str, default="THUDM/LongBench") 23 | parser.add_argument("--out_name", type=str, required=True) 24 | parser.add_argument('--compress_args_path', type=str, default=None, help="Path to the compress args") 25 | # parser.add_argument('--adaptive', action='store_true', help="Use adaptive budgets allocation across heads") 26 | parser.add_argument('--mode', type=str, choices=['ada', 'fix', 'test', "slm"], help="Ada mode, fix mode or normal") 27 | parser.add_argument('--floor_alpha',type=float,default=0.2,help="floor_alpha budgets for each head") 28 | parser.add_argument('--gqa_support',action='store_true', default=False, help="init gqa_support") 29 | parser.add_argument('--gqa_func',type=str, default="mean", help="gqa operation:optional max mean") 30 | parser.add_argument('--normalize',action='store_true') 31 | parser.add_argument('--pyram',action='store_true',help="using pyram mode") 32 | parser.add_argument('--pyram_beta',default=20,type=int, help="hyper parameter for pyram") 33 | parser.add_argument('--budget',default=1024, type=int, help="budget size for kv cache") 34 | 35 | parser.add_argument('--budget_ratio',type=float,default=0.2,help="floor_alpha budgets for each head") 36 | return parser.parse_args(args) 37 | 38 | # This is the customized building prompt for chat models 39 | def build_chat(tokenizer, prompt, model_name): 40 | if "chatglm3" in model_name: 41 | prompt = tokenizer.build_chat_input(prompt) 42 | elif "chatglm" in model_name: 43 | prompt = tokenizer.build_prompt(prompt) 44 | elif "longchat" in model_name or "vicuna" in model_name: 45 | from fastchat.model import get_conversation_template 46 | conv = get_conversation_template("vicuna") 47 | conv.append_message(conv.roles[0], prompt) 48 | conv.append_message(conv.roles[1], None) 49 | prompt = conv.get_prompt() 50 | elif "llama2" in model_name: 51 | prompt = f"[INST]{prompt}[/INST]" 52 | elif "xgen" in model_name: 53 | header = ( 54 | "A chat between a curious human and an artificial intelligence assistant. " 55 | "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" 56 | ) 57 | prompt = header + f" ### Human: {prompt}\n###" 58 | elif "internlm" in model_name: 59 | prompt = f"<|User|>:{prompt}\n<|Bot|>:" 60 | elif "llama-3" in model_name.lower() and "instruct" in model_name.lower(): 61 | prompt = [{ "role": "user", "content": prompt}] 62 | prompt = tokenizer.apply_chat_template( 63 | prompt, 64 | tokenize=False, 65 | add_generation_prompt=True 66 | ) 67 | return prompt 68 | 69 | def post_process(response, model_name): 70 | if "xgen" in model_name: 71 | response = response.strip().replace("Assistant:", "") 72 | elif "internlm" in model_name: 73 | response = response.split("")[0] 74 | return response 75 | 76 | def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device, model_name_or_path, out_path): 77 | device = "cuda" 78 | preds = [] 79 | times = [] 80 | with open(f"{out_path}_tmp", "w", encoding="utf-8") as f: 81 | for json_obj in tqdm(data): 82 | prompt = prompt_format.format(**json_obj) 83 | # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) 84 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] 85 | if "chatglm3" in model_name_or_path: 86 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt", add_special_tokens=False).input_ids[0] 87 | if len(tokenized_prompt) > max_length: 88 | half = int(max_length/2) 89 | prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) 90 | if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks 91 | prompt = build_chat(tokenizer, prompt, model_name_or_path) 92 | if "chatglm3" in model_name_or_path: 93 | if dataset in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: 94 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) 95 | else: 96 | input = prompt.to(device) 97 | else: 98 | input = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) 99 | context_length = input.input_ids.shape[-1] 100 | 101 | torch.cuda.synchronize() 102 | t = time.time() 103 | if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue 104 | output = model.generate( 105 | **input, 106 | max_new_tokens=max_gen, 107 | num_beams=1, 108 | do_sample=False, 109 | temperature=1.0, 110 | min_length=context_length+1, 111 | eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], 112 | )[0] 113 | else: 114 | output = model.generate( 115 | **input, 116 | max_new_tokens=max_gen, 117 | num_beams=1, 118 | do_sample=False, 119 | temperature=1.0, 120 | eos_token_id=[tokenizer.eos_token_id], 121 | )[0] 122 | torch.cuda.synchronize() 123 | t = time.time() - t 124 | times.append(t) 125 | 126 | pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) 127 | pred = post_process(pred, model_name_or_path) 128 | preds.append(pred) 129 | 130 | 131 | json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"], "time": t}, f, ensure_ascii=False ) 132 | f.write('\n') 133 | f.flush() 134 | 135 | 136 | gc.collect() 137 | torch.cuda.empty_cache() 138 | 139 | 140 | with open(out_path, "w", encoding="utf-8") as f: 141 | for json_obj, pred, t in zip(data, preds, times): 142 | json.dump({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"], "time": t}, f, ensure_ascii=False ) 143 | f.write('\n') 144 | 145 | 146 | def seed_everything(seed): 147 | torch.manual_seed(seed) 148 | torch.cuda.manual_seed(seed) 149 | np.random.seed(seed) 150 | random.seed(seed) 151 | torch.backends.cudnn.benchmark = False 152 | torch.backends.cudnn.deterministic = True 153 | torch.cuda.manual_seed_all(seed) 154 | 155 | def load_model_and_tokenizer(path): 156 | tokenizer = AutoTokenizer.from_pretrained(path, 157 | trust_remote_code=True, 158 | ) 159 | model = AutoModelForCausalLM.from_pretrained(path, 160 | torch_dtype=torch.bfloat16, 161 | # TODO: hard code 162 | device_map="auto", 163 | attn_implementation="flash_attention_2", 164 | trust_remote_code=True, 165 | ) 166 | model = model.eval() 167 | return model, tokenizer 168 | 169 | if __name__ == '__main__': 170 | seed_everything(42) 171 | args = parse_args() 172 | world_size = torch.cuda.device_count() 173 | mp.set_start_method('spawn', force=True) 174 | 175 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 176 | model_name_or_path = args.model_name_or_path 177 | model_name = args.model_name_or_path.split("/")[-1] 178 | # define your model 179 | max_length = args.max_length 180 | if args.e: 181 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", \ 182 | "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 183 | else: 184 | datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ 185 | "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ 186 | "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] 187 | 188 | datasets = [ 189 | "qasper", "narrativeqa", "multifieldqa_en", # single doc 190 | "hotpotqa", "2wikimqa", "musique", # multi doc 191 | "trec", "triviaqa", "samsum", # few-shot 192 | "gov_report", "qmsum", "multi_news", # sum 193 | "passage_count", "passage_retrieval_en", # Synthetic 194 | "lcc", "repobench-p", # code 195 | ] 196 | 197 | print(datasets) 198 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output 199 | dataset2prompt = json.load(open("config/dataset2prompt.json", "r")) 200 | dataset2maxlen = json.load(open("config/dataset2maxlen.json", "r")) 201 | # predict on each dataset 202 | if not os.path.exists("pred"): 203 | os.makedirs("pred") 204 | if not os.path.exists("pred_e"): 205 | os.makedirs("pred_e") 206 | 207 | if args.mode == "ada": 208 | print("Ada mode") 209 | replace_mistral_adaptive() 210 | replace_llama_adaptive() 211 | elif args.mode == "fix": 212 | print("Fix mode") 213 | replace_mistral_fixed() 214 | replace_llama_fixed() 215 | elif args.mode == "slm": 216 | print("Slm mode") 217 | replace_mistral_slm() 218 | replace_llama_slm() 219 | else: 220 | print("Base mode") 221 | 222 | 223 | # NOTE: load model after replace 224 | model, tokenizer = load_model_and_tokenizer(model_name_or_path) 225 | 226 | config_compress(model, base_capacity=args.budget, pyram_mode=args.pyram, floor_alpha=args.floor_alpha, gqa_support=args.gqa_support, gqa_func=args.gqa_func) 227 | 228 | for dataset in datasets: 229 | if args.e: 230 | data = load_dataset(args.dataset, f"{dataset}_e", split='test', data_dir=f"{args.dataset}/data") 231 | if not os.path.exists(f"pred_e/{args.out_name}"): 232 | os.makedirs(f"pred_e/{args.out_name}") 233 | out_path = f"pred_e/{args.out_name}/{dataset}.jsonl" 234 | else: 235 | data = load_dataset(args.dataset, f"{dataset}", split='test', data_dir=f"{args.dataset}/data") 236 | if not os.path.exists(f"pred/{args.out_name}"): 237 | os.makedirs(f"pred/{args.out_name}") 238 | out_path = f"pred/{args.out_name}/{dataset}.jsonl" 239 | 240 | if os.path.exists(out_path): 241 | print(f"{out_path} exists skip") 242 | continue 243 | 244 | prompt_format = dataset2prompt[dataset] 245 | max_gen = dataset2maxlen[dataset] 246 | data_all = [data_sample for data_sample in data] 247 | # TODO: hard code single process, which use all gpus 248 | torch.cuda.synchronize() 249 | t = time.time() 250 | get_pred(model, tokenizer, data_all, max_length, max_gen, prompt_format, dataset, device, model_name_or_path, out_path) 251 | torch.cuda.synchronize() 252 | t = time.time() - t 253 | print(f"== {args.out_name} {dataset} Time: {t}") 254 | -------------------------------------------------------------------------------- /makefile: -------------------------------------------------------------------------------- 1 | i install: 2 | cd csrc && make 3 | pip install -e . 4 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "packaging"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "adakv" 7 | version = "0.0.1" 8 | 9 | dependencies = [ 10 | "rouge", 11 | "tokenizers", 12 | "rouge_score", 13 | "numpy==1.26.0", 14 | "jieba==0.42.1", 15 | "datasets", 16 | "accelerate", 17 | "tqdm==4.66.1", 18 | "icetk==0.0.7", 19 | "transformers==4.44.2", 20 | ] 21 | 22 | [tool.setuptools.packages.find] 23 | where = ["."] 24 | --------------------------------------------------------------------------------