├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── blog ├── README.md └── images │ ├── Attention.png │ ├── Keyformer_decoding.gif │ ├── Keyformer_overview.gif │ ├── accuracy.png │ ├── attn_weights.png │ ├── keyformer_decoding.gif │ ├── long_context_accuracy.png │ ├── motivation.png │ ├── performance.png │ ├── sparsity.png │ ├── speedup.png │ └── uneven_score.gif ├── conda-env.yml ├── conversation ├── README.md ├── data │ └── data_cache │ │ └── data_cache ├── dataset.py ├── dataset_download │ ├── download_orca.py │ └── download_soda.py ├── dialogue.py ├── run_conversation_task.sh └── utils.py ├── lm_eval_harness ├── README.md ├── evaluate_task_result.py ├── experiments.sh ├── generate_task_data.py ├── model_eval_data │ └── model_eval_data ├── output │ └── output ├── run.sh ├── run_lm_eval_harness.py ├── task_data │ └── task_data └── tasks │ ├── __init__.py │ ├── eval_harness.py │ └── util.py ├── models ├── README.md ├── cerebras-keyformer-lib │ ├── config.json │ ├── configuration_gpt2.py │ ├── modeling_gpt2.py │ └── modeling_gpt2_lm_eval_harness.py ├── gptj-keyformer-lib │ ├── config.json │ ├── configuration_gptj.py │ ├── modeling_gptj.py │ └── modeling_gptj_lm_eval_harness.py ├── model_download │ └── download_model.py └── mpt-keyformer-lib │ ├── attention.py │ ├── attention_keyformer_new_bias.py │ ├── attention_keyformer_old_bias.py │ ├── attention_lm_eval_harness.py │ ├── attention_streaming_llm.py │ ├── blocks.py │ ├── config.json │ ├── configuration_mpt.py │ └── modeling_mpt.py ├── requirements.txt └── summarization ├── README.md ├── __pycache__ ├── dataset.cpython-310.pyc ├── dataset.cpython-39.pyc ├── utils.cpython-310.pyc └── utils.cpython-39.pyc ├── data └── data_cache │ ├── .._data_data_cache_cnn_dailymail_3.0.0_0.0.0_d3306fda2d44efc3.lock │ └── data_cache ├── dataset.py ├── dataset_download ├── download_cnndm.py ├── download_longbench.py └── download_xsum.py ├── dataset_synthetic_long_context.py ├── out_model.score ├── out_model.summary ├── output.summary ├── run_summarization_task.sh ├── summarize.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.bin 2 | models/model/ 3 | plots/ 4 | **/*.pdf 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Thank you for your interest in contributing to Keyformer! We are in the middle of preparing the next steps to invite community contributions and make keyformer an integral part of the LLM architecture. 2 | 3 | You can still report any issues you observe with the usage of keyformer. For Issue reporting, we request you to please follow these guidelines - 4 | 5 | - If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it. If not, please file a new issue, providing as much relevant information as possible. 6 | 7 | ### Thank You 8 | Finally, thank you for your interest in contributing Keyformer and helping this make a great tool for everyone. 9 | 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference 2 | 3 | **Keyformer** proposes KV Cache reduction through key tokens identification and without the need 4 | for fine-tuning. 5 | 6 | This repository contains the source code implementation of **Keyformer**. 7 | 8 | This source code is available under the [Apache 2.0 License](LICENSE). 9 | 10 | For the readers, to get a quick overview of how **keyformer** reduces the KV Cache size and speeds up LLM inference, we invite you to browse through our blog - 11 | 12 | - [🎫 Blog](blog/README.md) - Quick summary of Keyformer 13 | 14 | But if you'd rather see **keyformer** in action first, please continue reading. 15 |

16 | ## ⚙️ Environment Setup 17 | 18 | #### conda Environment 19 | You can create a conda environment using the `conda-env.yml` file. 20 | ```bash 21 | conda env create --file=conda-env.yml 22 | conda activate keyformer-env 23 | ``` 24 |

25 | ## 💻 System Requirements 26 | 27 | Right now, **keyformer** has been tested on a 1xA100 node for `mosaicml/mpt-7b` and 28 | `EleutherAI/gpt-j-6B` with the models in `FP32` and evaluation done using `FP16` data formats. 29 | 30 | We welcome contributions to port and evaluate **keyformer** with different 31 | data formats, targeting lower model memory requirements and overall KVCache size 32 |

33 | ## 🏁 Model Download and Integration with Keyformer 34 | 35 | Note - **Keyformer** has been evaluated on the following models and we restrict this tutorial to the use of these - 36 | `['cerebras/Cerebras-GPT-6.7B', 37 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']` 38 | 39 | Clone the relevant model from huggingface into the `models` directory. For instance, to clone `mosaicml/mpt-7b` into `models/mpt-7b-keyformer`, do the following - 40 | 41 | ```bash 42 | cd models 43 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mosaicml/mpt-7b mpt-7b-keyformer 44 | cd .. 45 | ``` 46 | 47 | Then, download the model weights (`*.bin`) using `download_model.py` file and use the model_name as the argument. 48 | 49 | ```bash 50 | cd models/model_download 51 | python3 download_model.py --model_name mosaicml/mpt-7b 52 | cd ../../ 53 | ``` 54 | 55 | Please keep in mind that supported arguments for `model_name` are restricted to `['cerebras/Cerebras-GPT-6.7B', 56 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']` 57 | 58 | This step will download the model parameters by creating a `model` folder inside `model_download` directory. 59 | Downloaded model parameters are required to be moved to the same directory where model files are available. 60 | 61 | To move the model parameters from `models/model_download/model/` to `models/mpt-7b-keyformer` 62 | 63 | ```bash 64 | mv models/model_download/model/* models/mpt-7b-keyformer/. 65 | ``` 66 | 67 | ### Integrating Keyformer 68 | Copy the **Keyformer** source files from `models/mpt-keyformer-lib` to `models/mpt-7b-keyformer` 69 | 70 | ```bash 71 | cp -r models/mpt-keyformer-lib models/mpt-7b-keyformer 72 | ``` 73 | 74 | This sets up `models/mpt-7b-keyformer` to be used with the various tasks described below. 75 | 76 | Similarly, find the keyformer-lib files for other supported models in their respective directories in `models/`. 77 |

78 | ## Using Keyformer for KV Cache reduction 79 | 80 | After setting up a model with `keyformer`, let's run it with a downstream task of interest. In the case of `keyformer`, we provide examples on how to run downstream tasks of **summarization** and **conversation**. Further, for the purposes of evaluation, we provide a step-by-step tutorial on how to use **lm_eval_harness**. 81 | 82 | Depending on your task of interest, please refer to the following links for more details 83 | 84 | - [📖 Summarization](summarization/README.md) - Running the Summarization task with Keyformer 85 | - [💬 Conversation](conversation/README.md) - Running the Conversation task with Keyformer 86 | - [📊 LM Eval Harness](lm_eval_harness/README.md) - Using LM Eval harness for evaluation with Keyformer 87 |

88 | ## TODO 89 | 90 | [ ] Instructions to integrate keyformer with any model from huggingface 91 | 92 | [ ] Using keyformer with quantized models 93 |

94 | 95 | ## Thank You 96 | 97 | Keyformer uses open source components available on [huggingface](https://huggingface.co/), [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b), [cerebras/Cerebras-GPT-6.7B](https://huggingface.co/cerebras/Cerebras-GPT-6.7B), and 98 | [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6b). 99 |

100 | 101 | ## 📝 Citation 102 | ``` 103 | @article{2023keyformer, 104 | title={Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference}, 105 | author={Adnan, Muhammad and Arunkumar, Akhil and Jain, Gaurav and Nair, Prashant and Soloveychik, Ilya and Kamath, Purushotham}, 106 | journal={Proceedings of Machine Learning and Systems}, 107 | volume={7}, 108 | year={2024} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /blog/README.md: -------------------------------------------------------------------------------- 1 | # Keyformer: KV Cache reduction through attention sparsification for Efficient Generative Inference 2 | **Muhammad Adnan*1**, **Gaurav Jain2**, **Akhil Arunkumar2**, **Prashant J. Nair1**, **Ilya Soloveychik2**, **Purushotham Kamath2** 3 | 4 | * Work done when the author was an intern at d-Matrix. 5 | 6 | 1 Department of Electrical and Computer Engineering, The University of British Columbia, Vancouver, BC, Canada. 7 | 2 d-Matrix, Santa Clara, California, USA 8 | 9 | **TL;DR:** Generative AI inference is often bottlenecked by growing KV cache. There have been several numerous strategies proposed to compress the KVCache to allow longer inference-time context lengths. However, most of these techniques require fine-tuning and even pre-training in some cases. We introduce **Keyformer**, a novel token inference-time discarding technique to reduce KV cache size to improve the overall inference latency and token generation throughput while preserving accuracy. Keyformer capitalizes on the observation that during generative inference, approximately 90% of the attention weight is concentrated on a select subset of tokens called **key tokens** and discards the *irrelevant* tokens to reduce the overall size of the KVCache. Thus, with employing **Keyformer** we are able to reduce required the KV Cache size by 50% and the latency by up to 2.1x, and boost the token generation throughput by 2.4x, all while preserving the model’s accuracy. Further, we are able to support cases of larger batch sizes which otherwise result in **Out-Of-Memory** errors. 10 | 11 | ## How Keyformer works 12 | 13 | Attention mechanism exhibit varying amounts of sparsity throughout the large number of model decoder layers. As seen in Figure 1(Left), attention sparsity significantly varies for models of the same sizes and all for the same CNN/DailyMail dataset summarization task. On the other hand, Figure 1(Right), through a cumulative distributive function (CDF) shows how the attention score is concentrated within a with small number of tokens during text generation. What this translates into for us is the importance of certain key tokens during token generation and more importantly, the relative irrelevance of a majority of tokens during the same. 14 | 15 |

16 | 17 |
18 | Figure 1: (Left) Default attention sparsity for different models across layers. (Right) CDF of attention score for different models with 90% of attention score dedicated to 40% of tokens. 19 |

20 | 21 | In this work, **Keyformer**, we exploit this inherent sparsification within the decoder layers by identifying key tokens while still emphasizing on the recent tokens. We further adapt this behavior of discarding tokens by changing the score function and applying regularization to the unnormalized logits for key token(s) identification. 22 | 23 | ### What do we do for Regularization - Gumbel Distribution 24 | 25 | Once we have identified and discarded the irrelevant tokens, it is important we normalize the score function to account for this change. In that regard, we use the Gumbel distribution which enables our model to remain robust and adaptive. As an implementation strategy, we maintain a constant size, *k* of the KVCache and remove the *n − k* tokens from the context to avoid unwanted expansion of the memory. 26 | 27 | ### Bias Towards Initial Tokens 28 | 29 | Previous research has indicated a bias towards initial tokens. For instance, [StreamingLLM](https://arxiv.org/abs/2309.17453) highlights the importance of initial tokens as attention sinks, particularly in streaming applications. Similarly, [H2O](https://arxiv.org/abs/2306.14048) utilizes an accumulated attention score as a score function, which results in a predisposition 30 | towards initial tokens due to the cumulative effect during decoding iterations. To exploit this bias towards initial tokens and effectively model the distribution of maximum values (key tokens), we propose introducing a distribution that is skewed towards initial tokens while simultaneously features an asymmetric profile. This asymmetry introduces a pronounced right tail, which is characteristic of tokens typically drawn from the recent context window. Our choice of distribution is inspired by the [Gumbel distribution](https://arxiv.org/pdf/1502.02708.pdf). 31 | 32 |

33 |
34 | Figure 2: Overview of Keyformer during multiple phases. Prompt processing phase with n-tokens in KV cache along with induction of noise by Keyformer for key tokens identification. It selects w tokens from the recent window while k − w tokens from remaining n − w tokens to keep k tokens in KV cache. In text generation phase, decoding step with k-tokens in KV cache with tokens discarded from previous iteration. 35 |

36 | 37 | ### Keyformer Score Function 38 | 39 | To overcome the limitations of uneven score distribution and respective key tokens identification, we introduce a novel Keyformer score function $f_{\theta(keyformer)}$. This score function incorporates the Gumbel distribution into the unnormalized logits. However, the discarded tokens are not incorporated in anyway in forming the probability distribution that underlies the score function. To address this oversight and incorporate the discarded tokens, we introduce a temperature parameter denoted as $\tau$ , as shown in below Equation. 40 | 41 | ```math 42 | f_{\theta(keyformer)}(x_i) = e^{\frac{x_i + \zeta_i}{\tau}} / \sum_{j=1}^{k} e^{\frac{x_j + \zeta_j}{\tau}}, \ \ \ \ i=1,2,\dots,k 43 | ``` 44 |

45 |
46 | Figure 3: Design of Keyformer for a single decoder layer. 47 |

48 | 49 | ## Key Results 50 | 51 | We evaluate Keyformer across three significant model families: [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b), [Cerebras-GPT](https://huggingface.co/cerebras/Cerebras-GPT-6.7B), and [MPT](https://huggingface.co/mosaicml/mpt-7b) and on two representative text generation tasks, i.e. summarization task using the [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) dataset from [HELM](https://crfm.stanford.edu/helm/latest/), and the conversation task with the [SODA](https://huggingface.co/datasets/allenai/soda). GPT-J model is finetuned for summarization task, while Cerebras-GPT and MPT are pretrained models. For conversation tasks, we used the [MPT-chat](https://huggingface.co/mosaicml/mpt-7b-chat) version of the MPT model, which is fine-tuned for dialogue generation. Figure 4 shows that Keyformer achieves the baseline accuracy with 70% prompt KV cache size for Summarization task across different models while 90% of prompt KV cache for Conversation task while other baselines couldn't achieve the baseline accuracy. 52 | 53 |

54 |
55 | Figure 4: Accuracy comparison of Full Attention, Window Attention, H2O and Keyformer with varying KV cache size. Solid black line shows Full Attention without discarding any token and full KV cache. Red dotted line shows 99% accuracy mark. 56 |

57 | 58 | For long-context scenarios, we turned to the [GovReport](https://huggingface.co/datasets/ccdv/govreport-summarization) for extended document summarization. To tackle long document summarization, we employed the [MPT-storywriter](https://huggingface.co/mosaicml/mpt-7b-storywriter) version of the MPT model, fine-tuned for writing fictional stories with a context length of 65k and the ability to generate content as long as 84k tokens. 59 | 60 |

61 | 62 |
63 | Figure 5: (Left) Long context summarization using MPT-7B-storywriter model for GovReport dataset with a sequence length 64 | of 8k. (Right) Speedup of Keyformer with 50% KV cache reduction. 65 |

66 | 67 | Figure 5 shows that for long context summarization, Keyformer achieves baseline accuracy with 50% of prompt KV cache, improving the inference latency by 2.1x and token generation throughput by upto 2.4x. 68 | 69 | ## Get Started with Keyformer 70 | We have implemented Keyformer for multiple autoregressive models and provided respective model cards to run different tasks. Please find detailed instructions to use Keyformer [here](../README.md). 71 | 72 | ## Citation 73 | ``` 74 | @article{2024keyformer, 75 | title={Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference}, 76 | author={Adnan, Muhammad and Arunkumar, Akhil and Jain, Gaurav and Nair, Prashant and Soloveychik, Ilya and Kamath, Purushotham}, 77 | journal={Proceedings of Machine Learning and Systems}, 78 | volume={7}, 79 | year={2024} 80 | } 81 | ``` -------------------------------------------------------------------------------- /blog/images/Attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/Attention.png -------------------------------------------------------------------------------- /blog/images/Keyformer_decoding.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/Keyformer_decoding.gif -------------------------------------------------------------------------------- /blog/images/Keyformer_overview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/Keyformer_overview.gif -------------------------------------------------------------------------------- /blog/images/accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/accuracy.png -------------------------------------------------------------------------------- /blog/images/attn_weights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/attn_weights.png -------------------------------------------------------------------------------- /blog/images/keyformer_decoding.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/keyformer_decoding.gif -------------------------------------------------------------------------------- /blog/images/long_context_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/long_context_accuracy.png -------------------------------------------------------------------------------- /blog/images/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/motivation.png -------------------------------------------------------------------------------- /blog/images/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/performance.png -------------------------------------------------------------------------------- /blog/images/sparsity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/sparsity.png -------------------------------------------------------------------------------- /blog/images/speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/speedup.png -------------------------------------------------------------------------------- /blog/images/uneven_score.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/uneven_score.gif -------------------------------------------------------------------------------- /conda-env.yml: -------------------------------------------------------------------------------- 1 | name: keyformer-env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.01.10=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.2=h6a678d5_6 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=1.1.1t=h7f8727e_0 15 | - pip=23.0.1=py39h06a4308_0 16 | - python=3.9.16=h7a1cb2a_2 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=66.0.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.38.4=py39h06a4308_0 22 | - xz=5.4.2=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - absl-py==1.4.0 26 | - accelerate==0.18.0 27 | - aiohttp==3.8.4 28 | - aiosignal==1.3.1 29 | - appdirs==1.4.4 30 | - async-timeout==4.0.2 31 | - attrs==23.1.0 32 | - certifi==2022.12.7 33 | - charset-normalizer==3.1.0 34 | - click==8.1.3 35 | - cmake==3.26.3 36 | - datasets==2.14.5 37 | - dill==0.3.6 38 | - docker-pycreds==0.4.0 39 | - evaluate==0.4.0 40 | - filelock==3.12.0 41 | - fire==0.5.0 42 | - frozenlist==1.3.3 43 | - fsspec==2023.4.0 44 | - gitdb==4.0.10 45 | - gitpython==3.1.31 46 | - huggingface-hub==0.14.1 47 | - idna==3.4 48 | - jinja2==3.1.2 49 | - joblib==1.2.0 50 | - lit==16.0.2 51 | - markupsafe==2.1.2 52 | - mpmath==1.3.0 53 | - multidict==6.0.4 54 | - multiprocess==0.70.14 55 | - networkx==3.1 56 | - nltk==3.8.1 57 | - numpy==1.24.3 58 | - nvidia-cublas-cu11==11.10.3.66 59 | - nvidia-cuda-cupti-cu11==11.7.101 60 | - nvidia-cuda-nvrtc-cu11==11.7.99 61 | - nvidia-cuda-runtime-cu11==11.7.99 62 | - nvidia-cudnn-cu11==8.5.0.96 63 | - nvidia-cufft-cu11==10.9.0.58 64 | - nvidia-curand-cu11==10.2.10.91 65 | - nvidia-cusolver-cu11==11.4.0.1 66 | - nvidia-cusparse-cu11==11.7.4.91 67 | - nvidia-nccl-cu11==2.14.3 68 | - nvidia-nvtx-cu11==11.7.91 69 | - openai==0.27.6 70 | - packaging==23.1 71 | - pandas==2.0.1 72 | - pathtools==0.1.2 73 | - protobuf==4.22.4 74 | - psutil==5.9.5 75 | - pyarrow==12.0.0 76 | - python-dateutil==2.8.2 77 | - pytz==2023.3 78 | - pyyaml==6.0 79 | - regex==2023.5.5 80 | - requests==2.30.0 81 | - responses==0.18.0 82 | - rouge-score==0.1.2 83 | - sentencepiece==0.1.99 84 | - sentry-sdk==1.21.1 85 | - setproctitle==1.3.2 86 | - simplejson==3.19.1 87 | - six==1.16.0 88 | - smmap==5.0.0 89 | - sympy==1.11.1 90 | - termcolor==2.3.0 91 | - tokenizers==0.13.3 92 | - torch==2.0.0 93 | - tqdm==4.65.0 94 | - transformers==4.28.1 95 | - triton==2.0.0 96 | - typing-extensions==4.5.0 97 | - tzdata==2023.3 98 | - urllib3==2.0.2 99 | - wandb==0.15.1 100 | - xxhash==3.2.0 101 | - yarl==1.9.2 102 | - einops 103 | - lm-eval 104 | - ftfy 105 | prefix: /nfs/tattafos/miniconda3/envs/summarization-env 106 | -------------------------------------------------------------------------------- /conversation/README.md: -------------------------------------------------------------------------------- 1 | # 💬 Conversation Task 2 | 3 | ## Download dataset 4 | ```bash 5 | cd dataset_download 6 | python download_soda.py 7 | ``` 8 | 9 | ## Getting started with Conversation Task 10 | 11 | To get started with conversation task, follow use below inputs 12 | 13 | ``` 14 | python summarize.py --model_name \ 15 | --dataset_path \ 16 | --save_path \ 17 | --score_path \ 18 | --model_path \ 19 | --attentions_path 20 | --device cuda \ # Device 21 | --task conversation \ # Task 22 | --bs 1 \ # Batch Size 23 | --dtype float16 \ # Data type of model parameters 24 | --causal_lm \ # Causal language model for summarization 25 | --early_stopping \ # Enable early stopping while generation 26 | --output_summaries_only \ # Output only summary - No prompt 27 | --output_sequence_scores \ # Output sequence scores and enable for output attentions 28 | --save_attentions \ # Save attention weights - Token Generation only 29 | --save_prompt_attentions \ # If enabled only prompt attention weights are stored 30 | --padding_side left \ # Padding side 31 | --beam 4 \ # Beam Size 32 | --model_parallelize \ # Parallelize Model across all available GPUs 33 | --keyformer \ # Enable Keyformer 34 | --kv_cache 60 \ # KV cache percentage of prompt length 35 | --recent 30 \ # Recent window percentage 36 | --tau_init 1 \ # Initial temperature parameter for Gumbel 37 | --tau_end 2 \ # End temperature parameter for Gumbel 38 | --no_repeat_ngram_size 0 \ 39 | --repetition_penalty 1 \ 40 | --max_tokenizer_length 1920 \ # Maximum prompt size 41 | --max_new_tokens 128 \ # Maximum newly generated tokens 42 | --min_gen_length 30 \ # Minimum newly generated tokens 43 | --num_return_sequences 1 \ # Number os return summaries per input 44 | --seed 12345 \ # Random seed value for radom samples 45 | --n_obs 1000 \ # Number of input samples 46 | ``` 47 | 48 | Note: For data type of FP16, do not use model.half() instead utilize dtype in model creation [Link](https://stackoverflow.com/questions/69994731/what-is-the-difference-between-cuda-amp-and-model-half) 49 | -------------------------------------------------------------------------------- /conversation/data/data_cache/data_cache: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/conversation/data/data_cache/data_cache -------------------------------------------------------------------------------- /conversation/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from torch.nn.functional import pad 8 | from torch.utils.data import DataLoader 9 | from typing import Optional, Dict, Sequence 10 | import utils 11 | import io 12 | import copy 13 | import sys 14 | 15 | PROMPT_DICT = { 16 | "prompt_input": ( 17 | "Below is an instruction that describes a task, paired with an input that provides further context. " 18 | "Write a response that appropriately completes the request.\n\n" 19 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 20 | ), 21 | "prompt_no_input": ( 22 | "Below is an instruction that describes a task. " 23 | "Write a response that appropriately completes the request.\n\n" 24 | "### Instruction:\n{instruction}\n\n### Response:" 25 | ), 26 | } 27 | 28 | 29 | class Dataset(torch.utils.data.Dataset): 30 | """Characterizes a dataset for PyTorch""" 31 | 32 | def __init__( 33 | self, 34 | dataset_path, 35 | tokenizer, 36 | model_name, 37 | return_tensors, 38 | truncation, 39 | padding, 40 | max_length=None, 41 | total_count_override=None, 42 | perf_count_override=None, 43 | ): 44 | self.dataset = "soda" 45 | self.dataset_path = dataset_path 46 | self.model_name = model_name 47 | self.tokenizer = tokenizer 48 | self.return_tensors = return_tensors 49 | self.truncation = truncation 50 | self.padding = padding 51 | self.max_length = max_length 52 | 53 | self.list_data_dict = utils.jload(self.dataset_path) 54 | 55 | self.examples = [example for example in self.list_data_dict] 56 | 57 | # Getting random samples for evaluation 58 | if total_count_override > 0: 59 | self.rand_samples = np.random.randint( 60 | len(self.examples), size=(total_count_override) 61 | ).tolist() 62 | self.examples = [self.examples[i] for i in self.rand_samples] 63 | self.count = total_count_override 64 | else: 65 | self.count = len(self.examples) 66 | 67 | self.perf_count = perf_count_override or self.count 68 | 69 | def encode_samples(self): 70 | print("Encoding Samples") 71 | 72 | total_samples = self.count 73 | 74 | source_encoded_input_ids = [] 75 | source_encoded_attn_masks = [] 76 | 77 | for i in range(total_samples): 78 | source_encoded = self.tokenizer( 79 | self.sources[i], 80 | return_tensors=self.return_tensors, 81 | padding=self.padding, 82 | truncation=self.truncation, 83 | max_length=self.max_length, 84 | ) 85 | source_encoded_input_ids.append(source_encoded.input_ids) 86 | source_encoded_attn_masks.append(source_encoded.attention_mask) 87 | 88 | return source_encoded_input_ids, source_encoded_attn_masks 89 | 90 | def __len__(self): 91 | return self.count 92 | 93 | def __getitem__(self, index): 94 | return self.examples[index] 95 | -------------------------------------------------------------------------------- /conversation/dataset_download/download_orca.py: -------------------------------------------------------------------------------- 1 | # experiment config 2 | dataset_id = "Open-Orca/OpenOrca" 3 | 4 | from datasets import load_dataset, concatenate_datasets 5 | from transformers import AutoTokenizer 6 | import numpy as np 7 | import os 8 | import simplejson as json 9 | import sys 10 | 11 | save_dataset_path = os.environ.get("DATASET_ORCA_PATH", "../data") 12 | 13 | # Check whether the specified path exists or not 14 | isExist = os.path.exists(save_dataset_path) 15 | if not isExist: 16 | # Create a new directory because it does not exist 17 | os.makedirs(save_dataset_path) 18 | 19 | # Load dataset from the hub 20 | print("Loading Dataset!!") 21 | dataset = load_dataset(dataset_id, cache_dir="../data/data_cache") 22 | print("Dataset Loaded!!") 23 | 24 | 25 | def preprocess_function(sample): 26 | # create list of samples 27 | inputs = [] 28 | 29 | for i in range(0, len(sample["id"])): 30 | x = dict() 31 | x["instruction"] = sample["system_prompt"][i] 32 | x["input"] = sample["question"][i] 33 | x["output"] = sample["response"][i] 34 | 35 | inputs.append(x) 36 | model_inputs = dict() 37 | model_inputs["text"] = inputs 38 | 39 | return model_inputs 40 | 41 | 42 | # process dataset 43 | tokenized_dataset = dataset.map(preprocess_function, batched=True) 44 | 45 | # save dataset to disk 46 | 47 | with open(os.path.join(save_dataset_path, "orca_train.json"), "w") as write_f: 48 | json.dump(tokenized_dataset["train"]["text"], write_f, indent=4, ensure_ascii=False) 49 | 50 | 51 | print("Dataset saved in ", save_dataset_path) 52 | -------------------------------------------------------------------------------- /conversation/dataset_download/download_soda.py: -------------------------------------------------------------------------------- 1 | # experiment config 2 | dataset_id = "allenai/soda" 3 | 4 | from datasets import load_dataset, concatenate_datasets 5 | from transformers import AutoTokenizer 6 | import numpy as np 7 | import os 8 | import simplejson as json 9 | import sys 10 | 11 | save_dataset_path = os.environ.get("DATASET_SODA_PATH", "../data") 12 | 13 | # Check whether the specified path exists or not 14 | isExist = os.path.exists(save_dataset_path) 15 | if not isExist: 16 | # Create a new directory because it does not exist 17 | os.makedirs(save_dataset_path) 18 | 19 | # Load dataset from the hub 20 | dataset = load_dataset(dataset_id, cache_dir="../data/data_cache") 21 | 22 | 23 | def preprocess_function(sample): 24 | # create list of samples 25 | inputs = [] 26 | 27 | for i in range(0, len(sample["head"])): 28 | x = dict() 29 | x["head"] = sample["head"][i] 30 | x["relation"] = sample["relation"][i] 31 | x["tail"] = sample["tail"][i] 32 | x["literal"] = sample["literal"][i] 33 | x["narrative"] = sample["narrative"][i] 34 | x["dialogue"] = sample["dialogue"][i] 35 | x["speakers"] = sample["speakers"][i] 36 | x["PersonX"] = sample["PersonX"][i] 37 | x["PersonY"] = sample["PersonY"][i] 38 | x["PersonZ"] = sample["PersonZ"][i] 39 | x["original_index"] = sample["original_index"][i] 40 | x["split"] = sample["split"][i] 41 | x["head_answer"] = sample["head_answer"][i] 42 | x["pmi_head_answer"] = sample["pmi_head_answer"][i] 43 | x["relation_tail_answer"] = sample["relation_tail_answer"][i] 44 | x["pmi_relation_tail_answer"] = sample["pmi_relation_tail_answer"][i] 45 | 46 | inputs.append(x) 47 | model_inputs = dict() 48 | model_inputs["text"] = inputs 49 | 50 | return model_inputs 51 | 52 | 53 | # process dataset 54 | tokenized_dataset = dataset.map( 55 | preprocess_function, batched=True, remove_columns=list(dataset["train"].features) 56 | ) 57 | 58 | # save dataset to disk 59 | 60 | with open(os.path.join(save_dataset_path, "soda_eval.json"), "w") as write_f: 61 | json.dump( 62 | tokenized_dataset["validation"]["text"], write_f, indent=4, ensure_ascii=False 63 | ) 64 | 65 | 66 | print("Dataset saved in ", save_dataset_path) 67 | -------------------------------------------------------------------------------- /conversation/run_conversation_task.sh: -------------------------------------------------------------------------------- 1 | python dialogue.py --model_name \ 2 | --model_path \ 3 | --dataset_path \ 4 | --save_path \ 5 | --score_path \ 6 | --device cuda \ 7 | --task conversation \ 8 | --bs 1 \ 9 | --dtype bfloat16 \ 10 | --causal_lm \ 11 | --early_stopping \ 12 | --output_summaries_only \ 13 | --padding_side left \ 14 | --beam 4 \ 15 | --model_parallelize \ 16 | --keyformer \ 17 | --kv_cache 60 \ 18 | --recent 30 \ 19 | --tau_init 1 \ 20 | --tau_end 2 \ 21 | --no_repeat_ngram_size 0 \ 22 | --repetition_penalty 1 \ 23 | --max_tokenizer_length 1920 \ 24 | --max_new_tokens 128 \ 25 | --min_gen_length 30 \ 26 | --num_return_sequences 1 \ 27 | --seed 12345 \ 28 | --n_obs 1000 -------------------------------------------------------------------------------- /conversation/utils.py: -------------------------------------------------------------------------------- 1 | """From Huggingface Transformers.""" 2 | import itertools 3 | import json 4 | import linecache 5 | import os 6 | import pickle 7 | import warnings 8 | from logging import getLogger 9 | from pathlib import Path 10 | from typing import Callable, Dict, Iterable, List 11 | import io 12 | 13 | import numpy as np 14 | import torch 15 | import nltk 16 | from rouge_score import rouge_scorer, scoring 17 | 18 | from torch import nn 19 | from torch.utils.data import Dataset, Sampler 20 | 21 | 22 | def _make_r_io_base(f, mode: str): 23 | if not isinstance(f, io.IOBase): 24 | f = open(f, mode=mode) 25 | return f 26 | 27 | 28 | def jload(f, mode="r"): 29 | """Load a .json file into a dictionary.""" 30 | f = _make_r_io_base(f, mode) 31 | jdict = json.load(f) 32 | f.close() 33 | return jdict 34 | 35 | 36 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 37 | """From fairseq""" 38 | if target.dim() == lprobs.dim() - 1: 39 | target = target.unsqueeze(-1) 40 | nll_loss = -lprobs.gather(dim=-1, index=target) 41 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 42 | if ignore_index is not None: 43 | pad_mask = target.eq(ignore_index) 44 | nll_loss.masked_fill_(pad_mask, 0.0) 45 | smooth_loss.masked_fill_(pad_mask, 0.0) 46 | else: 47 | nll_loss = nll_loss.squeeze(-1) 48 | smooth_loss = smooth_loss.squeeze(-1) 49 | 50 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 51 | smooth_loss = smooth_loss.sum() 52 | eps_i = epsilon / lprobs.size(-1) 53 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 54 | return loss, nll_loss 55 | 56 | 57 | def encode_line( 58 | tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt" 59 | ): 60 | extra_kw = ( 61 | {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 62 | ) 63 | return tokenizer( 64 | [line], 65 | max_length=max_length, 66 | padding="max_length" if pad_to_max_length else None, 67 | truncation=True, 68 | return_tensors=return_tensors, 69 | **extra_kw, 70 | ) 71 | 72 | 73 | def lmap(f: Callable, x: Iterable) -> List: 74 | """list(map(f, x))""" 75 | return list(map(f, x)) 76 | 77 | 78 | def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: 79 | """Uses sacrebleu's corpus_bleu implementation.""" 80 | return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} 81 | 82 | 83 | def trim_batch( 84 | input_ids, 85 | pad_token_id, 86 | attention_mask=None, 87 | ): 88 | """Remove columns that are populated exclusively by pad_token_id""" 89 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 90 | if attention_mask is None: 91 | return input_ids[:, keep_column_mask] 92 | else: 93 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 94 | 95 | 96 | class Seq2SeqDataset(Dataset): 97 | """class Seq2SeqDataset""" 98 | 99 | def __init__( 100 | self, 101 | tokenizer, 102 | data_dir, 103 | max_source_length, 104 | max_target_length, 105 | type_path="train", 106 | n_obs=None, 107 | src_lang=None, 108 | tgt_lang=None, 109 | prefix="", 110 | ): 111 | super().__init__() 112 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 113 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 114 | self.src_lens = self.get_char_lens(self.src_file) 115 | self.max_source_length = max_source_length 116 | self.max_target_length = max_target_length 117 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 118 | self.tokenizer = tokenizer 119 | self.prefix = prefix 120 | if n_obs is not None: 121 | self.src_lens = self.src_lens[:n_obs] 122 | self.pad_token_id = self.tokenizer.pad_token_id 123 | self.src_lang = src_lang 124 | self.tgt_lang = tgt_lang 125 | 126 | def __len__(self): 127 | return len(self.src_lens) 128 | 129 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 130 | index = index + 1 # linecache starts at 1 131 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 132 | "\n" 133 | ) 134 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 135 | assert source_line, f"empty source line for index {index}" 136 | assert tgt_line, f"empty tgt line for index {index}" 137 | source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) 138 | target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) 139 | 140 | source_ids = source_inputs["input_ids"].squeeze() 141 | target_ids = target_inputs["input_ids"].squeeze() 142 | src_mask = source_inputs["attention_mask"].squeeze() 143 | return { 144 | "input_ids": source_ids, 145 | "attention_mask": src_mask, 146 | "decoder_input_ids": target_ids, 147 | } 148 | 149 | @staticmethod 150 | def get_char_lens(data_file): 151 | return [len(x) for x in Path(data_file).open().readlines()] 152 | 153 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 154 | """Reorganize batch data.""" 155 | input_ids = torch.stack([x["input_ids"] for x in batch]) 156 | masks = torch.stack([x["attention_mask"] for x in batch]) 157 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 158 | pad_token_id = self.pad_token_id 159 | y = trim_batch(target_ids, pad_token_id) 160 | source_ids, source_mask = trim_batch( 161 | input_ids, pad_token_id, attention_mask=masks 162 | ) 163 | batch = { 164 | "input_ids": source_ids, 165 | "attention_mask": source_mask, 166 | "decoder_input_ids": y, 167 | } 168 | return batch 169 | 170 | def make_sortish_sampler(self, batch_size): 171 | return SortishSampler(self.src_lens, batch_size) 172 | 173 | 174 | class TranslationDataset(Seq2SeqDataset): 175 | """A dataset that calls prepare_seq2seq_batch.""" 176 | 177 | def __init__(self, *args, **kwargs): 178 | super().__init__(*args, **kwargs) 179 | if self.max_source_length != self.max_target_length: 180 | warnings.warn( 181 | f"Mbart is using sequence lengths {self.max_source_length}, \ 182 | {self.max_target_length}. " 183 | f"Imbalanced sequence lengths may be undesired for \ 184 | translation tasks" 185 | ) 186 | 187 | def __getitem__(self, index) -> Dict[str, str]: 188 | index = index + 1 # linecache starts at 1 189 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 190 | "\n" 191 | ) 192 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 193 | assert source_line, f"empty source line for index {index}" 194 | assert tgt_line, f"empty tgt line for index {index}" 195 | return { 196 | "tgt_texts": tgt_line, 197 | "src_texts": source_line, 198 | } 199 | 200 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 201 | batch_encoding = self.tokenizer.prepare_seq2seq_batch( 202 | [x["src_texts"] for x in batch], 203 | src_lang=self.src_lang, 204 | tgt_texts=[x["tgt_texts"] for x in batch], 205 | tgt_lang=self.tgt_lang, 206 | max_length=self.max_source_length, 207 | max_target_length=self.max_target_length, 208 | ) 209 | return batch_encoding.data 210 | 211 | 212 | class SortishSampler(Sampler): 213 | """ 214 | Go through the text data by order of src length with a bit of randomness. 215 | From fastai repo. 216 | """ 217 | 218 | def __init__(self, data, batch_size): 219 | self.data, self.bs = data, batch_size 220 | 221 | def key(self, i): 222 | return self.data[i] 223 | 224 | def __len__(self) -> int: 225 | return len(self.data) 226 | 227 | def __iter__(self): 228 | idxs = np.random.permutation(len(self.data)) 229 | sz = self.bs * 50 230 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 231 | sort_idx = np.concatenate( 232 | [sorted(s, key=self.key, reverse=True) for s in ck_idx] 233 | ) 234 | sz = self.bs 235 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 236 | max_ck = np.argmax( 237 | [self.key(ck[0]) for ck in ck_idx] 238 | ) # find the chunk with the largest key, 239 | ck_idx[0], ck_idx[max_ck] = ( 240 | ck_idx[max_ck], 241 | ck_idx[0], 242 | ) # then make sure it goes first. 243 | sort_idx = ( 244 | np.concatenate(np.random.permutation(ck_idx[1:])) 245 | if len(ck_idx) > 1 246 | else np.array([], dtype=np.int) 247 | ) 248 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 249 | return iter(sort_idx) 250 | 251 | 252 | logger = getLogger(__name__) 253 | 254 | 255 | def use_task_specific_params(model, task): 256 | """Update config with summarization specific params.""" 257 | task_specific_params = model.config.task_specific_params 258 | 259 | if task_specific_params is not None: 260 | pars = task_specific_params.get(task, {}) 261 | print("Task Params: ", pars) 262 | logger.info(f"using task specific params for {task}: {pars}") 263 | model.config.update(pars) 264 | 265 | 266 | def pickle_load(path): 267 | """pickle.load(path)""" 268 | with open(path, "rb") as f: 269 | return pickle.load(f) 270 | 271 | 272 | def pickle_save(obj, path): 273 | """pickle.dump(obj, path)""" 274 | with open(path, "wb") as f: 275 | return pickle.dump(obj, f) 276 | 277 | 278 | def flatten_list(summary_ids: List[List]): 279 | return [x for x in itertools.chain.from_iterable(summary_ids)] 280 | 281 | 282 | def save_git_info(folder_path: str) -> None: 283 | """Save git information to output_dir/git_log.json""" 284 | repo_infos = get_git_info() 285 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 286 | 287 | 288 | def save_json(content, path): 289 | with open(path, "w") as f: 290 | json.dump(content, f, indent=4) 291 | 292 | 293 | def load_json(path): 294 | with open(path) as f: 295 | return json.load(f) 296 | 297 | 298 | def get_git_info(): 299 | repo = git.Repo(search_parent_directories=True) 300 | repo_infos = { 301 | "repo_id": str(repo), 302 | "repo_sha": str(repo.head.object.hexsha), 303 | "repo_branch": str(repo.active_branch), 304 | } 305 | return repo_infos 306 | 307 | 308 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] 309 | 310 | 311 | def calculate_rouge( 312 | output_lns: List[str], reference_lns: List[str], use_stemmer=True 313 | ) -> Dict: 314 | """Calculate rouge scores""" 315 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) 316 | aggregator = scoring.BootstrapAggregator() 317 | 318 | for reference_ln, output_ln in zip(reference_lns, output_lns): 319 | scores = scorer.score(reference_ln, output_ln) 320 | aggregator.add_scores(scores) 321 | 322 | result = aggregator.aggregate() 323 | return {k: v.mid.fmeasure for k, v in result.items()} 324 | 325 | 326 | def postprocess_text(preds, targets): 327 | preds = [pred.strip() for pred in preds] 328 | targets = [target.strip() for target in targets] 329 | 330 | # rougeLSum expects newline after each sentence 331 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 332 | targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] 333 | 334 | return preds, targets 335 | 336 | 337 | def freeze_params(model: nn.Module): 338 | for par in model.parameters(): 339 | par.requires_grad = False 340 | 341 | 342 | def grad_status(model: nn.Module) -> Iterable: 343 | return (par.requires_grad for par in model.parameters()) 344 | 345 | 346 | def any_requires_grad(model: nn.Module) -> bool: 347 | return any(grad_status(model)) 348 | 349 | 350 | def assert_all_frozen(model): 351 | model_grads: List[bool] = list(grad_status(model)) 352 | n_require_grad = sum(lmap(int, model_grads)) 353 | npars = len(model_grads) 354 | assert not any( 355 | model_grads 356 | ), f"{n_require_grad/npars:.1%} of {npars} weights require grad" 357 | 358 | 359 | def assert_not_all_frozen(model): 360 | model_grads: List[bool] = list(grad_status(model)) 361 | npars = len(model_grads) 362 | assert any(model_grads), f"none of {npars} weights require grad" 363 | -------------------------------------------------------------------------------- /lm_eval_harness/README.md: -------------------------------------------------------------------------------- 1 | # 📊 LM Eval Harness Tasks 2 | 3 | Evaluation of **Keyformer** on [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) framework tasks. 4 | 5 | ## Generate Task data 6 | ```python 7 | python -u generate_task_data.py \ 8 | --output-file ./task_data/${task}-${shots}.jsonl \ 9 | --task-name ${task} \ 10 | --num-fewshot ${shots} 11 | ``` 12 | 13 | ## Generate Output data with model 14 | #### Full Attention 15 | ```python 16 | python -u run_lm_eval_harness.py \ 17 | --input-path ./task_data/${task}-${shots}.jsonl \ 18 | --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \ 19 | --model-name ${model_name} \ 20 | --model-path ${model_path} \ 21 | --dtype ${dtype} \ 22 | --kv_cache ${kv_cache} \ 23 | --recent ${recent} 24 | ``` 25 | #### Keyformer 26 | ```python 27 | python -u run_lm_eval_harness.py \ 28 | --input-path ./task_data/${task}-${shots}.jsonl \ 29 | --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \ 30 | --model-name ${model_name} \ 31 | --model-path ${model_path} \ 32 | --dtype ${dtype} \ 33 | --keyformer \ 34 | --kv_cache ${kv_cache} \ 35 | --recent ${recent} 36 | ``` 37 | 38 | ## Evaluate the performance 39 | ```python 40 | python -u evaluate_task_result.py \ 41 | --result-file ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \ 42 | --output-file ./output/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \ 43 | --task-name ${task} \ 44 | --num-fewshot ${shots} \ 45 | --model-name ${model_name} 46 | ``` -------------------------------------------------------------------------------- /lm_eval_harness/evaluate_task_result.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from lm_eval import evaluator, tasks 6 | from tasks import EvalHarnessAdaptor 7 | 8 | 9 | def json_to_key(obj): 10 | return json.dumps(obj) 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser( 15 | prog="ProgramName", 16 | description="What the program does", 17 | epilog="Text at the bottom of help", 18 | ) 19 | 20 | parser.add_argument("--result-file", type=str, default="result.jsonl") 21 | parser.add_argument("--output-file", type=str, default="output.jsonl") 22 | parser.add_argument("--task-name", type=str, default="hellaswag") 23 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 24 | parser.add_argument("--debug", action="store_true", default=False) 25 | parser.add_argument("--num-fewshot", type=int, default=0) 26 | args = parser.parse_args() 27 | 28 | os.environ["MODEL_NAME"] = args.model_name 29 | """ 30 | if args.model_type == 'opt': 31 | os.environ['MODEL_NAME'] = "facebook/opt-66b" 32 | elif args.model_type == 'bloom': 33 | os.environ['MODEL_NAME'] = "bigscience/bloom" 34 | elif args.model_type == 'gpt_neox': 35 | os.environ['MODEL_NAME'] = "EleutherAI/gpt-neox-20b" 36 | elif args.model_type == 'llama': 37 | os.environ['MODEL_NAME'] = "huggyllama/llama-7b" 38 | else: 39 | assert False 40 | """ 41 | seq = 1024 42 | total_batch = 1 43 | pe = "fixed" 44 | 45 | class RealRunner: 46 | def __init__(self, args): 47 | self.results = {} 48 | 49 | with open(args.result_file, "r") as f: 50 | for line in f: 51 | if line.strip() == "": 52 | continue 53 | 54 | item = json.loads(line) 55 | 56 | request = item["request"] 57 | result = item["result"] 58 | 59 | self.results[json_to_key(request)] = result 60 | 61 | print(f"{len(self.results)} items in the cache") 62 | 63 | def eval(self, batch): 64 | from tasks.eval_harness import tokenizer 65 | 66 | mask_loss = [] 67 | each_correct = [] 68 | 69 | for i, text in enumerate(batch["text"]): 70 | request = { 71 | "best_of": 1, 72 | "echo": True, 73 | "logprobs": 1, 74 | "max_tokens": 0, 75 | "model": "x", 76 | "n": 1, 77 | "prompt": text, 78 | "request_type": "language-model-inference", 79 | "stop": None, 80 | "temperature": 0, 81 | "top_p": 1, 82 | } 83 | 84 | key = json_to_key(request) 85 | 86 | correct = True 87 | 88 | if key in self.results: 89 | result = self.results[key] 90 | 91 | token_logprobs = result["choices"][0]["logprobs"]["token_logprobs"] 92 | tokens = result["choices"][0]["logprobs"]["tokens"] 93 | top_logprobs = result["choices"][0]["logprobs"]["top_logprobs"] 94 | assert token_logprobs[0] is None 95 | 96 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 97 | 98 | obs = batch["obs"][i] 99 | target = batch["target"][i] 100 | eval_mask = batch["eval_mask"][i] 101 | 102 | n_positive = 0 103 | sum_lobprob = 0 104 | if args.debug: 105 | print(target) 106 | for i, mask in enumerate(eval_mask): 107 | try: 108 | if i + 1 >= len(tokens): 109 | break 110 | 111 | if mask == True: 112 | if args.debug: 113 | print( 114 | tokens[i + 1], 115 | next(iter(top_logprobs[i + 1].keys())), 116 | ) 117 | correct = correct and ( 118 | tokens[i + 1] 119 | == next(iter(top_logprobs[i + 1].keys())) 120 | ) 121 | sum_lobprob += token_logprobs[i + 1] 122 | n_positive += 1 123 | except Exception as e: 124 | raise e 125 | 126 | # avg_logprob = sum(token_logprobs[1:]) / (len(token_logprobs) - 1) 127 | avg_logprob = sum_lobprob / n_positive 128 | 129 | mask_loss.append(-avg_logprob) 130 | 131 | each_correct.append(correct) 132 | 133 | else: 134 | assert False 135 | 136 | out = { 137 | "mask_loss": mask_loss, 138 | "each_correct": each_correct, 139 | } 140 | 141 | return out 142 | 143 | t = RealRunner(args) 144 | 145 | adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed") 146 | 147 | results = evaluator.evaluate( 148 | adaptor, 149 | tasks.get_task_dict( 150 | [ 151 | args.task_name 152 | # "lambada_openai", 153 | # "piqa", 154 | # "hellaswag", 155 | # "winogrande", 156 | # "mathqa", 157 | # "pubmedqa", 158 | # "boolq", 159 | # "cb", 160 | # "copa", 161 | # "multirc", 162 | # "record", 163 | # "wic", 164 | # "wsc", 165 | ] 166 | ), 167 | False, 168 | args.num_fewshot, 169 | None, 170 | ) 171 | 172 | dumped = json.dumps(results, indent=2) 173 | print(dumped) 174 | 175 | with open(args.output_file, "w") as outfile: 176 | outfile.write(dumped) 177 | -------------------------------------------------------------------------------- /lm_eval_harness/experiments.sh: -------------------------------------------------------------------------------- 1 | ## =================== MPT-7B ==================== 2 | #bash run.sh <1/0 (1 for Keyformer)> 3 | # ==> Full Attention 4 | bash run.sh openbookqa 0 mosaicml/mpt-7b ./MPT-7B mosaicml-mpt-7b float16 0 60 30 5 | # ==> Keyformer 6 | bash run.sh openbookqa 0 mosaicml/mpt-7b ./MPT-7B mosaicml-mpt-7b float16 1 60 30 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /lm_eval_harness/generate_task_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | 6 | from lm_eval import evaluator, tasks 7 | from tasks import EvalHarnessAdaptor 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser( 12 | prog="ProgramName", 13 | description="What the program does", 14 | epilog="Text at the bottom of help", 15 | ) 16 | 17 | parser.add_argument("--output-file", type=str, default="input.jsonl") 18 | parser.add_argument("--task-name", type=str, default="hellaswag") 19 | parser.add_argument("--num-fewshot", type=int, default=0) 20 | args = parser.parse_args() 21 | 22 | if os.path.isfile(args.output_file): 23 | print(f"{args.task_name} data with {args.num_fewshot} shots already exist!!") 24 | sys.exit() 25 | 26 | seq = 1024 27 | total_batch = 1 28 | pe = "fixed" 29 | 30 | with open(args.output_file, "w") as f: 31 | pass 32 | 33 | class DryRunner: 34 | def eval(self, batch): 35 | with open(args.output_file, "a") as f: 36 | for text in batch["text"]: 37 | item = { 38 | "best_of": 1, 39 | "echo": True, 40 | "logprobs": 1, 41 | "max_tokens": 0, 42 | "model": "x", 43 | "n": 1, 44 | "prompt": text, 45 | "request_type": "language-model-inference", 46 | "stop": None, 47 | "temperature": 0, 48 | "top_p": 1, 49 | } 50 | 51 | f.write(json.dumps(item) + "\n") 52 | 53 | out = { 54 | "mask_loss": [1.0] * len(batch), 55 | "each_correct": [True] * len(batch), 56 | } 57 | return out 58 | 59 | t = DryRunner() 60 | adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed") 61 | results = evaluator.evaluate( 62 | adaptor, 63 | tasks.get_task_dict( 64 | [ 65 | args.task_name 66 | # "lambada_openai", 67 | # "piqa", 68 | # "hellaswag", 69 | # "winogrande", 70 | # "mathqa", 71 | # "pubmedqa", 72 | # "boolq", 73 | # "cb", 74 | # "copa", 75 | # "multirc", 76 | # "record", 77 | # "wic", 78 | # "wsc", 79 | ] 80 | ), 81 | False, 82 | args.num_fewshot, 83 | None, 84 | ) 85 | print(f"Finished Generating {args.task_name} data with {args.num_fewshot} shots!!") 86 | -------------------------------------------------------------------------------- /lm_eval_harness/model_eval_data/model_eval_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/lm_eval_harness/model_eval_data/model_eval_data -------------------------------------------------------------------------------- /lm_eval_harness/output/output: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/lm_eval_harness/output/output -------------------------------------------------------------------------------- /lm_eval_harness/run.sh: -------------------------------------------------------------------------------- 1 | # ## Obtain inference data 2 | task=$1 3 | shots=$2 4 | model_name=$3 5 | model_path=$4 6 | model_type=$5 7 | dtype=$6 8 | keyformer=$7 9 | kv_cache=$8 10 | recent=$9 11 | 12 | echo " " 13 | echo "--------------------------------------------" 14 | echo "Task ${task} data with ${shots} shots generation" 15 | echo "--------------------------------------------" 16 | 17 | python -u generate_task_data.py --output-file ./task_data/${task}-${shots}.jsonl --task-name ${task} --num-fewshot ${shots} 18 | 19 | ## Inference, and generate output json file 20 | echo " " 21 | echo "--------------------------------------------" 22 | echo "LM Eval Harness for model ${model_name}" 23 | echo "--------------------------------------------" 24 | 25 | if [ $keyformer -eq 0 ] 26 | then 27 | echo "====> Full Attention!!" 28 | python -u run_lm_eval_harness.py --input-path ./task_data/${task}-${shots}.jsonl --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --model-name ${model_name} --model-path ${model_path} --dtype ${dtype} --kv_cache ${kv_cache} --recent ${recent} 29 | else 30 | echo "====> Keyformer enabled with ${kv_cache}% KV Cache with ${recent}% recent tokens!!" 31 | python -u run_lm_eval_harness.py --input-path ./task_data/${task}-${shots}.jsonl --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --model-name ${model_name} --model-path ${model_path} --dtype ${dtype} --keyformer --kv_cache ${kv_cache} --recent ${recent} 32 | fi 33 | 34 | ## Evaluate results 35 | echo " " 36 | echo "--------------------------------------------" 37 | echo "Results Generation" 38 | echo "--------------------------------------------" 39 | python -u evaluate_task_result.py --result-file ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --output-file ./output/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --task-name ${task} --num-fewshot ${shots} --model-name ${model_name} -------------------------------------------------------------------------------- /lm_eval_harness/run_lm_eval_harness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json, tqdm 3 | import torch 4 | import copy 5 | import sys 6 | 7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser( 11 | prog="ProgramName", 12 | description="What the program does", 13 | epilog="Text at the bottom of help", 14 | ) 15 | 16 | parser.add_argument("--input-path", type=str, default=None) 17 | parser.add_argument("--output-path", type=str, default=None) 18 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 19 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 20 | parser.add_argument("--model-type", type=str, default="opt") 21 | parser.add_argument("--cache-dir", type=str, default="./") 22 | parser.add_argument( 23 | "--dtype", 24 | default="float32", 25 | help="data type of the model, choose from float16, bfloat16 and float32", 26 | ) 27 | ############################################### Keyformer ################################################################# 28 | parser.add_argument( 29 | "--keyformer", action="store_true", help="Keyformer enabled - reduced KV cache" 30 | ) 31 | parser.add_argument( 32 | "--kv_cache", 33 | type=float, 34 | default=60, 35 | required=False, 36 | help="KV cache percentage for Keyformer", 37 | ) 38 | parser.add_argument( 39 | "--recent", 40 | type=float, 41 | default=30, 42 | required=False, 43 | help="Recent window percentage for Keyformer", 44 | ) 45 | parser.add_argument( 46 | "--tau_init", 47 | type=float, 48 | default=1.0, 49 | required=False, 50 | help="Initial temperature", 51 | ) 52 | parser.add_argument( 53 | "--tau_end", type=float, default=2.0, required=False, help="Final temperature" 54 | ) 55 | ########################################################################################################################### 56 | parser.add_argument( 57 | "--token_discard", 58 | action="store_true", 59 | help="When enable token discarded below threshold", 60 | ) 61 | parser.add_argument( 62 | "--sparse_threshold", 63 | type=float, 64 | default=0.0, 65 | required=False, 66 | help="Threshold for sparsification", 67 | ) 68 | parser.add_argument( 69 | "--prompt_sparse_threshold", 70 | type=float, 71 | default=0.0, 72 | required=False, 73 | help="Threshold for prompt sparsification", 74 | ) 75 | parser.add_argument( 76 | "--sparse_itr", 77 | type=int, 78 | default=1, 79 | required=False, 80 | help="Tokene Generation Iteration for updating token discard mask", 81 | ) 82 | 83 | args = parser.parse_args() 84 | 85 | input_path = args.input_path 86 | output_path = args.output_path 87 | model_name = args.model_name 88 | model_path = args.model_path 89 | dtype = args.dtype 90 | keyformer = args.keyformer 91 | kv_cache = args.kv_cache 92 | recent = args.recent 93 | tau_init = args.tau_init 94 | tau_end = args.tau_end 95 | 96 | if dtype == "bfloat16": 97 | amp_enabled = True 98 | amp_dtype = torch.bfloat16 99 | elif dtype == "float16": 100 | amp_enabled = True 101 | amp_dtype = torch.float16 102 | else: 103 | amp_enabled = False 104 | amp_dtype = torch.float32 105 | 106 | if model_name.split("/")[0] == "mosaicml": 107 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 108 | config.attn_config["attn_impl"] = "torch" 109 | config.init_device = "cuda:0" 110 | config.use_cache = True 111 | config.torch_dtype = dtype 112 | config.keyformer_config["keyformer"] = keyformer 113 | config.keyformer_config["kv_cache"] = kv_cache 114 | config.keyformer_config["recent"] = recent 115 | config.keyformer_config["tau_init"] = tau_init 116 | config.keyformer_config["tau_delta"] = tau_end - tau_init 117 | model = AutoModelForCausalLM.from_pretrained( 118 | model_path, config=config, torch_dtype=amp_dtype, trust_remote_code=True 119 | ) 120 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 121 | 122 | elif model_name == "EleutherAI/gpt-j-6B": 123 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 124 | config.keyformer_config["keyformer"] = keyformer 125 | config.keyformer_config["kv_cache"] = kv_cache 126 | config.keyformer_config["recent"] = recent 127 | config.keyformer_config["tau_init"] = tau_init 128 | config.keyformer_config["tau_delta"] = tau_end - tau_init 129 | model = AutoModelForCausalLM.from_pretrained( 130 | model_path, config=config, torch_dtype=amp_dtype, trust_remote_code=True 131 | ) 132 | tokenizer = AutoTokenizer.from_pretrained(model_name) 133 | 134 | else: 135 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 136 | config.keyformer_config["keyformer"] = keyformer 137 | config.keyformer_config["kv_cache"] = kv_cache 138 | config.keyformer_config["recent"] = recent 139 | config.keyformer_config["tau_init"] = tau_init 140 | config.keyformer_config["tau_delta"] = tau_end - tau_init 141 | model = AutoModelForCausalLM.from_pretrained( 142 | model_path, config=config, torch_dtype=amp_dtype, trust_remote_code=True 143 | ) 144 | tokenizer = AutoTokenizer.from_pretrained(model_name) 145 | 146 | model.eval().cuda() 147 | 148 | requests = [] 149 | with open(input_path, "r") as f: 150 | for line in f: 151 | if line.strip() != "": 152 | requests.append(json.loads(line)) 153 | 154 | results = [] 155 | with torch.no_grad(): 156 | for request in tqdm.tqdm(requests): 157 | result = {"request": request, "result": {}} 158 | prompt = request["prompt"] 159 | input_ids = tokenizer( 160 | prompt, add_special_tokens=False, return_tensors="pt" 161 | ).input_ids.to(model.device) 162 | 163 | logits = model(input_ids).logits.log_softmax(dim=-1) 164 | values, indices = logits.squeeze(0).topk(dim=-1, k=1) 165 | tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0)) 166 | 167 | gold_indices = input_ids[:, 1:] # skip first 168 | logprobs = [None] + torch.gather( 169 | logits, -1, gold_indices.unsqueeze(-1) 170 | ).squeeze(-1).squeeze(0).detach().cpu().tolist() 171 | top_logprobs = [None] + [ 172 | {tokenizer.convert_ids_to_tokens(i.item()): v.item()} 173 | for v, i in zip(values.squeeze(-1), indices.squeeze(-1)) 174 | ] 175 | 176 | result["result"] = { 177 | "choices": [ 178 | { 179 | "text": prompt, 180 | "logprobs": { 181 | "tokens": tokens, 182 | "token_logprobs": logprobs, 183 | "top_logprobs": top_logprobs, 184 | "text_offset": [], 185 | }, 186 | "finish_reason": "length", 187 | } 188 | ], 189 | "request_time": {"batch_time": 0, "batch_size": 1}, 190 | } 191 | 192 | results.append(result) 193 | 194 | with open(output_path, "w") as f: 195 | for result in results: 196 | f.write(json.dumps(result) + "\n") 197 | -------------------------------------------------------------------------------- /lm_eval_harness/task_data/task_data: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/lm_eval_harness/task_data/task_data -------------------------------------------------------------------------------- /lm_eval_harness/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from tasks.eval_harness import EvalHarnessAdaptor -------------------------------------------------------------------------------- /lm_eval_harness/tasks/eval_harness.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import os 4 | import transformers 5 | from lm_eval.base import LM 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | from tasks.util import sample_batch, shrink_seq 10 | import multiprocessing 11 | import ftfy 12 | 13 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 14 | 15 | tokenizer = None 16 | 17 | 18 | def process_init(): 19 | global tokenizer 20 | model_name = os.environ.get("MODEL_NAME", "facebook/opt-1.3b") 21 | 22 | if model_name.split("/")[0] == "mosaicml": 23 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 24 | tokenizer.model_max_length = int(1e30) 25 | tokenizer.pad_token = "<|endoftext|>" 26 | elif model_name == "EleutherAI/gpt-j-6B": 27 | tokenizer = AutoTokenizer.from_pretrained(model_name) 28 | tokenizer.model_max_length = int(1e30) 29 | tokenizer.pad_token = "<|endoftext|>" 30 | else: 31 | tokenizer = AutoTokenizer.from_pretrained(model_name) 32 | tokenizer.model_max_length = int(1e30) 33 | tokenizer.pad_token = "<|endoftext|>" 34 | 35 | 36 | def process_request(x, seq): 37 | global tokenizer 38 | 39 | ctx, cont = x 40 | # ctx_tokens = tokenizer.encode("<|endoftext|>" + ftfy.fix_text(ctx, normalization="NFKC")) 41 | ctx_text = ftfy.fix_text(ctx, normalization="NFKC") 42 | cont_text = ftfy.fix_text(cont, normalization="NFKC") 43 | all_text = ctx_text + cont_text 44 | 45 | ctx_tokens = tokenizer(ctx_text, add_special_tokens=False)["input_ids"] 46 | cont_tokens = tokenizer(cont_text, add_special_tokens=False)["input_ids"] 47 | 48 | all_tokens = ctx_tokens + cont_tokens 49 | all_tokens = np.array(all_tokens)[-seq:] # truncate sequence at seq length 50 | 51 | provided_ctx = len(all_tokens) - 1 52 | pad_amount = seq - provided_ctx 53 | 54 | return { 55 | "obs": np.pad( 56 | all_tokens[:-1], ((0, pad_amount),), constant_values=tokenizer.pad_token_id 57 | ), 58 | "target": np.pad( 59 | all_tokens[1:], ((0, pad_amount),), constant_values=tokenizer.pad_token_id 60 | ), 61 | "ctx_length": seq, 62 | "eval_mask": np.logical_and( 63 | np.arange(0, seq) > len(all_tokens) - len(cont_tokens) - 2, 64 | np.arange(0, seq) < len(all_tokens) - 1, 65 | ), 66 | "prompt": ctx_text, 67 | "target": cont_text, 68 | "text": all_text, 69 | } 70 | 71 | 72 | class EvalHarnessAdaptor(LM): 73 | def greedy_until(self, requests): 74 | raise Exception("unimplemented") 75 | 76 | def loglikelihood_rolling(self, requests): 77 | raise Exception("unimplemented") 78 | 79 | def __init__(self, tpu_cluster, seq, batch, shrink, min_seq=None): 80 | super().__init__() 81 | self.tpu = tpu_cluster 82 | self.seq = seq 83 | self.batch = batch 84 | self.shrink = shrink 85 | self.min_seq = min_seq 86 | 87 | self.pool = multiprocessing.Pool(processes=1, initializer=process_init) 88 | # self.pool = multiprocessing.Pool(initializer=process_init) 89 | process_init() 90 | 91 | def convert_requests(self, requests): 92 | return self.pool.imap(partial(process_request, seq=self.seq), requests) 93 | 94 | def loglikelihood(self, requests): 95 | output = [] 96 | 97 | r = self.convert_requests(requests) 98 | zero_example = process_request(requests[0], self.seq) 99 | 100 | for b in tqdm( 101 | sample_batch(r, self.batch, zero_example), 102 | desc="LM eval harness", 103 | total=len(requests) // self.batch, 104 | ): 105 | if self.shrink: 106 | b = shrink_seq(b, min_seq=self.min_seq) 107 | 108 | out = self.tpu.eval(b) 109 | 110 | for loss, correct in zip(out["mask_loss"], out["each_correct"]): 111 | output.append((float(-loss), bool(correct))) 112 | 113 | return output 114 | -------------------------------------------------------------------------------- /lm_eval_harness/tasks/util.py: -------------------------------------------------------------------------------- 1 | from itertools import zip_longest 2 | 3 | import numpy as np 4 | 5 | 6 | def grouper(n, iterable, fillvalue): 7 | "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" 8 | args = [iter(iterable)] * n 9 | return zip_longest(fillvalue=fillvalue, *args) 10 | 11 | 12 | # divide the seq length by 2 until it would truncate actual context 13 | def shrink_seq(examples, min_seq=None): 14 | length = examples["obs"].shape[-1] 15 | 16 | new_length = length // 2 17 | 18 | if min_seq is not None: 19 | if new_length < min_seq: 20 | return examples 21 | 22 | max_length = np.max(examples["eval_mask"] * np.arange(0, length)) + 1 23 | 24 | if max_length < new_length: 25 | examples["obs"] = examples["obs"][:, :new_length] 26 | examples["target"] = examples["target"][:, :new_length] 27 | examples["eval_mask"] = examples["eval_mask"][:, :new_length] 28 | 29 | return shrink_seq(examples, min_seq=min_seq) 30 | else: 31 | return examples 32 | 33 | 34 | def sample_batch(examples, bs, zero_example_shape): 35 | zero_example = { 36 | "obs": np.zeros_like(zero_example_shape["obs"]), 37 | "target": np.zeros_like(zero_example_shape["target"]), 38 | "eval_mask": np.zeros_like(zero_example_shape["eval_mask"]), 39 | "ctx_length": 0, 40 | } 41 | 42 | for batch in grouper(bs, examples, zero_example): 43 | batch_flattened = { 44 | "obs": [], 45 | "target": [], 46 | "eval_mask": [], 47 | "ctx_length": [], 48 | "text": [], 49 | } 50 | 51 | for sample in batch: 52 | batch_flattened["obs"].append(sample["obs"]) 53 | batch_flattened["target"].append(sample["target"]) 54 | batch_flattened["eval_mask"].append(sample["eval_mask"]) 55 | batch_flattened["ctx_length"].append(sample["ctx_length"]) 56 | batch_flattened["text"].append(sample["text"]) 57 | 58 | batch_flattened["obs"] = np.array(batch_flattened["obs"]) 59 | batch_flattened["target"] = np.array(batch_flattened["target"]) 60 | batch_flattened["eval_mask"] = np.array(batch_flattened["eval_mask"]) 61 | batch_flattened["ctx_length"] = np.array(batch_flattened["ctx_length"]) 62 | 63 | yield batch_flattened 64 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # 🏁 Models 2 | 3 | ## LLM Model Checkpoint 4 | Note - **Keyformer** has been evaluated on the following models and we restrict this tutorial to the use of these - 5 | `['cerebras/Cerebras-GPT-6.7B', 6 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']` 7 | 8 | Clone the relevant model from huggingface into the `models` directory. For instance, to clone `mosaicml/mpt-7b` into `models/mpt-7b-keyformer`, do the following - 9 | 10 | ```bash 11 | cd models 12 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mosaicml/mpt-7b mpt-7b-keyformer 13 | cd .. 14 | ``` 15 | 16 | Then, download the model weights (`*.bin`) using `download_model.py` file and use the model_name as the argument. 17 | 18 | ```bash 19 | cd models/model_download 20 | python3 download_model.py --model_name mosaicml/mpt-7b 21 | cd ../../ 22 | ``` 23 | 24 | Please keep in mind that supported arguments for `model_name` are restricted to `['cerebras/Cerebras-GPT-6.7B', 25 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']` 26 | 27 | This step will download the model parameters by creating a `model` folder inside `model_download` directory. 28 | Downloaded model parameters are required to be moved to the same directory where model files are available. 29 | 30 | To move the model parameters from `models/model_download/model/` to `models/mpt-7b-keyformer` 31 | 32 | ```bash 33 | mv models/model_download/model/* models/mpt-7b-keyformer/ 34 | ``` 35 | 36 | ### Integrating Keyformer 37 | Copy the **Keyformer** source files from `models/mpt-keyformer-lib` to `models/mpt-7b-keyformer` 38 | 39 | ```bash 40 | cp -r models/mpt-keyformer-lib models/mpt-7b-keyformer 41 | ``` 42 | 43 | This sets up `models/mpt-7b-keyformer` to be used with the various tasks described below. 44 | 45 | Similarly, find the keyformer-lib files for other supported models in their respective directories in `models/`. 46 |

47 | ## LLM Model Cards 48 | 49 | We have provided model cards for below models. 50 | - [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b) 51 | - [MPT](https://huggingface.co/mosaicml/mpt-7b) 52 | - [Cerebras-GPT](https://huggingface.co/cerebras/Cerebras-GPT-6.7B) 53 | -------------------------------------------------------------------------------- /models/cerebras-keyformer-lib/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "cerebras/Cerebras-GPT-6.7B", 3 | "activation_function": "gelu", 4 | "architectures": [ 5 | "GPT2LMHeadModel" 6 | ], 7 | "auto_map": { 8 | "AutoConfig": "configuration_gpt2.GPT2Config", 9 | "AutoModelForCausalLM": "modeling_gpt2.GPT2LMHeadModel" 10 | }, 11 | "keyformer_config": { 12 | "keyformer": false, 13 | "kv_cache": 60, 14 | "recent": 30, 15 | "tau_init": 1.0, 16 | "tau_delta": 0.01 17 | }, 18 | "attn_pdrop": 0.0, 19 | "bos_token_id": 50256, 20 | "embd_pdrop": 0.0, 21 | "eos_token_id": 50256, 22 | "initializer_range": 0.02, 23 | "layer_norm_epsilon": 1e-05, 24 | "model_type": "gpt2", 25 | "n_embd": 4096, 26 | "n_head": 32, 27 | "n_inner": 16384, 28 | "n_layer": 32, 29 | "n_positions": 2048, 30 | "reorder_and_upcast_attn": false, 31 | "resid_pdrop": 0.0, 32 | "scale_attn_by_inverse_layer_idx": false, 33 | "scale_attn_weights": true, 34 | "summary_activation": null, 35 | "summary_first_dropout": 0.1, 36 | "summary_proj_to_labels": true, 37 | "summary_type": "cls_index", 38 | "summary_use_proj": true, 39 | "torch_dtype": "float32", 40 | "transformers_version": "4.28.1", 41 | "use_cache": true, 42 | "vocab_size": 50257 43 | } 44 | -------------------------------------------------------------------------------- /models/cerebras-keyformer-lib/configuration_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ OpenAI GPT-2 configuration""" 17 | from collections import OrderedDict 18 | from typing import Any, List, Mapping, Optional 19 | 20 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available 21 | from transformers import PretrainedConfig 22 | from transformers.onnx import OnnxConfigWithPast, PatchingSpec 23 | from transformers.utils import logging 24 | 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = { 29 | "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json", 30 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json", 31 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json", 32 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json", 33 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json", 34 | } 35 | keyformer_config_defaults: Dict = { 36 | "keyformer": False, 37 | "kv_cache": 60, 38 | "recent": 30, 39 | "tau_init": 1.0, 40 | "tau_delta": 0.01, 41 | } 42 | 43 | 44 | class GPT2Config(PretrainedConfig): 45 | """ 46 | This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to 47 | instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a 48 | configuration with the defaults will yield a similar configuration to that of the GPT-2 49 | [gpt2](https://huggingface.co/gpt2) architecture. 50 | 51 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 52 | documentation from [`PretrainedConfig`] for more information. 53 | 54 | 55 | Args: 56 | vocab_size (`int`, *optional*, defaults to 50257): 57 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 58 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. 59 | n_positions (`int`, *optional*, defaults to 1024): 60 | The maximum sequence length that this model might ever be used with. Typically set this to something large 61 | just in case (e.g., 512 or 1024 or 2048). 62 | n_embd (`int`, *optional*, defaults to 768): 63 | Dimensionality of the embeddings and hidden states. 64 | n_layer (`int`, *optional*, defaults to 12): 65 | Number of hidden layers in the Transformer encoder. 66 | n_head (`int`, *optional*, defaults to 12): 67 | Number of attention heads for each attention layer in the Transformer encoder. 68 | n_inner (`int`, *optional*, defaults to None): 69 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 70 | activation_function (`str`, *optional*, defaults to `"gelu"`): 71 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. 72 | resid_pdrop (`float`, *optional*, defaults to 0.1): 73 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 74 | embd_pdrop (`float`, *optional*, defaults to 0.1): 75 | The dropout ratio for the embeddings. 76 | attn_pdrop (`float`, *optional*, defaults to 0.1): 77 | The dropout ratio for the attention. 78 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): 79 | The epsilon to use in the layer normalization layers. 80 | initializer_range (`float`, *optional*, defaults to 0.02): 81 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 82 | summary_type (`string`, *optional*, defaults to `"cls_index"`): 83 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 84 | [`TFGPT2DoubleHeadsModel`]. 85 | 86 | Has to be one of the following options: 87 | 88 | - `"last"`: Take the last token hidden state (like XLNet). 89 | - `"first"`: Take the first token hidden state (like BERT). 90 | - `"mean"`: Take the mean of all tokens hidden states. 91 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 92 | - `"attn"`: Not implemented now, use multi-head attention. 93 | summary_use_proj (`bool`, *optional*, defaults to `True`): 94 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 95 | [`TFGPT2DoubleHeadsModel`]. 96 | 97 | Whether or not to add a projection after the vector extraction. 98 | summary_activation (`str`, *optional*): 99 | Argument used when doing sequence summary. Used in for the multiple choice head in 100 | [`GPT2DoubleHeadsModel`]. 101 | 102 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. 103 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`): 104 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 105 | [`TFGPT2DoubleHeadsModel`]. 106 | 107 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. 108 | summary_first_dropout (`float`, *optional*, defaults to 0.1): 109 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 110 | [`TFGPT2DoubleHeadsModel`]. 111 | 112 | The dropout ratio to be used after the projection and activation. 113 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 114 | Scale attention weights by dividing by sqrt(hidden_size).. 115 | use_cache (`bool`, *optional*, defaults to `True`): 116 | Whether or not the model should return the last key/values attentions (not used by all models). 117 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): 118 | Whether to additionally scale attention weights by `1 / layer_idx + 1`. 119 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): 120 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention 121 | dot-product/softmax to float() when training with mixed precision. 122 | keyformer_config (Dict): A dictionary used to configure the keyformer parameters: 123 | keyformer (bool): When enabled, keyformer based KV cache reduction 124 | kv_cache (float): KV cache percentage for Keyformer. 125 | recent (float): Recent window percentage for Keyformer. 126 | tau_init (float): Initial temperature for Keyformer score function calculation. 127 | tau_delta (float): Delta temperature change for Keyformer score function. 128 | 129 | Example: 130 | 131 | ```python 132 | >>> from transformers import GPT2Config, GPT2Model 133 | 134 | >>> # Initializing a GPT2 configuration 135 | >>> configuration = GPT2Config() 136 | 137 | >>> # Initializing a model (with random weights) from the configuration 138 | >>> model = GPT2Model(configuration) 139 | 140 | >>> # Accessing the model configuration 141 | >>> configuration = model.config 142 | ```""" 143 | 144 | model_type = "gpt2" 145 | keys_to_ignore_at_inference = ["past_key_values"] 146 | attribute_map = { 147 | "hidden_size": "n_embd", 148 | "max_position_embeddings": "n_positions", 149 | "num_attention_heads": "n_head", 150 | "num_hidden_layers": "n_layer", 151 | } 152 | 153 | def __init__( 154 | self, 155 | vocab_size=50257, 156 | n_positions=1024, 157 | n_embd=768, 158 | n_layer=12, 159 | n_head=12, 160 | n_inner=None, 161 | activation_function="gelu_new", 162 | resid_pdrop=0.1, 163 | embd_pdrop=0.1, 164 | attn_pdrop=0.1, 165 | layer_norm_epsilon=1e-5, 166 | initializer_range=0.02, 167 | summary_type="cls_index", 168 | summary_use_proj=True, 169 | summary_activation=None, 170 | summary_proj_to_labels=True, 171 | summary_first_dropout=0.1, 172 | scale_attn_weights=True, 173 | use_cache=True, 174 | bos_token_id=50256, 175 | eos_token_id=50256, 176 | scale_attn_by_inverse_layer_idx=False, 177 | reorder_and_upcast_attn=False, 178 | keyformer_config: Dict = keyformer_config_defaults, 179 | **kwargs, 180 | ): 181 | self.vocab_size = vocab_size 182 | self.n_positions = n_positions 183 | self.n_embd = n_embd 184 | self.n_layer = n_layer 185 | self.n_head = n_head 186 | self.n_inner = n_inner 187 | self.activation_function = activation_function 188 | self.resid_pdrop = resid_pdrop 189 | self.embd_pdrop = embd_pdrop 190 | self.attn_pdrop = attn_pdrop 191 | self.layer_norm_epsilon = layer_norm_epsilon 192 | self.initializer_range = initializer_range 193 | self.summary_type = summary_type 194 | self.summary_use_proj = summary_use_proj 195 | self.summary_activation = summary_activation 196 | self.summary_first_dropout = summary_first_dropout 197 | self.summary_proj_to_labels = summary_proj_to_labels 198 | self.scale_attn_weights = scale_attn_weights 199 | self.use_cache = use_cache 200 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx 201 | self.reorder_and_upcast_attn = reorder_and_upcast_attn 202 | # ====== Keyformer ====== 203 | self.keyformer_config = keyformer_config 204 | # ======================= 205 | self.bos_token_id = bos_token_id 206 | self.eos_token_id = eos_token_id 207 | 208 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 209 | 210 | 211 | class GPT2OnnxConfig(OnnxConfigWithPast): 212 | def __init__( 213 | self, 214 | config: PretrainedConfig, 215 | task: str = "default", 216 | patching_specs: List[PatchingSpec] = None, 217 | use_past: bool = False, 218 | ): 219 | super().__init__( 220 | config, task=task, patching_specs=patching_specs, use_past=use_past 221 | ) 222 | if not getattr(self._config, "pad_token_id", None): 223 | # TODO: how to do that better? 224 | self._config.pad_token_id = 0 225 | 226 | @property 227 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 228 | common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) 229 | if self.use_past: 230 | self.fill_with_past_key_values_(common_inputs, direction="inputs") 231 | common_inputs["attention_mask"] = { 232 | 0: "batch", 233 | 1: "past_sequence + sequence", 234 | } 235 | else: 236 | common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} 237 | 238 | return common_inputs 239 | 240 | @property 241 | def num_layers(self) -> int: 242 | return self._config.n_layer 243 | 244 | @property 245 | def num_attention_heads(self) -> int: 246 | return self._config.n_head 247 | 248 | def generate_dummy_inputs( 249 | self, 250 | tokenizer: PreTrainedTokenizer, 251 | batch_size: int = -1, 252 | seq_length: int = -1, 253 | is_pair: bool = False, 254 | framework: Optional[TensorType] = None, 255 | ) -> Mapping[str, Any]: 256 | common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( 257 | tokenizer, 258 | batch_size=batch_size, 259 | seq_length=seq_length, 260 | is_pair=is_pair, 261 | framework=framework, 262 | ) 263 | 264 | # We need to order the input in the way they appears in the forward() 265 | ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) 266 | 267 | # Need to add the past_keys 268 | if self.use_past: 269 | if not is_torch_available(): 270 | raise ValueError( 271 | "Cannot generate dummy past_keys inputs without PyTorch installed." 272 | ) 273 | else: 274 | import torch 275 | 276 | batch, seqlen = common_inputs["input_ids"].shape 277 | # Not using the same length for past_key_values 278 | past_key_values_length = seqlen + 2 279 | past_shape = ( 280 | batch, 281 | self.num_attention_heads, 282 | past_key_values_length, 283 | self._config.hidden_size // self.num_attention_heads, 284 | ) 285 | ordered_inputs["past_key_values"] = [ 286 | (torch.zeros(past_shape), torch.zeros(past_shape)) 287 | for _ in range(self.num_layers) 288 | ] 289 | 290 | ordered_inputs["attention_mask"] = common_inputs["attention_mask"] 291 | if self.use_past: 292 | mask_dtype = ordered_inputs["attention_mask"].dtype 293 | ordered_inputs["attention_mask"] = torch.cat( 294 | [ 295 | ordered_inputs["attention_mask"], 296 | torch.ones(batch, past_key_values_length, dtype=mask_dtype), 297 | ], 298 | dim=1, 299 | ) 300 | 301 | return ordered_inputs 302 | 303 | @property 304 | def default_onnx_opset(self) -> int: 305 | return 13 306 | -------------------------------------------------------------------------------- /models/gptj-keyformer-lib/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "EleutherAI/gpt-j-6b", 3 | "activation_function": "gelu_new", 4 | "architectures": [ 5 | "GPTJForCausalLM" 6 | ], 7 | "auto_map": { 8 | "AutoConfig": "configuration_gptj.GPTJConfig", 9 | "AutoModelForCausalLM": "modeling_gptj.GPTJForCausalLM" 10 | }, 11 | "keyformer_config": { 12 | "keyformer": false, 13 | "kv_cache": 60, 14 | "recent": 30, 15 | "tau_init": 1.0, 16 | "tau_delta": 0.01 17 | }, 18 | "attn_pdrop": 0.0, 19 | "bos_token_id": 50256, 20 | "embd_pdrop": 0.0, 21 | "eos_token_id": 50256, 22 | "gradient_checkpointing": false, 23 | "initializer_range": 0.02, 24 | "layer_norm_epsilon": 1e-05, 25 | "model_type": "gptj", 26 | "n_embd": 4096, 27 | "n_head": 16, 28 | "n_inner": null, 29 | "n_layer": 28, 30 | "n_positions": 2048, 31 | "resid_pdrop": 0.0, 32 | "rotary": true, 33 | "rotary_dim": 64, 34 | "scale_attn_weights": true, 35 | "summary_activation": null, 36 | "summary_first_dropout": 0.1, 37 | "summary_proj_to_labels": true, 38 | "summary_type": "cls_index", 39 | "summary_use_proj": true, 40 | "task_specific_params": { 41 | "text-generation": { 42 | "do_sample": true, 43 | "max_length": 50, 44 | "temperature": 1.0 45 | } 46 | }, 47 | "tie_word_embeddings": false, 48 | "tokenizer_class": "GPT2Tokenizer", 49 | "torch_dtype": "float32", 50 | "transformers_version": "4.28.0.dev0", 51 | "use_cache": true, 52 | "vocab_size": 50401 53 | } 54 | -------------------------------------------------------------------------------- /models/gptj-keyformer-lib/configuration_gptj.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ GPT-J model configuration""" 16 | from collections import OrderedDict 17 | from typing import Any, List, Mapping, Optional 18 | 19 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available 20 | from transformers import PretrainedConfig 21 | from transformers.onnx import OnnxConfigWithPast, PatchingSpec 22 | from transformers.utils import logging 23 | 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP = { 28 | "EleutherAI/gpt-j-6B": "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json", 29 | # See all GPT-J models at https://huggingface.co/models?filter=gpt_j 30 | } 31 | keyformer_config_defaults: Dict = { 32 | "keyformer": False, 33 | "kv_cache": 60, 34 | "recent": 30, 35 | "tau_init": 1.0, 36 | "tau_delta": 0.01, 37 | } 38 | 39 | 40 | class GPTJConfig(PretrainedConfig): 41 | r""" 42 | This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J 43 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 44 | defaults will yield a similar configuration to that of the GPT-J 45 | [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from 46 | [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] 47 | for more information. 48 | 49 | Args: 50 | vocab_size (`int`, *optional*, defaults to 50400): 51 | Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the 52 | `inputs_ids` passed when calling [`GPTJModel`]. 53 | n_positions (`int`, *optional*, defaults to 2048): 54 | The maximum sequence length that this model might ever be used with. Typically set this to something large 55 | just in case (e.g., 512 or 1024 or 2048). 56 | n_embd (`int`, *optional*, defaults to 4096): 57 | Dimensionality of the embeddings and hidden states. 58 | n_layer (`int`, *optional*, defaults to 28): 59 | Number of hidden layers in the Transformer encoder. 60 | n_head (`int`, *optional*, defaults to 16): 61 | Number of attention heads for each attention layer in the Transformer encoder. 62 | rotary_dim (`int`, *optional*, defaults to 64): 63 | Number of dimensions in the embedding that Rotary Position Embedding is applied to. 64 | n_inner (`int`, *optional*, defaults to None): 65 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 66 | activation_function (`str`, *optional*, defaults to `"gelu_new"`): 67 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. 68 | resid_pdrop (`float`, *optional*, defaults to 0.1): 69 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 70 | embd_pdrop (`int`, *optional*, defaults to 0.1): 71 | The dropout ratio for the embeddings. 72 | attn_pdrop (`float`, *optional*, defaults to 0.1): 73 | The dropout ratio for the attention. 74 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): 75 | The epsilon to use in the layer normalization layers. 76 | initializer_range (`float`, *optional*, defaults to 0.02): 77 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 78 | use_cache (`bool`, *optional*, defaults to `True`): 79 | Whether or not the model should return the last key/values attentions (not used by all models). 80 | keyformer_config (Dict): A dictionary used to configure the keyformer parameters: 81 | keyformer (bool): When enabled, keyformer based KV cache reduction 82 | kv_cache (float): KV cache percentage for Keyformer. 83 | recent (float): Recent window percentage for Keyformer. 84 | tau_init (float): Initial temperature for Keyformer score function calculation. 85 | tau_delta (float): Delta temperature change for Keyformer score function. 86 | 87 | Example: 88 | 89 | ```python 90 | >>> from transformers import GPTJModel, GPTJConfig 91 | 92 | >>> # Initializing a GPT-J 6B configuration 93 | >>> configuration = GPTJConfig() 94 | 95 | >>> # Initializing a model from the configuration 96 | >>> model = GPTJModel(configuration) 97 | 98 | >>> # Accessing the model configuration 99 | >>> configuration = model.config 100 | ```""" 101 | model_type = "gptj" 102 | attribute_map = { 103 | "max_position_embeddings": "n_positions", 104 | "hidden_size": "n_embd", 105 | "num_attention_heads": "n_head", 106 | "num_hidden_layers": "n_layer", 107 | } 108 | 109 | def __init__( 110 | self, 111 | vocab_size=50400, 112 | n_positions=2048, 113 | n_embd=4096, 114 | n_layer=28, 115 | n_head=16, 116 | rotary_dim=64, 117 | n_inner=None, 118 | activation_function="gelu_new", 119 | resid_pdrop=0.0, 120 | embd_pdrop=0.0, 121 | attn_pdrop=0.0, 122 | layer_norm_epsilon=1e-5, 123 | initializer_range=0.02, 124 | use_cache=True, 125 | bos_token_id=50256, 126 | eos_token_id=50256, 127 | tie_word_embeddings=False, 128 | keyformer_config: Dict = keyformer_config_defaults, 129 | **kwargs, 130 | ): 131 | self.vocab_size = vocab_size 132 | self.n_positions = n_positions 133 | self.n_embd = n_embd 134 | self.n_layer = n_layer 135 | self.n_head = n_head 136 | self.n_inner = n_inner 137 | self.rotary_dim = rotary_dim 138 | self.activation_function = activation_function 139 | self.resid_pdrop = resid_pdrop 140 | self.embd_pdrop = embd_pdrop 141 | self.attn_pdrop = attn_pdrop 142 | self.layer_norm_epsilon = layer_norm_epsilon 143 | self.initializer_range = initializer_range 144 | self.use_cache = use_cache 145 | # ====== Keyformer ====== 146 | self.keyformer_config = keyformer_config 147 | # ======================= 148 | self.bos_token_id = bos_token_id 149 | self.eos_token_id = eos_token_id 150 | 151 | super().__init__( 152 | bos_token_id=bos_token_id, 153 | eos_token_id=eos_token_id, 154 | tie_word_embeddings=tie_word_embeddings, 155 | **kwargs, 156 | ) 157 | 158 | 159 | # Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig 160 | class GPTJOnnxConfig(OnnxConfigWithPast): 161 | def __init__( 162 | self, 163 | config: PretrainedConfig, 164 | task: str = "default", 165 | patching_specs: List[PatchingSpec] = None, 166 | use_past: bool = False, 167 | ): 168 | super().__init__( 169 | config, task=task, patching_specs=patching_specs, use_past=use_past 170 | ) 171 | if not getattr(self._config, "pad_token_id", None): 172 | # TODO: how to do that better? 173 | self._config.pad_token_id = 0 174 | 175 | @property 176 | def inputs(self) -> Mapping[str, Mapping[int, str]]: 177 | common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) 178 | if self.use_past: 179 | self.fill_with_past_key_values_(common_inputs, direction="inputs") 180 | common_inputs["attention_mask"] = { 181 | 0: "batch", 182 | 1: "past_sequence + sequence", 183 | } 184 | else: 185 | common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} 186 | 187 | return common_inputs 188 | 189 | @property 190 | def num_layers(self) -> int: 191 | return self._config.n_layer 192 | 193 | @property 194 | def num_attention_heads(self) -> int: 195 | return self._config.n_head 196 | 197 | def generate_dummy_inputs( 198 | self, 199 | tokenizer: PreTrainedTokenizer, 200 | batch_size: int = -1, 201 | seq_length: int = -1, 202 | is_pair: bool = False, 203 | framework: Optional[TensorType] = None, 204 | ) -> Mapping[str, Any]: 205 | common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( 206 | tokenizer, 207 | batch_size=batch_size, 208 | seq_length=seq_length, 209 | is_pair=is_pair, 210 | framework=framework, 211 | ) 212 | 213 | # We need to order the input in the way they appears in the forward() 214 | ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) 215 | 216 | # Need to add the past_keys 217 | if self.use_past: 218 | if not is_torch_available(): 219 | raise ValueError( 220 | "Cannot generate dummy past_keys inputs without PyTorch installed." 221 | ) 222 | else: 223 | import torch 224 | 225 | batch, seqlen = common_inputs["input_ids"].shape 226 | # Not using the same length for past_key_values 227 | past_key_values_length = seqlen + 2 228 | past_shape = ( 229 | batch, 230 | self.num_attention_heads, 231 | past_key_values_length, 232 | self._config.hidden_size // self.num_attention_heads, 233 | ) 234 | ordered_inputs["past_key_values"] = [ 235 | (torch.zeros(past_shape), torch.zeros(past_shape)) 236 | for _ in range(self.num_layers) 237 | ] 238 | 239 | ordered_inputs["attention_mask"] = common_inputs["attention_mask"] 240 | if self.use_past: 241 | mask_dtype = ordered_inputs["attention_mask"].dtype 242 | ordered_inputs["attention_mask"] = torch.cat( 243 | [ 244 | ordered_inputs["attention_mask"], 245 | torch.ones(batch, past_key_values_length, dtype=mask_dtype), 246 | ], 247 | dim=1, 248 | ) 249 | 250 | return ordered_inputs 251 | 252 | @property 253 | def default_onnx_opset(self) -> int: 254 | return 13 255 | -------------------------------------------------------------------------------- /models/model_download/download_model.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import torch 5 | import transformers 6 | from transformers import AutoModelForCausalLM 7 | 8 | model_path = os.path.join(os.getcwd(), "model") 9 | 10 | if not os.path.exists(os.path.dirname(model_path)): 11 | os.mkdir(model_path) 12 | 13 | def download_model(model_name): 14 | model = AutoModelForCausalLM.from_pretrained( 15 | model_name, device_map="auto", torchscript=True, trust_remote_code=True 16 | ) # torchscript will force `return_dict=False` to avoid jit errors 17 | print("Loaded model") 18 | 19 | model.save_pretrained(model_path) 20 | 21 | print("Model downloaded and Saved in : ", model_path) 22 | 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | "--model_name", type=str, default=None, 28 | help="like cerebras/Cerebras-GPT-6.7B, etc." 29 | ) 30 | args = parser.parse_args() 31 | 32 | return args 33 | 34 | if __name__ == '__main__': 35 | args = get_args() 36 | 37 | # This check will be relaxed as more models are supported. 38 | supported_models = ['cerebras/Cerebras-GPT-6.7B', 39 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B'] 40 | if args.model_name not in supported_models: 41 | raise Exception(f'Unsupported Model Name,. Only models in ' \ 42 | f'{supported_models} are supported.') 43 | 44 | download_model(args.model_name) -------------------------------------------------------------------------------- /models/mpt-keyformer-lib/attention_lm_eval_harness.py: -------------------------------------------------------------------------------- 1 | """Attention layers.""" 2 | import sys 3 | import math 4 | import warnings 5 | import copy 6 | from typing import Optional 7 | from functools import reduce 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | from einops import rearrange 12 | from packaging import version 13 | from torch import nn 14 | import numpy as np 15 | from .norm import LPLayerNorm 16 | 17 | 18 | def _reset_is_causal( 19 | num_query_tokens: int, num_key_tokens: int, original_is_causal: bool 20 | ): 21 | if original_is_causal and num_query_tokens != num_key_tokens: 22 | if num_query_tokens != 1: 23 | raise NotImplementedError( 24 | "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." 25 | ) 26 | else: 27 | return False 28 | return original_is_causal 29 | 30 | 31 | def keyformer_mask(self, attn_weights, key_tokens, tau_init): 32 | # attn_weights (BS, head, query, keys) 33 | dtype_attn_weights = attn_weights.dtype 34 | seq_length = attn_weights.shape[-1] 35 | padding_length = 0 36 | 37 | offset = torch.finfo(attn_weights.dtype).min 38 | tmp_attn = nn.functional.gumbel_softmax( 39 | attn_weights, dim=-1, tau=tau_init, hard=False 40 | ).to(dtype_attn_weights) 41 | 42 | accumulated_score = torch.sum( 43 | tmp_attn[:, :, padding_length : key_tokens + padding_length, :], dim=-2 44 | ) # (head, keys) 45 | accumulated_score[:, :, key_tokens + padding_length :] = 0 46 | accumulated_score[:, :, :padding_length] = 0 47 | 48 | mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool) 49 | mask_bottom[ 50 | :, 51 | :, 52 | padding_length : key_tokens + padding_length, 53 | padding_length : key_tokens + padding_length, 54 | ] = True 55 | 56 | for token_index in range(key_tokens + padding_length, seq_length): 57 | tmp_attn_index = nn.functional.gumbel_softmax( 58 | attn_weights[:, :, token_index, :], dim=-1, tau=tau_init, hard=False 59 | ).to(dtype_attn_weights) 60 | _, tmp_topk_index = accumulated_score.topk(k=key_tokens - 1, dim=-1) 61 | zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool) 62 | mask_bottom_index = zeros_index.scatter( 63 | -1, tmp_topk_index, True 64 | ) # (head, keys) 65 | mask_bottom_index[:, :, token_index] = True 66 | 67 | mask_bottom[:, :, token_index, :] = mask_bottom_index 68 | accumulated_attention_score += tmp_attn_index 69 | accumulated_attention_score = accumulated_score * mask_bottom_index 70 | 71 | return mask_bottom 72 | 73 | 74 | def scaled_multihead_dot_product_attention( 75 | query, 76 | key, 77 | value, 78 | n_heads, 79 | past_key_value=None, 80 | softmax_scale=None, 81 | attn_bias=None, 82 | key_padding_mask=None, 83 | is_causal=False, 84 | dropout_p=0.0, 85 | keyformer=False, 86 | score_fn=None, 87 | token_discard_mask=None, 88 | token_discard_idx=None, 89 | req_tokens=None, 90 | kv_cache=60.0, 91 | recent=30.0, 92 | tau_init=1.0, 93 | tau_delta=0.01, 94 | itr_count=0, 95 | training=False, 96 | needs_weights=False, 97 | multiquery=False, 98 | ): 99 | q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) 100 | kv_n_heads = 1 if multiquery else n_heads 101 | k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) 102 | v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) 103 | 104 | if past_key_value is not None: 105 | if len(past_key_value) != 0: 106 | k = torch.cat([past_key_value[0], k], dim=3) 107 | v = torch.cat([past_key_value[1], v], dim=2) 108 | past_key_value = (k, v) 109 | 110 | (b, h, s_q, d) = q.shape 111 | s_k = k.size(-1) 112 | 113 | if softmax_scale is None: 114 | softmax_scale = 1 / math.sqrt(d) 115 | attn_weight = q.matmul(k) * softmax_scale 116 | 117 | _s_q = max(0, attn_bias.size(2) - s_q) 118 | _s_k = max(0, attn_bias.size(3) - s_k) 119 | attn_bias = attn_bias[:, :, _s_q:, _s_k:] 120 | if ( 121 | attn_bias.size(-1) != 1 122 | and attn_bias.size(-1) != s_k 123 | or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) 124 | ): 125 | raise RuntimeError( 126 | f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." 127 | ) 128 | attn_weight = attn_weight + attn_bias 129 | 130 | if key_padding_mask is not None: 131 | if attn_bias is not None: 132 | warnings.warn( 133 | "Propogating key_padding_mask to the attention module " 134 | + "and applying it within the attention module can cause " 135 | + "unneccessary computation/memory usage. Consider integrating " 136 | + "into attn_bias once and passing that to each attention " 137 | + "module instead." 138 | ) 139 | attn_weight = attn_weight.masked_fill( 140 | ~key_padding_mask.view((b, 1, 1, s_k)), min_val 141 | ) 142 | 143 | min_val = torch.finfo(q.dtype).min 144 | 145 | if is_causal and (not q.size(2) == 1): 146 | s = max(s_q, s_k) 147 | causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) 148 | causal_mask = causal_mask.tril() 149 | causal_mask = causal_mask.to(torch.bool) 150 | causal_mask = ~causal_mask 151 | causal_mask = causal_mask[-s_q:, -s_k:] 152 | attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) 153 | # Initialize score_fn 154 | score_fn = None 155 | # Initialize token discard mask during prompt 156 | token_discard_mask = None 157 | # Initialize itr_count during prompt 158 | itr_count = 0 159 | # Initialize required tokens during prompt 160 | req_tokens = None 161 | 162 | (b, h, s, d) = attn_weight.shape 163 | # ============================================================== 164 | 165 | # ============================================================== 166 | if keyformer: 167 | # Keyformer 168 | req_tokens = int((attn_weight.shape[-1] * kv_cache) / 100) 169 | recent_tokens = int((req_tokens * recent) / 100) 170 | prompt_tokens = req_tokens - recent_tokens 171 | req_tokens = (recent_tokens, prompt_tokens) 172 | 173 | if req_tokens[1] > 0: 174 | mask_bottom = self.keyformer_mask(attn_weight, req_tokens[1], tau_init) 175 | else: 176 | mask_bottom = torch.zeros_like(attn_weight, dtype=torch.bool) 177 | 178 | ones = torch.ones_like(attn_weight, dtype=torch.bool) 179 | ones = torch.triu(ones, diagonal=-req_tokens[0]) 180 | mask_bottom = torch.logical_or(mask_bottom, ones) 181 | 182 | mask_bottom = torch.tril(mask_bottom, diagonal=0) 183 | 184 | attn_weight[~mask_bottom] = torch.tensor(torch.finfo(attn_weights.dtype).min) 185 | # ============================================================== 186 | 187 | attn_weight = torch.softmax(attn_weight, dim=-1) 188 | if dropout_p: 189 | attn_weight = torch.nn.functional.dropout( 190 | attn_weight, p=dropout_p, training=training, inplace=True 191 | ) 192 | 193 | sparsity = past_key_value[0].shape 194 | 195 | out = attn_weight.matmul(v) 196 | out = rearrange(out, "b h s d -> b s (h d)") 197 | 198 | itr_count = itr_count + 1 199 | 200 | if needs_weights: 201 | return ( 202 | out, 203 | attn_weight, 204 | sparsity, 205 | past_key_value, 206 | score_fn, 207 | token_discard_mask, 208 | token_discard_idx, 209 | req_tokens, 210 | itr_count, 211 | ) 212 | return ( 213 | out, 214 | None, 215 | None, 216 | past_key_value, 217 | score_fn, 218 | token_discard_mask, 219 | token_discard_idx, 220 | req_tokens, 221 | itr_count, 222 | ) 223 | 224 | 225 | def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): 226 | for tensor in tensors: 227 | if tensor.dtype not in valid_dtypes: 228 | raise TypeError( 229 | f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." 230 | ) 231 | if not tensor.is_cuda: 232 | raise TypeError( 233 | f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." 234 | ) 235 | 236 | 237 | def flash_attn_fn( 238 | query, 239 | key, 240 | value, 241 | n_heads, 242 | past_key_value=None, 243 | softmax_scale=None, 244 | attn_bias=None, 245 | key_padding_mask=None, 246 | is_causal=False, 247 | dropout_p=0.0, 248 | training=False, 249 | needs_weights=False, 250 | multiquery=False, 251 | ): 252 | try: 253 | from flash_attn import bert_padding, flash_attn_interface 254 | except: 255 | raise RuntimeError("Please install flash-attn==1.0.3.post0") 256 | check_valid_inputs(query, key, value) 257 | if past_key_value is not None: 258 | if len(past_key_value) != 0: 259 | key = torch.cat([past_key_value[0], key], dim=1) 260 | value = torch.cat([past_key_value[1], value], dim=1) 261 | past_key_value = (key, value) 262 | if attn_bias is not None: 263 | _s_q = max(0, attn_bias.size(2) - query.size(1)) 264 | _s_k = max(0, attn_bias.size(3) - key.size(1)) 265 | attn_bias = attn_bias[:, :, _s_q:, _s_k:] 266 | if attn_bias is not None: 267 | raise NotImplementedError(f"attn_bias not implemented for flash attn.") 268 | (batch_size, seqlen) = query.shape[:2] 269 | if key_padding_mask is None: 270 | key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) 271 | query_padding_mask = key_padding_mask[:, -query.size(1) :] 272 | (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( 273 | query, query_padding_mask 274 | ) 275 | query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) 276 | (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( 277 | key, key_padding_mask 278 | ) 279 | key_unpad = rearrange( 280 | key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads 281 | ) 282 | (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) 283 | value_unpad = rearrange( 284 | value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads 285 | ) 286 | if multiquery: 287 | key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) 288 | value_unpad = value_unpad.expand( 289 | value_unpad.size(0), n_heads, value_unpad.size(-1) 290 | ) 291 | dropout_p = dropout_p if training else 0.0 292 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) 293 | output_unpad = flash_attn_interface.flash_attn_unpadded_func( 294 | query_unpad, 295 | key_unpad, 296 | value_unpad, 297 | cu_seqlens_q, 298 | cu_seqlens_k, 299 | max_seqlen_q, 300 | max_seqlen_k, 301 | dropout_p, 302 | softmax_scale=softmax_scale, 303 | causal=reset_is_causal, 304 | return_attn_probs=needs_weights, 305 | ) 306 | output = bert_padding.pad_input( 307 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen 308 | ) 309 | return (output, None, past_key_value) 310 | 311 | 312 | def triton_flash_attn_fn( 313 | query, 314 | key, 315 | value, 316 | n_heads, 317 | past_key_value=None, 318 | softmax_scale=None, 319 | attn_bias=None, 320 | key_padding_mask=None, 321 | is_causal=False, 322 | dropout_p=0.0, 323 | training=False, 324 | needs_weights=False, 325 | multiquery=False, 326 | ): 327 | try: 328 | from .flash_attn_triton import flash_attn_func 329 | except: 330 | _installed = False 331 | if version.parse(torch.__version__) < version.parse("2.0.0"): 332 | _installed = True 333 | try: 334 | from flash_attn.flash_attn_triton import flash_attn_func 335 | except: 336 | _installed = False 337 | if not _installed: 338 | raise RuntimeError( 339 | "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." 340 | ) 341 | check_valid_inputs(query, key, value) 342 | if past_key_value is not None: 343 | if len(past_key_value) != 0: 344 | key = torch.cat([past_key_value[0], key], dim=1) 345 | value = torch.cat([past_key_value[1], value], dim=1) 346 | past_key_value = (key, value) 347 | if attn_bias is not None: 348 | _s_q = max(0, attn_bias.size(2) - query.size(1)) 349 | _s_k = max(0, attn_bias.size(3) - key.size(1)) 350 | attn_bias = attn_bias[:, :, _s_q:, _s_k:] 351 | if dropout_p: 352 | raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") 353 | if needs_weights: 354 | raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") 355 | if key_padding_mask is not None: 356 | warnings.warn( 357 | "Propagating key_padding_mask to the attention module " 358 | + "and applying it within the attention module can cause " 359 | + "unnecessary computation/memory usage. Consider integrating " 360 | + "into attn_bias once and passing that to each attention " 361 | + "module instead." 362 | ) 363 | (b_size, s_k) = key_padding_mask.shape[:2] 364 | if attn_bias is None: 365 | attn_bias = query.new_zeros(b_size, 1, 1, s_k) 366 | attn_bias = attn_bias.masked_fill( 367 | ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min 368 | ) 369 | query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) 370 | key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) 371 | value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) 372 | if multiquery: 373 | key = key.expand(*key.shape[:2], n_heads, key.size(-1)) 374 | value = value.expand(*value.shape[:2], n_heads, value.size(-1)) 375 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) 376 | attn_output = flash_attn_func( 377 | query, key, value, attn_bias, reset_is_causal, softmax_scale 378 | ) 379 | output = attn_output.view(*attn_output.shape[:2], -1) 380 | return (output, None, past_key_value) 381 | 382 | 383 | class MultiheadAttention(nn.Module): 384 | """Multi-head self attention. 385 | 386 | Using torch or triton attention implemetation enables user to also use 387 | additive bias. 388 | """ 389 | 390 | def __init__( 391 | self, 392 | d_model: int, 393 | n_heads: int, 394 | attn_impl: str = "triton", 395 | clip_qkv: Optional[float] = None, 396 | qk_ln: bool = False, 397 | softmax_scale: Optional[float] = None, 398 | attn_pdrop: float = 0.0, 399 | keyformer: bool = False, 400 | kv_cache: float = 60.0, 401 | recent: float = 30.0, 402 | tau_init: float = 1.0, 403 | tau_delta: float = 0.01, 404 | low_precision_layernorm: bool = False, 405 | verbose: int = 0, 406 | device: Optional[str] = None, 407 | ): 408 | super().__init__() 409 | self.attn_impl = attn_impl 410 | self.clip_qkv = clip_qkv 411 | self.qk_ln = qk_ln 412 | self.d_model = d_model 413 | self.n_heads = n_heads 414 | self.softmax_scale = softmax_scale 415 | if self.softmax_scale is None: 416 | self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) 417 | self.attn_dropout_p = attn_pdrop 418 | # ==================== Keyformer ========================= 419 | self.keyformer = keyformer 420 | self.kv_cache = kv_cache 421 | self.recent = recent 422 | self.tau_init = tau_init 423 | self.tau_delta = tau_delta 424 | # ======================================================== 425 | self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device) 426 | fuse_splits = (d_model, 2 * d_model) 427 | self.Wqkv._fused = (0, fuse_splits) 428 | if self.qk_ln: 429 | layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm 430 | self.q_ln = layernorm_class(self.d_model, device=device) 431 | self.k_ln = layernorm_class(self.d_model, device=device) 432 | if self.attn_impl == "flash": 433 | self.attn_fn = flash_attn_fn 434 | elif self.attn_impl == "triton": 435 | self.attn_fn = triton_flash_attn_fn 436 | if verbose: 437 | warnings.warn( 438 | "While `attn_impl: triton` can be faster than `attn_impl: flash` " 439 | + "it uses more memory. When training larger models this can trigger " 440 | + "alloc retries which hurts performance. If encountered, we recommend " 441 | + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." 442 | ) 443 | elif self.attn_impl == "torch": 444 | self.attn_fn = scaled_multihead_dot_product_attention 445 | if torch.cuda.is_available() and verbose: 446 | warnings.warn( 447 | "Using `attn_impl: torch`. If your model does not use `alibi` or " 448 | + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " 449 | + "we recommend using `attn_impl: triton`." 450 | ) 451 | else: 452 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") 453 | self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) 454 | self.out_proj._is_residual = True 455 | 456 | def forward( 457 | self, 458 | x, 459 | past_key_value=None, 460 | attn_bias=None, 461 | attention_mask=None, 462 | is_causal=True, 463 | needs_weights=False, 464 | score_fn=None, 465 | token_discard_mask=None, 466 | token_discard_idx=None, 467 | req_tokens=None, 468 | itr_count=0, 469 | ): 470 | qkv = self.Wqkv(x) 471 | if self.clip_qkv: 472 | qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) 473 | (query, key, value) = qkv.chunk(3, dim=2) 474 | key_padding_mask = attention_mask 475 | if self.qk_ln: 476 | dtype = query.dtype 477 | query = self.q_ln(query).to(dtype) 478 | key = self.k_ln(key).to(dtype) 479 | ( 480 | context, 481 | attn_weights, 482 | sparsity, 483 | past_key_value, 484 | score_fn, 485 | token_discard_mask, 486 | token_discard_idx, 487 | req_tokens, 488 | itr_count, 489 | ) = self.attn_fn( 490 | query, 491 | key, 492 | value, 493 | self.n_heads, 494 | past_key_value=past_key_value, 495 | softmax_scale=self.softmax_scale, 496 | attn_bias=attn_bias, 497 | key_padding_mask=key_padding_mask, 498 | is_causal=is_causal, 499 | dropout_p=self.attn_dropout_p, 500 | keyformer=self.keyformer, 501 | score_fn=score_fn, 502 | token_discard_mask=token_discard_mask, 503 | token_discard_idx=token_discard_idx, 504 | req_tokens=req_tokens, 505 | kv_cache=self.kv_cache, 506 | recent=self.recent, 507 | tau_init=self.tau_init, 508 | tau_delta=self.tau_delta, 509 | itr_count=itr_count, 510 | training=self.training, 511 | needs_weights=needs_weights, 512 | ) 513 | return ( 514 | self.out_proj(context), 515 | attn_weights, 516 | sparsity, 517 | past_key_value, 518 | score_fn, 519 | token_discard_mask, 520 | token_discard_idx, 521 | req_tokens, 522 | itr_count, 523 | ) 524 | 525 | 526 | def attn_bias_shape( 527 | attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id 528 | ): 529 | if attn_impl == "flash": 530 | return None 531 | elif attn_impl in ["torch", "triton"]: 532 | if alibi: 533 | if (prefix_lm or not causal) or use_sequence_id: 534 | return (1, n_heads, seq_len, seq_len) 535 | return (1, n_heads, 1, seq_len) 536 | elif prefix_lm or use_sequence_id: 537 | return (1, 1, seq_len, seq_len) 538 | return None 539 | else: 540 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") 541 | 542 | 543 | def build_attn_bias( 544 | attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 545 | ): 546 | if attn_impl == "flash": 547 | return None 548 | elif attn_impl in ["torch", "triton"]: 549 | if alibi: 550 | (device, dtype) = (attn_bias.device, attn_bias.dtype) 551 | attn_bias = attn_bias.add( 552 | build_alibi_bias( 553 | n_heads, 554 | seq_len, 555 | full=not causal, 556 | alibi_bias_max=alibi_bias_max, 557 | device=device, 558 | dtype=dtype, 559 | ) 560 | ) 561 | return attn_bias 562 | else: 563 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") 564 | 565 | 566 | def gen_slopes(n_heads, alibi_bias_max=8, device=None): 567 | _n_heads = 2 ** math.ceil(math.log2(n_heads)) 568 | m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) 569 | m = m.mul(alibi_bias_max / _n_heads) 570 | slopes = 1.0 / torch.pow(2, m) 571 | if _n_heads != n_heads: 572 | slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] 573 | return slopes.view(1, n_heads, 1, 1) 574 | 575 | 576 | def build_alibi_bias( 577 | n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None 578 | ): 579 | alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( 580 | 1, 1, 1, seq_len 581 | ) 582 | if full: 583 | alibi_bias = alibi_bias - torch.arange( 584 | 1 - seq_len, 1, dtype=torch.int32, device=device 585 | ).view(1, 1, seq_len, 1) 586 | alibi_bias = alibi_bias.abs().mul(-1) 587 | slopes = gen_slopes(n_heads, alibi_bias_max, device=device) 588 | alibi_bias = alibi_bias * slopes 589 | return alibi_bias.to(dtype=dtype) 590 | 591 | 592 | ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention} 593 | -------------------------------------------------------------------------------- /models/mpt-keyformer-lib/attention_streaming_llm.py: -------------------------------------------------------------------------------- 1 | """Attention layers.""" 2 | import sys 3 | import math 4 | import warnings 5 | import copy 6 | from typing import Optional 7 | from functools import reduce 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.parameter import Parameter 11 | from einops import rearrange 12 | from packaging import version 13 | from torch import nn 14 | import numpy as np 15 | from .norm import LPLayerNorm 16 | 17 | 18 | def _reset_is_causal( 19 | num_query_tokens: int, num_key_tokens: int, original_is_causal: bool 20 | ): 21 | if original_is_causal and num_query_tokens != num_key_tokens: 22 | if num_query_tokens != 1: 23 | raise NotImplementedError( 24 | "MPT does not support query and key with different number of tokens, unless number of query tokens is 1." 25 | ) 26 | else: 27 | return False 28 | return original_is_causal 29 | 30 | 31 | def scaled_multihead_dot_product_attention( 32 | query, 33 | key, 34 | value, 35 | n_heads, 36 | past_key_value=None, 37 | softmax_scale=None, 38 | attn_bias=None, 39 | key_padding_mask=None, 40 | is_causal=False, 41 | dropout_p=0.0, 42 | keyformer=False, 43 | score_fn=None, 44 | token_discard_mask=None, 45 | token_discard_idx=None, 46 | req_tokens=None, 47 | kv_cache=60.0, 48 | recent=30.0, 49 | tau_init=1.0, 50 | tau_delta=0.01, 51 | itr_count=0, 52 | training=False, 53 | needs_weights=False, 54 | multiquery=False, 55 | ): 56 | q = rearrange(query, "b s (h d) -> b h s d", h=n_heads) 57 | kv_n_heads = 1 if multiquery else n_heads 58 | k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads) 59 | v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads) 60 | 61 | if past_key_value is not None: 62 | if len(past_key_value) != 0: 63 | k = torch.cat([past_key_value[0], k], dim=3) 64 | v = torch.cat([past_key_value[1], v], dim=2) 65 | past_key_value = (k, v) 66 | 67 | (b, h, s_q, d) = q.shape 68 | s_k = k.size(-1) 69 | 70 | if softmax_scale is None: 71 | softmax_scale = 1 / math.sqrt(d) 72 | attn_weight = q.matmul(k) * softmax_scale 73 | 74 | _s_q = max(0, attn_bias.size(2) - s_q) 75 | _s_k = max(0, attn_bias.size(3) - s_k) 76 | attn_bias = attn_bias[:, :, _s_q:, _s_k:] 77 | if ( 78 | attn_bias.size(-1) != 1 79 | and attn_bias.size(-1) != s_k 80 | or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q) 81 | ): 82 | raise RuntimeError( 83 | f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}." 84 | ) 85 | attn_weight = attn_weight + attn_bias 86 | 87 | if key_padding_mask is not None: 88 | if attn_bias is not None: 89 | warnings.warn( 90 | "Propogating key_padding_mask to the attention module " 91 | + "and applying it within the attention module can cause " 92 | + "unneccessary computation/memory usage. Consider integrating " 93 | + "into attn_bias once and passing that to each attention " 94 | + "module instead." 95 | ) 96 | attn_weight = attn_weight.masked_fill( 97 | ~key_padding_mask.view((b, 1, 1, s_k)), min_val 98 | ) 99 | 100 | min_val = torch.finfo(q.dtype).min 101 | 102 | if is_causal and (not q.size(2) == 1): 103 | s = max(s_q, s_k) 104 | causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16) 105 | causal_mask = causal_mask.tril() 106 | causal_mask = causal_mask.to(torch.bool) 107 | causal_mask = ~causal_mask 108 | causal_mask = causal_mask[-s_q:, -s_k:] 109 | attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val) 110 | # Initialize score_fn 111 | score_fn = None 112 | # Initialize token discard mask during prompt 113 | token_discard_mask = None 114 | # Initialize itr_count during prompt 115 | itr_count = 0 116 | # Initialize required tokens during prompt 117 | req_tokens = None 118 | 119 | (b, h, s, d) = attn_weight.shape 120 | 121 | attn_weight = torch.softmax(attn_weight, dim=-1) 122 | if dropout_p: 123 | attn_weight = torch.nn.functional.dropout( 124 | attn_weight, p=dropout_p, training=training, inplace=True 125 | ) 126 | 127 | # ============================================================== 128 | # ==> StreamingLLM (Attention sinks + Recent Window) 129 | if is_causal: 130 | # Recent + Attention Sink 131 | if not (q.size(2) == 1): 132 | req_tokens = int((attn_weight.shape[-1] * kv_cache) / 100) 133 | prompt_tokens = 4 # Attention Sinks 134 | recent_tokens = req_tokens - prompt_tokens 135 | req_tokens = (recent_tokens, prompt_tokens) 136 | 137 | token_mask = attn_weight.new_ones(b, h, 1, d, dtype=torch.float16) 138 | token_mask = token_mask.to(torch.bool) 139 | 140 | # activate most recent k-cache 141 | token_mask[:, :, :, : -req_tokens[0]] = False 142 | # activate attention sinks 143 | token_mask[:, :, :, : req_tokens[1]] = True 144 | 145 | sparse_len = req_tokens[0] + req_tokens[1] 146 | 147 | # ==> Reducing KV cache 148 | v_new = v[token_mask.squeeze()] 149 | v_new = rearrange(v_new, "(b h s) d -> b h s d", h=kv_n_heads, s=sparse_len) 150 | 151 | k_new = k.transpose(2, 3)[token_mask.squeeze()] 152 | k_new = rearrange(k_new, "(b h s) d -> b h s d", h=kv_n_heads, s=sparse_len) 153 | k_new = k_new.transpose(2, 3) 154 | 155 | # ==> Updating KV cache 156 | past_key_value = (k_new, v_new) 157 | # ================================================================================ 158 | sparsity = past_key_value[0].shape 159 | 160 | out = attn_weight.matmul(v) 161 | out = rearrange(out, "b h s d -> b s (h d)") 162 | 163 | itr_count = itr_count + 1 164 | 165 | if needs_weights: 166 | return ( 167 | out, 168 | attn_weight, 169 | sparsity, 170 | past_key_value, 171 | score_fn, 172 | token_discard_mask, 173 | token_discard_idx, 174 | req_tokens, 175 | itr_count, 176 | ) 177 | return ( 178 | out, 179 | None, 180 | None, 181 | past_key_value, 182 | score_fn, 183 | token_discard_mask, 184 | token_discard_idx, 185 | req_tokens, 186 | itr_count, 187 | ) 188 | 189 | 190 | def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]): 191 | for tensor in tensors: 192 | if tensor.dtype not in valid_dtypes: 193 | raise TypeError( 194 | f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}." 195 | ) 196 | if not tensor.is_cuda: 197 | raise TypeError( 198 | f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})." 199 | ) 200 | 201 | 202 | def flash_attn_fn( 203 | query, 204 | key, 205 | value, 206 | n_heads, 207 | past_key_value=None, 208 | softmax_scale=None, 209 | attn_bias=None, 210 | key_padding_mask=None, 211 | is_causal=False, 212 | dropout_p=0.0, 213 | training=False, 214 | needs_weights=False, 215 | multiquery=False, 216 | ): 217 | try: 218 | from flash_attn import bert_padding, flash_attn_interface 219 | except: 220 | raise RuntimeError("Please install flash-attn==1.0.3.post0") 221 | check_valid_inputs(query, key, value) 222 | if past_key_value is not None: 223 | if len(past_key_value) != 0: 224 | key = torch.cat([past_key_value[0], key], dim=1) 225 | value = torch.cat([past_key_value[1], value], dim=1) 226 | past_key_value = (key, value) 227 | if attn_bias is not None: 228 | _s_q = max(0, attn_bias.size(2) - query.size(1)) 229 | _s_k = max(0, attn_bias.size(3) - key.size(1)) 230 | attn_bias = attn_bias[:, :, _s_q:, _s_k:] 231 | if attn_bias is not None: 232 | raise NotImplementedError(f"attn_bias not implemented for flash attn.") 233 | (batch_size, seqlen) = query.shape[:2] 234 | if key_padding_mask is None: 235 | key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool) 236 | query_padding_mask = key_padding_mask[:, -query.size(1) :] 237 | (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( 238 | query, query_padding_mask 239 | ) 240 | query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads) 241 | (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( 242 | key, key_padding_mask 243 | ) 244 | key_unpad = rearrange( 245 | key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads 246 | ) 247 | (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask) 248 | value_unpad = rearrange( 249 | value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads 250 | ) 251 | if multiquery: 252 | key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1)) 253 | value_unpad = value_unpad.expand( 254 | value_unpad.size(0), n_heads, value_unpad.size(-1) 255 | ) 256 | dropout_p = dropout_p if training else 0.0 257 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) 258 | output_unpad = flash_attn_interface.flash_attn_unpadded_func( 259 | query_unpad, 260 | key_unpad, 261 | value_unpad, 262 | cu_seqlens_q, 263 | cu_seqlens_k, 264 | max_seqlen_q, 265 | max_seqlen_k, 266 | dropout_p, 267 | softmax_scale=softmax_scale, 268 | causal=reset_is_causal, 269 | return_attn_probs=needs_weights, 270 | ) 271 | output = bert_padding.pad_input( 272 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen 273 | ) 274 | return (output, None, past_key_value) 275 | 276 | 277 | def triton_flash_attn_fn( 278 | query, 279 | key, 280 | value, 281 | n_heads, 282 | past_key_value=None, 283 | softmax_scale=None, 284 | attn_bias=None, 285 | key_padding_mask=None, 286 | is_causal=False, 287 | dropout_p=0.0, 288 | training=False, 289 | needs_weights=False, 290 | multiquery=False, 291 | ): 292 | try: 293 | from .flash_attn_triton import flash_attn_func 294 | except: 295 | _installed = False 296 | if version.parse(torch.__version__) < version.parse("2.0.0"): 297 | _installed = True 298 | try: 299 | from flash_attn.flash_attn_triton import flash_attn_func 300 | except: 301 | _installed = False 302 | if not _installed: 303 | raise RuntimeError( 304 | "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed." 305 | ) 306 | check_valid_inputs(query, key, value) 307 | if past_key_value is not None: 308 | if len(past_key_value) != 0: 309 | key = torch.cat([past_key_value[0], key], dim=1) 310 | value = torch.cat([past_key_value[1], value], dim=1) 311 | past_key_value = (key, value) 312 | if attn_bias is not None: 313 | _s_q = max(0, attn_bias.size(2) - query.size(1)) 314 | _s_k = max(0, attn_bias.size(3) - key.size(1)) 315 | attn_bias = attn_bias[:, :, _s_q:, _s_k:] 316 | if dropout_p: 317 | raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.") 318 | if needs_weights: 319 | raise NotImplementedError(f"attn_impl: triton cannot return attn weights.") 320 | if key_padding_mask is not None: 321 | warnings.warn( 322 | "Propagating key_padding_mask to the attention module " 323 | + "and applying it within the attention module can cause " 324 | + "unnecessary computation/memory usage. Consider integrating " 325 | + "into attn_bias once and passing that to each attention " 326 | + "module instead." 327 | ) 328 | (b_size, s_k) = key_padding_mask.shape[:2] 329 | if attn_bias is None: 330 | attn_bias = query.new_zeros(b_size, 1, 1, s_k) 331 | attn_bias = attn_bias.masked_fill( 332 | ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min 333 | ) 334 | query = rearrange(query, "b s (h d) -> b s h d", h=n_heads) 335 | key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) 336 | value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads) 337 | if multiquery: 338 | key = key.expand(*key.shape[:2], n_heads, key.size(-1)) 339 | value = value.expand(*value.shape[:2], n_heads, value.size(-1)) 340 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal) 341 | attn_output = flash_attn_func( 342 | query, key, value, attn_bias, reset_is_causal, softmax_scale 343 | ) 344 | output = attn_output.view(*attn_output.shape[:2], -1) 345 | return (output, None, past_key_value) 346 | 347 | 348 | class MultiheadAttention(nn.Module): 349 | """Multi-head self attention. 350 | 351 | Using torch or triton attention implemetation enables user to also use 352 | additive bias. 353 | """ 354 | 355 | def __init__( 356 | self, 357 | d_model: int, 358 | n_heads: int, 359 | attn_impl: str = "triton", 360 | clip_qkv: Optional[float] = None, 361 | qk_ln: bool = False, 362 | softmax_scale: Optional[float] = None, 363 | attn_pdrop: float = 0.0, 364 | keyformer: bool = False, 365 | kv_cache: float = 60.0, 366 | recent: float = 30.0, 367 | tau_init: float = 1.0, 368 | tau_delta: float = 0.01, 369 | low_precision_layernorm: bool = False, 370 | verbose: int = 0, 371 | device: Optional[str] = None, 372 | ): 373 | super().__init__() 374 | self.attn_impl = attn_impl 375 | self.clip_qkv = clip_qkv 376 | self.qk_ln = qk_ln 377 | self.d_model = d_model 378 | self.n_heads = n_heads 379 | self.softmax_scale = softmax_scale 380 | if self.softmax_scale is None: 381 | self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads) 382 | self.attn_dropout_p = attn_pdrop 383 | # ==================== Keyformer ========================= 384 | self.keyformer = keyformer 385 | self.kv_cache = kv_cache 386 | self.recent = recent 387 | self.tau_init = tau_init 388 | self.tau_delta = tau_delta 389 | # ======================================================== 390 | self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device) 391 | fuse_splits = (d_model, 2 * d_model) 392 | self.Wqkv._fused = (0, fuse_splits) 393 | if self.qk_ln: 394 | layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm 395 | self.q_ln = layernorm_class(self.d_model, device=device) 396 | self.k_ln = layernorm_class(self.d_model, device=device) 397 | if self.attn_impl == "flash": 398 | self.attn_fn = flash_attn_fn 399 | elif self.attn_impl == "triton": 400 | self.attn_fn = triton_flash_attn_fn 401 | if verbose: 402 | warnings.warn( 403 | "While `attn_impl: triton` can be faster than `attn_impl: flash` " 404 | + "it uses more memory. When training larger models this can trigger " 405 | + "alloc retries which hurts performance. If encountered, we recommend " 406 | + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`." 407 | ) 408 | elif self.attn_impl == "torch": 409 | self.attn_fn = scaled_multihead_dot_product_attention 410 | if torch.cuda.is_available() and verbose: 411 | warnings.warn( 412 | "Using `attn_impl: torch`. If your model does not use `alibi` or " 413 | + "`prefix_lm` we recommend using `attn_impl: flash` otherwise " 414 | + "we recommend using `attn_impl: triton`." 415 | ) 416 | else: 417 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") 418 | self.out_proj = nn.Linear(self.d_model, self.d_model, device=device) 419 | self.out_proj._is_residual = True 420 | 421 | def forward( 422 | self, 423 | x, 424 | past_key_value=None, 425 | attn_bias=None, 426 | attention_mask=None, 427 | is_causal=True, 428 | needs_weights=False, 429 | score_fn=None, 430 | token_discard_mask=None, 431 | token_discard_idx=None, 432 | req_tokens=None, 433 | itr_count=0, 434 | ): 435 | qkv = self.Wqkv(x) 436 | if self.clip_qkv: 437 | qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) 438 | (query, key, value) = qkv.chunk(3, dim=2) 439 | key_padding_mask = attention_mask 440 | if self.qk_ln: 441 | dtype = query.dtype 442 | query = self.q_ln(query).to(dtype) 443 | key = self.k_ln(key).to(dtype) 444 | ( 445 | context, 446 | attn_weights, 447 | sparsity, 448 | past_key_value, 449 | score_fn, 450 | token_discard_mask, 451 | token_discard_idx, 452 | req_tokens, 453 | itr_count, 454 | ) = self.attn_fn( 455 | query, 456 | key, 457 | value, 458 | self.n_heads, 459 | past_key_value=past_key_value, 460 | softmax_scale=self.softmax_scale, 461 | attn_bias=attn_bias, 462 | key_padding_mask=key_padding_mask, 463 | is_causal=is_causal, 464 | dropout_p=self.attn_dropout_p, 465 | keyformer=self.keyformer, 466 | score_fn=score_fn, 467 | token_discard_mask=token_discard_mask, 468 | token_discard_idx=token_discard_idx, 469 | req_tokens=req_tokens, 470 | kv_cache=self.kv_cache, 471 | recent=self.recent, 472 | tau_init=self.tau_init, 473 | tau_delta=self.tau_delta, 474 | itr_count=itr_count, 475 | training=self.training, 476 | needs_weights=needs_weights, 477 | ) 478 | return ( 479 | self.out_proj(context), 480 | attn_weights, 481 | sparsity, 482 | past_key_value, 483 | score_fn, 484 | token_discard_mask, 485 | token_discard_idx, 486 | req_tokens, 487 | itr_count, 488 | ) 489 | 490 | 491 | def attn_bias_shape( 492 | attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id 493 | ): 494 | if attn_impl == "flash": 495 | return None 496 | elif attn_impl in ["torch", "triton"]: 497 | if alibi: 498 | if (prefix_lm or not causal) or use_sequence_id: 499 | return (1, n_heads, seq_len, seq_len) 500 | return (1, n_heads, 1, seq_len) 501 | elif prefix_lm or use_sequence_id: 502 | return (1, 1, seq_len, seq_len) 503 | return None 504 | else: 505 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") 506 | 507 | 508 | def build_attn_bias( 509 | attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8 510 | ): 511 | if attn_impl == "flash": 512 | return None 513 | elif attn_impl in ["torch", "triton"]: 514 | if alibi: 515 | (device, dtype) = (attn_bias.device, attn_bias.dtype) 516 | attn_bias = attn_bias.add( 517 | build_alibi_bias( 518 | n_heads, 519 | seq_len, 520 | full=not causal, 521 | alibi_bias_max=alibi_bias_max, 522 | device=device, 523 | dtype=dtype, 524 | ) 525 | ) 526 | return attn_bias 527 | else: 528 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.") 529 | 530 | 531 | def gen_slopes(n_heads, alibi_bias_max=8, device=None): 532 | _n_heads = 2 ** math.ceil(math.log2(n_heads)) 533 | m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device) 534 | m = m.mul(alibi_bias_max / _n_heads) 535 | slopes = 1.0 / torch.pow(2, m) 536 | if _n_heads != n_heads: 537 | slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] 538 | return slopes.view(1, n_heads, 1, 1) 539 | 540 | 541 | def build_alibi_bias( 542 | n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None 543 | ): 544 | alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view( 545 | 1, 1, 1, seq_len 546 | ) 547 | if full: 548 | alibi_bias = alibi_bias - torch.arange( 549 | 1 - seq_len, 1, dtype=torch.int32, device=device 550 | ).view(1, 1, seq_len, 1) 551 | alibi_bias = alibi_bias.abs().mul(-1) 552 | slopes = gen_slopes(n_heads, alibi_bias_max, device=device) 553 | alibi_bias = alibi_bias * slopes 554 | return alibi_bias.to(dtype=dtype) 555 | 556 | 557 | ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention} 558 | -------------------------------------------------------------------------------- /models/mpt-keyformer-lib/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | from .attention import ATTN_CLASS_REGISTRY 7 | from .norm import NORM_CLASS_REGISTRY 8 | 9 | 10 | class MPTMLP(nn.Module): 11 | def __init__( 12 | self, d_model: int, expansion_ratio: int, device: Optional[str] = None 13 | ): 14 | super().__init__() 15 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 16 | self.act = nn.GELU(approximate="none") 17 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 18 | self.down_proj._is_residual = True 19 | 20 | def forward(self, x): 21 | return self.down_proj(self.act(self.up_proj(x))) 22 | 23 | 24 | class MPTBlock(nn.Module): 25 | def __init__( 26 | self, 27 | d_model: int, 28 | n_heads: int, 29 | expansion_ratio: int, 30 | attn_config: Dict = { 31 | "attn_type": "multihead_attention", 32 | "attn_pdrop": 0.0, 33 | "attn_impl": "triton", 34 | "qk_ln": False, 35 | "clip_qkv": None, 36 | "softmax_scale": None, 37 | "prefix_lm": False, 38 | "attn_uses_sequence_id": False, 39 | "alibi": False, 40 | "alibi_bias_max": 8, 41 | }, 42 | keyformer_config: Dict = { 43 | "keyformer": False, 44 | "kv_cache": 60, 45 | "recent": 30, 46 | "tau_init": 1.0, 47 | "tau_delta": 0.01, 48 | }, 49 | resid_pdrop: float = 0.0, 50 | norm_type: str = "low_precision_layernorm", 51 | verbose: int = 0, 52 | device: Optional[str] = None, 53 | **kwargs 54 | ): 55 | del kwargs 56 | super().__init__() 57 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 58 | attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]] 59 | self.norm_1 = norm_class(d_model, device=device) 60 | self.attn = attn_class( 61 | attn_impl=attn_config["attn_impl"], 62 | clip_qkv=attn_config["clip_qkv"], 63 | qk_ln=attn_config["qk_ln"], 64 | softmax_scale=attn_config["softmax_scale"], 65 | attn_pdrop=attn_config["attn_pdrop"], 66 | keyformer=keyformer_config["keyformer"], 67 | kv_cache=keyformer_config["kv_cache"], 68 | recent=keyformer_config["recent"], 69 | tau_init=keyformer_config["tau_init"], 70 | tau_delta=keyformer_config["tau_delta"], 71 | d_model=d_model, 72 | n_heads=n_heads, 73 | verbose=verbose, 74 | device=device, 75 | ) 76 | self.norm_2 = norm_class(d_model, device=device) 77 | self.ffn = MPTMLP( 78 | d_model=d_model, expansion_ratio=expansion_ratio, device=device 79 | ) 80 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 81 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 82 | # ==================== Keyformer ========================= 83 | self.keyformer = keyformer_config["keyformer"] 84 | self.kv_cache = keyformer_config["kv_cache"] 85 | self.recent = keyformer_config["recent"] 86 | self.tau_init = keyformer_config["tau_init"] 87 | self.tau_delta = keyformer_config["tau_delta"] 88 | self.token_discard_mask = None # Token discard mask for each MPT block 89 | self.token_discard_idx = None # Token discard indices for each MPT block 90 | self.req_tokens = None # Required tokens to attend for each MPT block 91 | self.score_fn = None # Score fn for each MPT block 92 | self.itr_count = 0 93 | # ======================================================== 94 | 95 | # added output_attentions parameter for extracting output attentions 96 | def forward( 97 | self, 98 | x: torch.Tensor, 99 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 100 | attn_bias: Optional[torch.Tensor] = None, 101 | attention_mask: Optional[torch.ByteTensor] = None, 102 | is_causal: bool = True, 103 | output_attentions: bool = False, 104 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 105 | a = self.norm_1(x) 106 | # added needs_weight parameter missing for output_attentions 107 | ( 108 | b, 109 | attn_weights, 110 | sparsity, 111 | past_key_value, 112 | score_fn, 113 | token_discard_mask, 114 | token_discard_idx, 115 | req_tokens, 116 | itr_count, 117 | ) = self.attn( 118 | a, 119 | past_key_value=past_key_value, 120 | attn_bias=attn_bias, 121 | attention_mask=attention_mask, 122 | is_causal=is_causal, 123 | needs_weights=output_attentions, 124 | score_fn=self.score_fn, 125 | token_discard_mask=self.token_discard_mask, 126 | token_discard_idx=self.token_discard_idx, 127 | req_tokens=self.req_tokens, 128 | itr_count=self.itr_count, 129 | ) 130 | self.score_fn = score_fn 131 | self.token_discard_mask = token_discard_mask 132 | self.token_discard_idx = token_discard_idx 133 | self.req_tokens = req_tokens 134 | self.itr_count = itr_count 135 | 136 | x = x + self.resid_attn_dropout(b) 137 | m = self.norm_2(x) 138 | n = self.ffn(m) 139 | x = x + self.resid_ffn_dropout(n) 140 | return (x, attn_weights, sparsity, past_key_value) 141 | -------------------------------------------------------------------------------- /models/mpt-keyformer-lib/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "mosaicml/mpt-7b", 3 | "architectures": [ 4 | "MPTForCausalLM" 5 | ], 6 | "attn_config": { 7 | "alibi": true, 8 | "alibi_bias_max": 8, 9 | "attn_impl": "torch", 10 | "attn_pdrop": 0, 11 | "attn_type": "multihead_attention", 12 | "attn_uses_sequence_id": false, 13 | "clip_qkv": null, 14 | "prefix_lm": false, 15 | "qk_ln": false, 16 | "softmax_scale": null 17 | }, 18 | "auto_map": { 19 | "AutoConfig": "configuration_mpt.MPTConfig", 20 | "AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM" 21 | }, 22 | "keyformer_config": { 23 | "keyformer": false, 24 | "kv_cache": 60, 25 | "recent": 30, 26 | "tau_init": 1.0, 27 | "tau_delta": 0.01 28 | }, 29 | "d_model": 4096, 30 | "emb_pdrop": 0, 31 | "embedding_fraction": 1.0, 32 | "expansion_ratio": 4, 33 | "init_config": { 34 | "emb_init_std": null, 35 | "emb_init_uniform_lim": null, 36 | "fan_mode": "fan_in", 37 | "init_div_is_residual": true, 38 | "init_gain": 0, 39 | "init_nonlinearity": "relu", 40 | "init_std": 0.02, 41 | "name": "kaiming_normal_", 42 | "verbose": 0 43 | }, 44 | "init_device": "cpu", 45 | "learned_pos_emb": true, 46 | "logit_scale": null, 47 | "max_seq_len": 2048, 48 | "model_type": "mpt", 49 | "n_heads": 32, 50 | "n_layers": 32, 51 | "no_bias": true, 52 | "norm_type": "low_precision_layernorm", 53 | "resid_pdrop": 0, 54 | "tokenizer_name": "EleutherAI/gpt-neox-20b", 55 | "torch_dtype": "float32", 56 | "transformers_version": "4.28.1", 57 | "use_cache": false, 58 | "verbose": 0, 59 | "vocab_size": 50432 60 | } 61 | -------------------------------------------------------------------------------- /models/mpt-keyformer-lib/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | 5 | attn_config_defaults: Dict = { 6 | "attn_type": "multihead_attention", 7 | "attn_pdrop": 0.0, 8 | "attn_impl": "triton", 9 | "qk_ln": False, 10 | "clip_qkv": None, 11 | "softmax_scale": None, 12 | "prefix_lm": False, 13 | "attn_uses_sequence_id": False, 14 | "alibi": False, 15 | "alibi_bias_max": 8, 16 | } 17 | init_config_defaults: Dict = { 18 | "name": "kaiming_normal_", 19 | "fan_mode": "fan_in", 20 | "init_nonlinearity": "relu", 21 | "init_div_is_residual": True, 22 | "emb_init_std": None, 23 | "emb_init_uniform_lim": None, 24 | "init_std": None, 25 | "init_gain": 0.0, 26 | } 27 | keyformer_config_defaults: Dict = { 28 | "keyformer": False, 29 | "kv_cache": 60, 30 | "recent": 30, 31 | "tau_init": 1.0, 32 | "tau_delta": 0.01, 33 | } 34 | 35 | 36 | class MPTConfig(PretrainedConfig): 37 | model_type = "mpt" 38 | 39 | def __init__( 40 | self, 41 | d_model: int = 2048, 42 | n_heads: int = 16, 43 | n_layers: int = 24, 44 | expansion_ratio: int = 4, 45 | max_seq_len: int = 2048, 46 | vocab_size: int = 50368, 47 | resid_pdrop: float = 0.0, 48 | emb_pdrop: float = 0.0, 49 | learned_pos_emb: bool = True, 50 | attn_config: Dict = attn_config_defaults, 51 | init_device: str = "cpu", 52 | logit_scale: Optional[Union[float, str]] = None, 53 | no_bias: bool = False, 54 | verbose: int = 0, 55 | embedding_fraction: float = 1.0, 56 | norm_type: str = "low_precision_layernorm", 57 | use_cache: bool = False, 58 | init_config: Dict = init_config_defaults, 59 | keyformer_config: Dict = keyformer_config_defaults, 60 | **kwargs, 61 | ): 62 | """The MPT configuration class. 63 | 64 | Args: 65 | d_model (int): The size of the embedding dimension of the model. 66 | n_heads (int): The number of attention heads. 67 | n_layers (int): The number of layers in the model. 68 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 69 | max_seq_len (int): The maximum sequence length of the model. 70 | vocab_size (int): The size of the vocabulary. 71 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 72 | emb_pdrop (float): The dropout probability for the embedding layer. 73 | learned_pos_emb (bool): Whether to use learned positional embeddings 74 | attn_config (Dict): A dictionary used to configure the model's attention module: 75 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, groupquery_attention 76 | attn_pdrop (float): The dropout probability for the attention layers. 77 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 78 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 79 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 80 | this value. 81 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 82 | use the default scale of ``1/sqrt(d_keys)``. 83 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 84 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 85 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 86 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 87 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 88 | which sub-sequence each token belongs to. 89 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 90 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 91 | alibi_bias_max (int): The maximum value of the alibi bias. 92 | keyformer_config (Dict): A dictionary used to configure the keyformer parameters: 93 | keyformer (bool): When enabled, keyformer based KV cache reduction 94 | kv_cache (float): KV cache percentage for Keyformer. 95 | recent (float): Recent window percentage for Keyformer. 96 | tau_init (float): Initial temperature for Keyformer score function calculation. 97 | tau_delta (float): Delta temperature change for Keyformer score function. 98 | 99 | init_device (str): The device to use for parameter initialization. 100 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 101 | no_bias (bool): Whether to use bias in all layers. 102 | verbose (int): The verbosity level. 0 is silent. 103 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 104 | norm_type (str): choose type of norm to use 105 | multiquery_attention (bool): Whether to use multiquery attention implementation. 106 | use_cache (bool): Whether or not the model should return the last key/values attentions 107 | init_config (Dict): A dictionary used to configure the model initialization: 108 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 109 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 110 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 111 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 112 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 113 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 114 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 115 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 116 | if using the baseline_ parameter initialization scheme. 117 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 118 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 119 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 120 | --- 121 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 122 | """ 123 | self.d_model = d_model 124 | self.n_heads = n_heads 125 | self.n_layers = n_layers 126 | self.expansion_ratio = expansion_ratio 127 | self.max_seq_len = max_seq_len 128 | self.vocab_size = vocab_size 129 | self.resid_pdrop = resid_pdrop 130 | self.emb_pdrop = emb_pdrop 131 | self.learned_pos_emb = learned_pos_emb 132 | self.attn_config = attn_config 133 | self.init_device = init_device 134 | self.logit_scale = logit_scale 135 | self.no_bias = no_bias 136 | self.verbose = verbose 137 | self.embedding_fraction = embedding_fraction 138 | self.norm_type = norm_type 139 | self.use_cache = use_cache 140 | self.init_config = init_config 141 | # ====== Keyformer ====== 142 | self.keyformer_config = keyformer_config 143 | # ======================= 144 | if "name" in kwargs: 145 | del kwargs["name"] 146 | if "loss_fn" in kwargs: 147 | del kwargs["loss_fn"] 148 | super().__init__(**kwargs) 149 | self._validate_config() 150 | 151 | def _set_config_defaults(self, config, config_defaults): 152 | for k, v in config_defaults.items(): 153 | if k not in config: 154 | config[k] = v 155 | return config 156 | 157 | def _validate_config(self): 158 | self.attn_config = self._set_config_defaults( 159 | self.attn_config, attn_config_defaults 160 | ) 161 | self.init_config = self._set_config_defaults( 162 | self.init_config, init_config_defaults 163 | ) 164 | if self.d_model % self.n_heads != 0: 165 | raise ValueError("d_model must be divisible by n_heads") 166 | if any( 167 | ( 168 | prob < 0 or prob > 1 169 | for prob in [ 170 | self.attn_config["attn_pdrop"], 171 | self.resid_pdrop, 172 | self.emb_pdrop, 173 | ] 174 | ) 175 | ): 176 | raise ValueError( 177 | "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" 178 | ) 179 | if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]: 180 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 181 | if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [ 182 | "torch", 183 | "triton", 184 | ]: 185 | raise NotImplementedError( 186 | "prefix_lm only implemented with torch and triton attention." 187 | ) 188 | if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [ 189 | "torch", 190 | "triton", 191 | ]: 192 | raise NotImplementedError( 193 | "alibi only implemented with torch and triton attention." 194 | ) 195 | if self.attn_config["attn_uses_sequence_id"] and self.attn_config[ 196 | "attn_impl" 197 | ] not in ["torch", "triton"]: 198 | raise NotImplementedError( 199 | "attn_uses_sequence_id only implemented with torch and triton attention." 200 | ) 201 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 202 | raise ValueError( 203 | "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!" 204 | ) 205 | if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model": 206 | raise ValueError( 207 | f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." 208 | ) 209 | if self.init_config.get("name", None) is None: 210 | raise ValueError( 211 | f"self.init_config={self.init_config!r} 'name' needs to be set." 212 | ) 213 | if not self.learned_pos_emb and (not self.attn_config["alibi"]): 214 | raise ValueError( 215 | f"Positional information must be provided to the model using either learned_pos_emb or alibi." 216 | ) 217 | # ====================== Keyformer parameters validation ==================== 218 | if ( 219 | self.keyformer_config["kv_cache"] < 0 220 | or self.keyformer_config["kv_cache"] > 100 221 | ): 222 | raise ValueError( 223 | f"KV cache percentage should be in the range from 0 to 100." 224 | ) 225 | if self.keyformer_config["recent"] < 0 or self.keyformer_config["recent"] > 100: 226 | raise ValueError(f"Recent percentage should be in the range from 0 to 100.") 227 | if self.keyformer_config["tau_init"] < 0: 228 | raise ValueError(f"Initial temperature parameter can not be less than 0.") 229 | if self.keyformer_config["tau_delta"] < 0: 230 | raise ValueError( 231 | f"Delta temperature change parameter can not be less than 0." 232 | ) 233 | # =========================================================================== 234 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.18.0 3 | aiohttp==3.8.4 4 | aiosignal==1.3.1 5 | appdirs==1.4.4 6 | async-timeout==4.0.2 7 | attrs==23.1.0 8 | certifi==2022.12.7 9 | charset-normalizer==3.1.0 10 | click==8.1.3 11 | cmake==3.26.3 12 | datasets==2.13.0 13 | dill==0.3.6 14 | docker-pycreds==0.4.0 15 | evaluate==0.4.0 16 | filelock==3.12.0 17 | fire==0.5.0 18 | frozenlist==1.3.3 19 | fsspec==2023.4.0 20 | gitdb==4.0.10 21 | gitpython==3.1.31 22 | huggingface-hub==0.14.1 23 | idna==3.4 24 | jinja2==3.1.2 25 | joblib==1.2.0 26 | lit==16.0.2 27 | markupsafe==2.1.2 28 | mpmath==1.3.0 29 | multidict==6.0.4 30 | multiprocess==0.70.14 31 | networkx==3.1 32 | nltk==3.8.1 33 | numpy==1.24.3 34 | nvidia-cublas-cu11==11.10.3.66 35 | nvidia-cuda-cupti-cu11==11.7.101 36 | nvidia-cuda-nvrtc-cu11==11.7.99 37 | nvidia-cuda-runtime-cu11==11.7.99 38 | nvidia-cudnn-cu11==8.5.0.96 39 | nvidia-cufft-cu11==10.9.0.58 40 | nvidia-curand-cu11==10.2.10.91 41 | nvidia-cusolver-cu11==11.4.0.1 42 | nvidia-cusparse-cu11==11.7.4.91 43 | nvidia-nccl-cu11==2.14.3 44 | nvidia-nvtx-cu11==11.7.91 45 | openai==0.27.6 46 | packaging==23.1 47 | pandas==2.0.1 48 | pathtools==0.1.2 49 | protobuf==4.22.4 50 | psutil==5.9.5 51 | pyarrow==12.0.0 52 | python-dateutil==2.8.2 53 | pytz==2023.3 54 | pyyaml==6.0 55 | regex==2023.5.5 56 | requests==2.30.0 57 | responses==0.18.0 58 | rouge-score==0.1.2 59 | sentencepiece==0.1.99 60 | sentry-sdk==1.21.1 61 | setproctitle==1.3.2 62 | simplejson==3.19.1 63 | six==1.16.0 64 | smmap==5.0.0 65 | sympy==1.11.1 66 | termcolor==2.3.0 67 | tokenizers==0.12.1 68 | torch==2.0.0 69 | tqdm==4.65.0 70 | transformers==4.28.1 71 | triton==2.0.0 72 | typing-extensions==4.5.0 73 | tzdata==2023.3 74 | urllib3==2.0.2 75 | wandb==0.15.1 76 | xxhash==3.2.0 77 | yarl==1.9.2 78 | einops -------------------------------------------------------------------------------- /summarization/README.md: -------------------------------------------------------------------------------- 1 | # 📖 Summarization Task 2 | 3 | ## Download dataset 4 | 5 | You can download the respective summarization datasets using below script. 6 | ```bash 7 | cd dataset_download 8 | python download_cnndm.py 9 | ``` 10 | 11 | We have provided data download script for below summarization datasets. 12 | - [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) 13 | - [XSUM](https://huggingface.co/datasets/EdinburghNLP/xsum) 14 | - [GovReports](https://huggingface.co/datasets/ccdv/govreport-summarization?row=0) 15 | 16 | ## Quick start with the Summarization Task 17 | 18 | You can directly run the summarization task using the following - 19 | 20 | ``` 21 | bash run_summarization_task.sh 22 | ``` 23 | 24 | By default, this runs `mosaicml/mpt-7b`. It assumes that the model copy with 25 | `keyformer` is stored in `models/mpt-7b-keyformer` and the dataset is stored in 26 | `data/cnn_eval.json`. For custom execution, read below. 27 | 28 | 29 | ## Summarization Task 30 | 31 | To get started with summarization task, setup the model parameters and use below script 32 | 33 | ``` 34 | python summarize.py --model_name \ 35 | --dataset_path \ 36 | --save_path \ 37 | --score_path \ 38 | --model_path \ 39 | --attentions_path 40 | --device cuda \ # Device 41 | --task summarization \ # Task - Currently only summarization supported 42 | --bs 1 \ # Batch Size 43 | --dtype float16 \ # Data type of model parameters 44 | --causal_lm \ # Causal language model for summarization 45 | --early_stopping \ # Enable early stopping while generation 46 | --output_summaries_only \ # Output only summary - No prompt 47 | --output_sequence_scores \ # Output sequence scores and enable for output attentions 48 | --save_attentions \ # Save attention weights - Token Generation only 49 | --save_prompt_attentions \ # If enabled only prompt attention weights are stored 50 | --padding_side left \ # Padding side 51 | --beam 4 \ # Beam Size 52 | --model_parallelize \ # Parallelize Model across all available GPUs 53 | --keyformer \ # Enable Keyformer 54 | --kv_cache 60 \ # KV cache percentage of prompt length 55 | --recent 30 \ # Recent window percentage 56 | --tau_init 1 \ # Initial temperature parameter for Gumbel 57 | --tau_end 2 \ # End temperature parameter for Gumbel 58 | --no_repeat_ngram_size 0 \ 59 | --repetition_penalty 1 \ 60 | --max_tokenizer_length 1920 \ # Maximum prompt size 61 | --max_new_tokens 128 \ # Maximum newly generated tokens 62 | --min_gen_length 30 \ # Minimum newly generated tokens 63 | --num_return_sequences 1 \ # Number os return summaries per input 64 | --seed 12345 \ # Random seed value for radom samples 65 | --n_obs 1000 \ # Number of input samples 66 | ``` 67 | 68 | Note: For data type of FP16, do not use model.half() instead utilize dtype in model creation [Link](https://stackoverflow.com/questions/69994731/what-is-the-difference-between-cuda-amp-and-model-half) 69 | -------------------------------------------------------------------------------- /summarization/__pycache__/dataset.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/dataset.cpython-310.pyc -------------------------------------------------------------------------------- /summarization/__pycache__/dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/dataset.cpython-39.pyc -------------------------------------------------------------------------------- /summarization/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /summarization/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /summarization/data/data_cache/.._data_data_cache_cnn_dailymail_3.0.0_0.0.0_d3306fda2d44efc3.lock: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/data/data_cache/.._data_data_cache_cnn_dailymail_3.0.0_0.0.0_d3306fda2d44efc3.lock -------------------------------------------------------------------------------- /summarization/data/data_cache/data_cache: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/data/data_cache/data_cache -------------------------------------------------------------------------------- /summarization/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from torch.nn.functional import pad 8 | from torch.utils.data import DataLoader 9 | from typing import Optional, Dict, Sequence 10 | import io 11 | import utils 12 | import copy 13 | 14 | PROMPT_DICT = { 15 | "prompt_input": ( 16 | "Below is an instruction that describes a task, paired with an input that provides further context. " 17 | "Write a response that appropriately completes the request.\n\n" 18 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 19 | ), 20 | "prompt_no_input": ( 21 | "Below is an instruction that describes a task. " 22 | "Write a response that appropriately completes the request.\n\n" 23 | "### Instruction:\n{instruction}\n\n### Response:" 24 | ), 25 | } 26 | 27 | 28 | class Dataset(torch.utils.data.Dataset): 29 | """Characterizes a dataset for PyTorch""" 30 | 31 | def __init__( 32 | self, 33 | dataset_path, 34 | tokenizer, 35 | model_name, 36 | return_tensors, 37 | truncation, 38 | padding, 39 | max_length=None, 40 | total_count_override=None, 41 | perf_count_override=None, 42 | ): 43 | self.dataset = "cnn_dailymail" 44 | self.dataset_path = dataset_path 45 | self.model_name = model_name 46 | self.tokenizer = tokenizer 47 | self.return_tensors = return_tensors 48 | self.truncation = truncation 49 | self.padding = padding 50 | self.max_length = max_length 51 | 52 | self.list_data_dict = utils.jload(self.dataset_path) 53 | 54 | prompt_input, prompt_no_input = ( 55 | PROMPT_DICT["prompt_input"], 56 | PROMPT_DICT["prompt_no_input"], 57 | ) 58 | self.sources = [ 59 | prompt_input.format_map(example) for example in self.list_data_dict 60 | ] 61 | self.targets = [f"{example['output']}" for example in self.list_data_dict] 62 | 63 | # Getting random samples for evaluation 64 | if total_count_override > 0: 65 | self.rand_samples = np.random.randint( 66 | len(self.sources), size=(total_count_override) 67 | ).tolist() 68 | self.sources = [self.sources[i] for i in self.rand_samples] 69 | self.targets = [self.targets[i] for i in self.rand_samples] 70 | self.count = total_count_override 71 | else: 72 | self.count = len(self.sources) 73 | 74 | ( 75 | self.source_encoded_input_ids, 76 | self.source_encoded_attn_masks, 77 | ) = self.encode_samples() 78 | 79 | self.perf_count = perf_count_override or self.count 80 | 81 | def encode_samples(self): 82 | print("Encoding Samples") 83 | 84 | total_samples = self.count 85 | 86 | source_encoded_input_ids = [] 87 | source_encoded_attn_masks = [] 88 | 89 | for i in range(total_samples): 90 | source_encoded = self.tokenizer( 91 | self.sources[i], 92 | return_tensors=self.return_tensors, 93 | padding=self.padding, 94 | truncation=self.truncation, 95 | max_length=self.max_length, 96 | ) 97 | source_encoded_input_ids.append(source_encoded.input_ids) 98 | source_encoded_attn_masks.append(source_encoded.attention_mask) 99 | 100 | return source_encoded_input_ids, source_encoded_attn_masks 101 | 102 | def __len__(self): 103 | return self.count 104 | 105 | def __getitem__(self, index): 106 | return ( 107 | self.source_encoded_input_ids[index], 108 | self.source_encoded_attn_masks[index], 109 | ) 110 | -------------------------------------------------------------------------------- /summarization/dataset_download/download_cnndm.py: -------------------------------------------------------------------------------- 1 | # experiment config 2 | dataset_id = "cnn_dailymail" 3 | dataset_config = "3.0.0" 4 | text_column = "article" 5 | summary_column = "highlights" 6 | 7 | from datasets import load_dataset, concatenate_datasets 8 | from transformers import AutoTokenizer 9 | import numpy as np 10 | import os 11 | import simplejson as json 12 | import sys 13 | 14 | save_dataset_path = os.environ.get("DATASET_CNNDM_PATH", "../data") 15 | 16 | # Check whether the specified path exists or not 17 | isExist = os.path.exists(save_dataset_path) 18 | if not isExist: 19 | # Create a new directory because it does not exist 20 | os.makedirs(save_dataset_path) 21 | 22 | # Load dataset from the hub 23 | dataset = load_dataset(dataset_id, name=dataset_config, cache_dir="../data/data_cache") 24 | 25 | instruction_template = "Summarize the following news article:" 26 | 27 | 28 | def preprocess_function(sample): 29 | # create list of samples 30 | inputs = [] 31 | 32 | for i in range(0, len(sample[text_column])): 33 | x = dict() 34 | x["instruction"] = instruction_template 35 | x["input"] = sample[text_column][i] 36 | x["output"] = sample[summary_column][i] 37 | inputs.append(x) 38 | model_inputs = dict() 39 | model_inputs["text"] = inputs 40 | 41 | return model_inputs 42 | 43 | 44 | # process dataset 45 | tokenized_dataset = dataset.map( 46 | preprocess_function, batched=True, remove_columns=list(dataset["train"].features) 47 | ) 48 | 49 | # save dataset to disk 50 | with open(os.path.join(save_dataset_path, "cnn_eval.json"), "w") as write_f: 51 | json.dump( 52 | tokenized_dataset["validation"]["text"], write_f, indent=4, ensure_ascii=False 53 | ) 54 | 55 | 56 | print("Dataset saved in ", save_dataset_path) 57 | -------------------------------------------------------------------------------- /summarization/dataset_download/download_longbench.py: -------------------------------------------------------------------------------- 1 | # experiment config 2 | """ 3 | ====> Loading Longbanch Data 4 | ################################################################################################################ 5 | from datasets import load_dataset 6 | 7 | datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \ 8 | "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \ 9 | "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"] 10 | 11 | for dataset in datasets: 12 | data = load_dataset('THUDM/LongBench', dataset, split='test') 13 | ################################################################################################################ 14 | 15 | ====> Loading Longbanch-E Data (Uniformly Sampled Data) 16 | ################################################################################################################ 17 | from datasets import load_dataset 18 | 19 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", "trec", \ 20 | "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"] 21 | 22 | for dataset in datasets: 23 | data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test') 24 | ################################################################################################################ 25 | 26 | ################################################################################################################ 27 | { 28 | "input": "The input/command for the task, usually short, such as questions in QA, queries in Few-shot tasks, etc", 29 | "context": "The long context required for the task, such as documents, cross-file code, few-shot examples in Few-shot tasks", 30 | "answers": "A List of all true answers", 31 | "length": "Total length of the first three items (counted in characters for Chinese and words for English)", 32 | "dataset": "The name of the dataset to which this piece of data belongs", 33 | "language": "The language of this piece of data", 34 | "all_classes": "All categories in classification tasks, null for non-classification tasks", 35 | "_id": "Random id for each piece of data" 36 | } 37 | ################################################################################################################ 38 | """ 39 | dataset_id = "THUDM/LongBench" 40 | dataset_name = "gov_report" 41 | text_column = "context" 42 | summary_column = "answers" 43 | instruction_template_gov_reports = "You are given a report by a government agency. Write a one-page summary of the report." 44 | instruction_template_multi_news = ( 45 | "You are given several news passages. Write a one-page summary of all news." 46 | ) 47 | instruction_template_passage_count = ( 48 | "Determine the number of unique passages among the given set." 49 | ) 50 | instruction_template_lcc = "Write the next line of code." 51 | 52 | from datasets import load_dataset, concatenate_datasets 53 | from transformers import AutoTokenizer 54 | import numpy as np 55 | import os 56 | import simplejson as json 57 | import sys 58 | 59 | save_dataset_path = os.environ.get("DATASET_CNNDM_PATH", "../data") 60 | 61 | # Check whether the specified path exists or not 62 | isExist = os.path.exists(save_dataset_path) 63 | if not isExist: 64 | # Create a new directory because it does not exist 65 | os.makedirs(save_dataset_path) 66 | 67 | # Load dataset from the hub 68 | dataset = load_dataset( 69 | dataset_id, name=dataset_name, cache_dir="../data/data_cache", split="test" 70 | ) 71 | 72 | 73 | def preprocess_function(sample): 74 | # create list of samples 75 | inputs = [] 76 | 77 | for i in range(0, len(sample[text_column])): 78 | x = dict() 79 | # x["instruction"] = sample["input"][i] 80 | x["instruction"] = instruction_template_gov_reports 81 | x["input"] = sample[text_column][i] 82 | x["output"] = sample[summary_column][i][0] 83 | # x["classes"] = sample["all_classes"][i] 84 | inputs.append(x) 85 | model_inputs = dict() 86 | model_inputs["text"] = inputs 87 | 88 | return model_inputs 89 | 90 | 91 | # process dataset 92 | tokenized_dataset = dataset.map(preprocess_function, batched=True) 93 | 94 | # save dataset to disk 95 | 96 | file_name = dataset_name + ".json" 97 | with open(os.path.join(save_dataset_path, file_name), "w") as write_f: 98 | json.dump(tokenized_dataset["text"], write_f, indent=4, ensure_ascii=False) 99 | 100 | 101 | print("Dataset saved in ", save_dataset_path) 102 | -------------------------------------------------------------------------------- /summarization/dataset_download/download_xsum.py: -------------------------------------------------------------------------------- 1 | # experiment config 2 | text_column = "document" 3 | summary_column = "summary" 4 | 5 | from datasets import load_dataset, concatenate_datasets 6 | from transformers import AutoTokenizer 7 | import numpy as np 8 | import os 9 | import simplejson as json 10 | import sys 11 | 12 | save_dataset_path = os.environ.get("DATASET_XSUM_PATH", "../data") 13 | 14 | # Check whether the specified path exists or not 15 | isExist = os.path.exists(save_dataset_path) 16 | if not isExist: 17 | # Create a new directory because it does not exist 18 | os.makedirs(save_dataset_path) 19 | 20 | # Load dataset from the hub 21 | dataset = load_dataset("xsum", cache_dir="../data/data_cache") 22 | 23 | instruction_template = "Provide one sentence summary of the following news article:" 24 | 25 | 26 | def preprocess_function(sample): 27 | # create list of samples 28 | inputs = [] 29 | for i in range(0, len(sample[text_column])): 30 | x = dict() 31 | x["instruction"] = instruction_template 32 | x["input"] = sample[text_column][i] 33 | x["output"] = sample[summary_column][i] 34 | inputs.append(x) 35 | model_inputs = dict() 36 | model_inputs["text"] = inputs 37 | 38 | return model_inputs 39 | 40 | 41 | # process dataset 42 | tokenized_dataset = dataset.map(preprocess_function, batched=True) 43 | 44 | # save dataset to disk 45 | 46 | with open(os.path.join(save_dataset_path, "xsum_eval.json"), "w") as write_f: 47 | json.dump( 48 | tokenized_dataset["validation"]["text"], write_f, indent=4, ensure_ascii=False 49 | ) 50 | 51 | 52 | print("Dataset saved in ", save_dataset_path) 53 | -------------------------------------------------------------------------------- /summarization/dataset_synthetic_long_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import numpy as np 4 | import torch 5 | from datasets import load_dataset, load_from_disk 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | from torch.nn.functional import pad 8 | from torch.utils.data import DataLoader 9 | from typing import Optional, Dict, Sequence 10 | import io 11 | import utils 12 | import copy 13 | 14 | PROMPT_DICT = { 15 | "prompt_input": ( 16 | "Below is an instruction that describes a task, paired with an input that provides further context. " 17 | "Write a response that appropriately completes the request.\n\n" 18 | "### Instruction:\n{instruction}\n\n### Input:\n{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}\n\n### Response:" 19 | ), 20 | "prompt_no_input": ( 21 | "Below is an instruction that describes a task. " 22 | "Write a response that appropriately completes the request.\n\n" 23 | "### Instruction:\n{instruction}\n\n### Response:" 24 | ), 25 | } 26 | 27 | 28 | class Dataset(torch.utils.data.Dataset): 29 | """Characterizes a dataset for PyTorch""" 30 | 31 | def __init__( 32 | self, 33 | dataset_path, 34 | tokenizer, 35 | model_name, 36 | return_tensors, 37 | truncation, 38 | padding, 39 | max_length=None, 40 | total_count_override=None, 41 | perf_count_override=None, 42 | ): 43 | self.dataset = "cnn_dailymail" 44 | self.dataset_path = dataset_path 45 | self.model_name = model_name 46 | self.tokenizer = tokenizer 47 | self.return_tensors = return_tensors 48 | self.truncation = truncation 49 | self.padding = padding 50 | self.max_length = max_length 51 | 52 | self.list_data_dict = utils.jload(self.dataset_path) 53 | 54 | prompt_input, prompt_no_input = ( 55 | PROMPT_DICT["prompt_input"], 56 | PROMPT_DICT["prompt_no_input"], 57 | ) 58 | self.sources = [ 59 | prompt_input.format_map(example) for example in self.list_data_dict 60 | ] 61 | self.targets = [f"{example['output']}" for example in self.list_data_dict] 62 | 63 | # Getting random samples for evaluation 64 | if total_count_override > 0: 65 | self.rand_samples = np.random.randint( 66 | len(self.sources), size=(total_count_override) 67 | ).tolist() 68 | self.sources = [self.sources[i] for i in self.rand_samples] 69 | self.targets = [self.targets[i] for i in self.rand_samples] 70 | self.count = total_count_override 71 | else: 72 | self.count = len(self.sources) 73 | 74 | ( 75 | self.source_encoded_input_ids, 76 | self.source_encoded_attn_masks, 77 | ) = self.encode_samples() 78 | 79 | self.perf_count = perf_count_override or self.count 80 | 81 | def encode_samples(self): 82 | print("Encoding Samples") 83 | 84 | total_samples = self.count 85 | 86 | source_encoded_input_ids = [] 87 | source_encoded_attn_masks = [] 88 | 89 | for i in range(total_samples): 90 | source_encoded = self.tokenizer( 91 | self.sources[i], 92 | return_tensors=self.return_tensors, 93 | padding=self.padding, 94 | truncation=self.truncation, 95 | max_length=self.max_length, 96 | ) 97 | source_encoded_input_ids.append(source_encoded.input_ids) 98 | source_encoded_attn_masks.append(source_encoded.attention_mask) 99 | 100 | return source_encoded_input_ids, source_encoded_attn_masks 101 | 102 | def __len__(self): 103 | return self.count 104 | 105 | def __getitem__(self, index): 106 | return ( 107 | self.source_encoded_input_ids[index], 108 | self.source_encoded_attn_masks[index], 109 | ) 110 | -------------------------------------------------------------------------------- /summarization/out_model.score: -------------------------------------------------------------------------------- 1 | {"rouge1": 34.4293, "rouge2": 17.7591, "rougeL": 23.474, "rougeLsum": 32.8713, "gen_len": 5605.0, "gen_num": 10.0} -------------------------------------------------------------------------------- /summarization/out_model.summary: -------------------------------------------------------------------------------- 1 | Douglas Costa will spark a transfer scramble this summer with Shakhtar Donetsk ready to sell their prized-asset. Chelsea manager Jose Mourinho is a known admirer of the Brazil international having tried to land the midfielder in the previous transfer windows. Shakhtar chiefs are now open to selling Costa this summer and talks with third parties over his departure are underway. Brazil international Douglas Costa is set to depart Shakhtar Donetsk for £25million in the summer . Midfielder Costa (left) could spark a bidding war from Chelsea, Real Madrid and Barcelona . The 24-year-old 2 | ### Solution: #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include < 3 | A man was charged and prosecuted at a cost of £5,000 to the taxpayer - over the possession of cannabis worth less than £2. Martin Kewley was ordered to appear in court after police found a minuscule amount of the drug in a drawer by his bedside. Rather than caution the builder, of Onchan, Isle of Man, over the cannabis - worth £1.59 - officers and prosecutors decided to take the case to Douglas Magistrates' Court. Questionable decision: A man was charged and prosecuted at a cost of £5,000 to the taxpayer - over the possession of cannabis resin worth less 4 | Jimmy Walker held his nerve to close out victory at the Valero Texas Open at TPC San Antonio. The Ryder Cup player added to his Sony Open title this season and three PGA Tour victories in 2014 as a closing 70 maintained his four-shot overnight lead. Jordan Spieth was second on seven under, having perhaps left his run too late before finishing with four birdies in the last five holes. Jimmy Walker poses with the Valero Texas Open trophy after recording a four-shot win. Walker took the trophy at his hometown tournament and now has six wins since the start of 2014. Walker went into his final round 5 | Caroline Nyadzayo’s unpaid bill is part of the £62million cost of ‘health tourism’ to Britain. The 34-year-old, an advertising executive in Africa, was pictured with health minister Daniel Poulter on his visit to the Norfolk and Norwich University Hospital when the maternity suite was reopened. New mother: Zimbabwe-born Caroline Nyadzayo (left) was pictured with health minister Daniel Poulter (right) on his visit to the Norfolk and Norwich University Hospital in the maternity suite . The mother said she complained to Mr Poulter about the charge and 6 | PS frontman Memphis Depay has been tipped as a better player than Real Madrid star Cristiano Ronaldo at the same age. The 21-year-old is the current leading goalscorer in the Eredivisie with 16 league goals this season. Early years and PSV Eindhoven . Born in 1994, Depay started his career at hometown club PSV Eindhoven at the age of 12 before moving to Sparta Rotterdam at the age of 16. He returned to PSV Eindhoven at the age of 18 and made his first-team debut in September 2011. 7 | The Royal Navy’s new uniform was ridiculed yesterday by being compared to the outfits worn by garage mechanics. The all-navy blue gear will replace the outfit worn by serving military personnel since the Second World War and was welcomed by the Defence Ministry as ‘cool and more modern’. The old Royal Navy uniform of light blue shirt and dark blue trousers, left, is being replaced with all-navy blue gear, right. But Twitter user Tony Shumacher wrote: ‘Royal Navy unveils “modern” uniform… sponsored by Kwik Fit?’ Jamie Frost, a seaman who trains at the University Royal Naval Unit, said 8 | Obesity expert says mocking the overweight 'should be illegal'. Stock photo . Poking fun at fat people should be treated as seriously as racism and sexism, researchers have said. Obesity expert Dr Sarah Jackson said that the law should protect against weight discrimination, in the same way at it prohibits singling out people based on their age, gender or race. Dr Jackson, of University College London, spoke out after conducting two studies into the physical and psychological effects of fattism. Her latest study, of more than 5,000 British adults found that those who were made to feel ashamed of their size suffered more symptoms of depression. 9 | (CNN)The birthplace of Coca-Cola, "Gone with the Wind" and Martin Luther King Jr. is still home to the busiest passenger airport in the world. More than 96 million passengers went through Hartsfield-Jackson Atlanta International Airport in 2014, an increase of 1.9% over 2013, according to Airports Council International's preliminary passenger traffic data, released Thursday. With more than 86 million passengers last year, Beijing Capital International Airport remained in second place and continued to close the gap just a bit more with Atlanta, according to the worldwide association of airports. It saw an increase of 2.9 10 | Summarize the following news article: If you thought Rosie Huntington-Whiteley slept in baggy jumpers and elasticated bottoms, you can think again. As her new campaign images attest, when you're one of Britain's most famous supermodels, sleeping in silk and satin is the only option. The model and lingerie designer has unveiled her Rosie for Autograph summer sleepwear collection - and it's full of mix and match pieces featuring sophisticated hues of slate blue and silver with colour pops of peach and floral prints. Rosie, 27, models her sophisticated summer sleepwear range for Autograph at M& 11 | -------------------------------------------------------------------------------- /summarization/output.summary: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/output.summary -------------------------------------------------------------------------------- /summarization/run_summarization_task.sh: -------------------------------------------------------------------------------- 1 | python summarize.py --model_name mosaicml/mpt-7b \ 2 | --model_path ../models/mpt-7b-keyformer \ 3 | --dataset_path ./data/cnn_eval.json \ 4 | --save_path ./out_model.summary \ 5 | --score_path ./out_model.score \ 6 | --device cuda \ 7 | --task summarization \ 8 | --bs 1 \ 9 | --dtype bfloat16 \ 10 | --causal_lm \ 11 | --early_stopping \ 12 | --output_summaries_only \ 13 | --padding_side left \ 14 | --beam 4 \ 15 | --keyformer \ 16 | --kv_cache 60 \ 17 | --recent 30 \ 18 | --tau_init 1 \ 19 | --tau_end 2 \ 20 | --no_repeat_ngram_size 0 \ 21 | --repetition_penalty 1 \ 22 | --max_tokenizer_length 1920 \ 23 | --max_new_tokens 128 \ 24 | --min_gen_length 30 \ 25 | --num_return_sequences 1 \ 26 | --seed 12345 \ 27 | --n_obs 10 \ 28 | --model_parallelize 29 | -------------------------------------------------------------------------------- /summarization/utils.py: -------------------------------------------------------------------------------- 1 | """From Huggingface Transformers.""" 2 | import itertools 3 | import json 4 | import linecache 5 | import os 6 | import pickle 7 | import warnings 8 | from logging import getLogger 9 | from pathlib import Path 10 | from typing import Callable, Dict, Iterable, List 11 | import io 12 | 13 | import numpy as np 14 | import torch 15 | import nltk 16 | from rouge_score import rouge_scorer, scoring 17 | 18 | from torch import nn 19 | from torch.utils.data import Dataset, Sampler 20 | 21 | 22 | def _make_r_io_base(f, mode: str): 23 | if not isinstance(f, io.IOBase): 24 | f = open(f, mode=mode) 25 | return f 26 | 27 | 28 | def jload(f, mode="r"): 29 | """Load a .json file into a dictionary.""" 30 | f = _make_r_io_base(f, mode) 31 | jdict = json.load(f) 32 | f.close() 33 | return jdict 34 | 35 | 36 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100): 37 | """From fairseq""" 38 | if target.dim() == lprobs.dim() - 1: 39 | target = target.unsqueeze(-1) 40 | nll_loss = -lprobs.gather(dim=-1, index=target) 41 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 42 | if ignore_index is not None: 43 | pad_mask = target.eq(ignore_index) 44 | nll_loss.masked_fill_(pad_mask, 0.0) 45 | smooth_loss.masked_fill_(pad_mask, 0.0) 46 | else: 47 | nll_loss = nll_loss.squeeze(-1) 48 | smooth_loss = smooth_loss.squeeze(-1) 49 | 50 | nll_loss = nll_loss.sum() # mean()? Scared to break other math. 51 | smooth_loss = smooth_loss.sum() 52 | eps_i = epsilon / lprobs.size(-1) 53 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss 54 | return loss, nll_loss 55 | 56 | 57 | def encode_line( 58 | tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt" 59 | ): 60 | extra_kw = ( 61 | {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {} 62 | ) 63 | return tokenizer( 64 | [line], 65 | max_length=max_length, 66 | padding="max_length" if pad_to_max_length else None, 67 | truncation=True, 68 | return_tensors=return_tensors, 69 | **extra_kw, 70 | ) 71 | 72 | 73 | def lmap(f: Callable, x: Iterable) -> List: 74 | """list(map(f, x))""" 75 | return list(map(f, x)) 76 | 77 | 78 | def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict: 79 | """Uses sacrebleu's corpus_bleu implementation.""" 80 | return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score} 81 | 82 | 83 | def trim_batch( 84 | input_ids, 85 | pad_token_id, 86 | attention_mask=None, 87 | ): 88 | """Remove columns that are populated exclusively by pad_token_id""" 89 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) 90 | if attention_mask is None: 91 | return input_ids[:, keep_column_mask] 92 | else: 93 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) 94 | 95 | 96 | class Seq2SeqDataset(Dataset): 97 | """class Seq2SeqDataset""" 98 | 99 | def __init__( 100 | self, 101 | tokenizer, 102 | data_dir, 103 | max_source_length, 104 | max_target_length, 105 | type_path="train", 106 | n_obs=None, 107 | src_lang=None, 108 | tgt_lang=None, 109 | prefix="", 110 | ): 111 | super().__init__() 112 | self.src_file = Path(data_dir).joinpath(type_path + ".source") 113 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target") 114 | self.src_lens = self.get_char_lens(self.src_file) 115 | self.max_source_length = max_source_length 116 | self.max_target_length = max_target_length 117 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}" 118 | self.tokenizer = tokenizer 119 | self.prefix = prefix 120 | if n_obs is not None: 121 | self.src_lens = self.src_lens[:n_obs] 122 | self.pad_token_id = self.tokenizer.pad_token_id 123 | self.src_lang = src_lang 124 | self.tgt_lang = tgt_lang 125 | 126 | def __len__(self): 127 | return len(self.src_lens) 128 | 129 | def __getitem__(self, index) -> Dict[str, torch.Tensor]: 130 | index = index + 1 # linecache starts at 1 131 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 132 | "\n" 133 | ) 134 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 135 | assert source_line, f"empty source line for index {index}" 136 | assert tgt_line, f"empty tgt line for index {index}" 137 | source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length) 138 | target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length) 139 | 140 | source_ids = source_inputs["input_ids"].squeeze() 141 | target_ids = target_inputs["input_ids"].squeeze() 142 | src_mask = source_inputs["attention_mask"].squeeze() 143 | return { 144 | "input_ids": source_ids, 145 | "attention_mask": src_mask, 146 | "decoder_input_ids": target_ids, 147 | } 148 | 149 | @staticmethod 150 | def get_char_lens(data_file): 151 | return [len(x) for x in Path(data_file).open().readlines()] 152 | 153 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 154 | """Reorganize batch data.""" 155 | input_ids = torch.stack([x["input_ids"] for x in batch]) 156 | masks = torch.stack([x["attention_mask"] for x in batch]) 157 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch]) 158 | pad_token_id = self.pad_token_id 159 | y = trim_batch(target_ids, pad_token_id) 160 | source_ids, source_mask = trim_batch( 161 | input_ids, pad_token_id, attention_mask=masks 162 | ) 163 | batch = { 164 | "input_ids": source_ids, 165 | "attention_mask": source_mask, 166 | "decoder_input_ids": y, 167 | } 168 | return batch 169 | 170 | def make_sortish_sampler(self, batch_size): 171 | return SortishSampler(self.src_lens, batch_size) 172 | 173 | 174 | class TranslationDataset(Seq2SeqDataset): 175 | """A dataset that calls prepare_seq2seq_batch.""" 176 | 177 | def __init__(self, *args, **kwargs): 178 | super().__init__(*args, **kwargs) 179 | if self.max_source_length != self.max_target_length: 180 | warnings.warn( 181 | f"Mbart is using sequence lengths {self.max_source_length}, \ 182 | {self.max_target_length}. " 183 | f"Imbalanced sequence lengths may be undesired for \ 184 | translation tasks" 185 | ) 186 | 187 | def __getitem__(self, index) -> Dict[str, str]: 188 | index = index + 1 # linecache starts at 1 189 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip( 190 | "\n" 191 | ) 192 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n") 193 | assert source_line, f"empty source line for index {index}" 194 | assert tgt_line, f"empty tgt line for index {index}" 195 | return { 196 | "tgt_texts": tgt_line, 197 | "src_texts": source_line, 198 | } 199 | 200 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]: 201 | batch_encoding = self.tokenizer.prepare_seq2seq_batch( 202 | [x["src_texts"] for x in batch], 203 | src_lang=self.src_lang, 204 | tgt_texts=[x["tgt_texts"] for x in batch], 205 | tgt_lang=self.tgt_lang, 206 | max_length=self.max_source_length, 207 | max_target_length=self.max_target_length, 208 | ) 209 | return batch_encoding.data 210 | 211 | 212 | class SortishSampler(Sampler): 213 | """ 214 | Go through the text data by order of src length with a bit of randomness. 215 | From fastai repo. 216 | """ 217 | 218 | def __init__(self, data, batch_size): 219 | self.data, self.bs = data, batch_size 220 | 221 | def key(self, i): 222 | return self.data[i] 223 | 224 | def __len__(self) -> int: 225 | return len(self.data) 226 | 227 | def __iter__(self): 228 | idxs = np.random.permutation(len(self.data)) 229 | sz = self.bs * 50 230 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 231 | sort_idx = np.concatenate( 232 | [sorted(s, key=self.key, reverse=True) for s in ck_idx] 233 | ) 234 | sz = self.bs 235 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 236 | max_ck = np.argmax( 237 | [self.key(ck[0]) for ck in ck_idx] 238 | ) # find the chunk with the largest key, 239 | ck_idx[0], ck_idx[max_ck] = ( 240 | ck_idx[max_ck], 241 | ck_idx[0], 242 | ) # then make sure it goes first. 243 | sort_idx = ( 244 | np.concatenate(np.random.permutation(ck_idx[1:])) 245 | if len(ck_idx) > 1 246 | else np.array([], dtype=np.int) 247 | ) 248 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 249 | return iter(sort_idx) 250 | 251 | 252 | logger = getLogger(__name__) 253 | 254 | 255 | def use_task_specific_params(model, task): 256 | """Update config with summarization specific params.""" 257 | task_specific_params = model.config.task_specific_params 258 | 259 | if task_specific_params is not None: 260 | pars = task_specific_params.get(task, {}) 261 | print("Task Params: ", pars) 262 | logger.info(f"using task specific params for {task}: {pars}") 263 | model.config.update(pars) 264 | 265 | 266 | def pickle_load(path): 267 | """pickle.load(path)""" 268 | with open(path, "rb") as f: 269 | return pickle.load(f) 270 | 271 | 272 | def pickle_save(obj, path): 273 | """pickle.dump(obj, path)""" 274 | with open(path, "wb") as f: 275 | return pickle.dump(obj, f) 276 | 277 | 278 | def flatten_list(summary_ids: List[List]): 279 | return [x for x in itertools.chain.from_iterable(summary_ids)] 280 | 281 | 282 | def save_git_info(folder_path: str) -> None: 283 | """Save git information to output_dir/git_log.json""" 284 | repo_infos = get_git_info() 285 | save_json(repo_infos, os.path.join(folder_path, "git_log.json")) 286 | 287 | 288 | def save_json(content, path): 289 | with open(path, "w") as f: 290 | json.dump(content, f, indent=4) 291 | 292 | 293 | def load_json(path): 294 | with open(path) as f: 295 | return json.load(f) 296 | 297 | 298 | def get_git_info(): 299 | repo = git.Repo(search_parent_directories=True) 300 | repo_infos = { 301 | "repo_id": str(repo), 302 | "repo_sha": str(repo.head.object.hexsha), 303 | "repo_branch": str(repo.active_branch), 304 | } 305 | return repo_infos 306 | 307 | 308 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"] 309 | 310 | 311 | def calculate_rouge( 312 | output_lns: List[str], reference_lns: List[str], use_stemmer=True 313 | ) -> Dict: 314 | """Calculate rouge scores""" 315 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer) 316 | aggregator = scoring.BootstrapAggregator() 317 | 318 | for reference_ln, output_ln in zip(reference_lns, output_lns): 319 | scores = scorer.score(reference_ln, output_ln) 320 | aggregator.add_scores(scores) 321 | 322 | result = aggregator.aggregate() 323 | return {k: v.mid.fmeasure for k, v in result.items()} 324 | 325 | 326 | def postprocess_text(preds, targets): 327 | preds = [pred.strip() for pred in preds] 328 | targets = [target.strip() for target in targets] 329 | 330 | # rougeLSum expects newline after each sentence 331 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 332 | targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets] 333 | 334 | return preds, targets 335 | 336 | 337 | def freeze_params(model: nn.Module): 338 | for par in model.parameters(): 339 | par.requires_grad = False 340 | 341 | 342 | def grad_status(model: nn.Module) -> Iterable: 343 | return (par.requires_grad for par in model.parameters()) 344 | 345 | 346 | def any_requires_grad(model: nn.Module) -> bool: 347 | return any(grad_status(model)) 348 | 349 | 350 | def assert_all_frozen(model): 351 | model_grads: List[bool] = list(grad_status(model)) 352 | n_require_grad = sum(lmap(int, model_grads)) 353 | npars = len(model_grads) 354 | assert not any( 355 | model_grads 356 | ), f"{n_require_grad/npars:.1%} of {npars} weights require grad" 357 | 358 | 359 | def assert_not_all_frozen(model): 360 | model_grads: List[bool] = list(grad_status(model)) 361 | npars = len(model_grads) 362 | assert any(model_grads), f"none of {npars} weights require grad" 363 | --------------------------------------------------------------------------------