├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── blog
├── README.md
└── images
│ ├── Attention.png
│ ├── Keyformer_decoding.gif
│ ├── Keyformer_overview.gif
│ ├── accuracy.png
│ ├── attn_weights.png
│ ├── keyformer_decoding.gif
│ ├── long_context_accuracy.png
│ ├── motivation.png
│ ├── performance.png
│ ├── sparsity.png
│ ├── speedup.png
│ └── uneven_score.gif
├── conda-env.yml
├── conversation
├── README.md
├── data
│ └── data_cache
│ │ └── data_cache
├── dataset.py
├── dataset_download
│ ├── download_orca.py
│ └── download_soda.py
├── dialogue.py
├── run_conversation_task.sh
└── utils.py
├── lm_eval_harness
├── README.md
├── evaluate_task_result.py
├── experiments.sh
├── generate_task_data.py
├── model_eval_data
│ └── model_eval_data
├── output
│ └── output
├── run.sh
├── run_lm_eval_harness.py
├── task_data
│ └── task_data
└── tasks
│ ├── __init__.py
│ ├── eval_harness.py
│ └── util.py
├── models
├── README.md
├── cerebras-keyformer-lib
│ ├── config.json
│ ├── configuration_gpt2.py
│ ├── modeling_gpt2.py
│ └── modeling_gpt2_lm_eval_harness.py
├── gptj-keyformer-lib
│ ├── config.json
│ ├── configuration_gptj.py
│ ├── modeling_gptj.py
│ └── modeling_gptj_lm_eval_harness.py
├── model_download
│ └── download_model.py
└── mpt-keyformer-lib
│ ├── attention.py
│ ├── attention_keyformer_new_bias.py
│ ├── attention_keyformer_old_bias.py
│ ├── attention_lm_eval_harness.py
│ ├── attention_streaming_llm.py
│ ├── blocks.py
│ ├── config.json
│ ├── configuration_mpt.py
│ └── modeling_mpt.py
├── requirements.txt
└── summarization
├── README.md
├── __pycache__
├── dataset.cpython-310.pyc
├── dataset.cpython-39.pyc
├── utils.cpython-310.pyc
└── utils.cpython-39.pyc
├── data
└── data_cache
│ ├── .._data_data_cache_cnn_dailymail_3.0.0_0.0.0_d3306fda2d44efc3.lock
│ └── data_cache
├── dataset.py
├── dataset_download
├── download_cnndm.py
├── download_longbench.py
└── download_xsum.py
├── dataset_synthetic_long_context.py
├── out_model.score
├── out_model.summary
├── output.summary
├── run_summarization_task.sh
├── summarize.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/*.bin
2 | models/model/
3 | plots/
4 | **/*.pdf
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Thank you for your interest in contributing to Keyformer! We are in the middle of preparing the next steps to invite community contributions and make keyformer an integral part of the LLM architecture.
2 |
3 | You can still report any issues you observe with the usage of keyformer. For Issue reporting, we request you to please follow these guidelines -
4 |
5 | - If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it. If not, please file a new issue, providing as much relevant information as possible.
6 |
7 | ### Thank You
8 | Finally, thank you for your interest in contributing Keyformer and helping this make a great tool for everyone.
9 |
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference
2 |
3 | **Keyformer** proposes KV Cache reduction through key tokens identification and without the need
4 | for fine-tuning.
5 |
6 | This repository contains the source code implementation of **Keyformer**.
7 |
8 | This source code is available under the [Apache 2.0 License](LICENSE).
9 |
10 | For the readers, to get a quick overview of how **keyformer** reduces the KV Cache size and speeds up LLM inference, we invite you to browse through our blog -
11 |
12 | - [🎫 Blog](blog/README.md) - Quick summary of Keyformer
13 |
14 | But if you'd rather see **keyformer** in action first, please continue reading.
15 |
16 | ## ⚙️ Environment Setup
17 |
18 | #### conda Environment
19 | You can create a conda environment using the `conda-env.yml` file.
20 | ```bash
21 | conda env create --file=conda-env.yml
22 | conda activate keyformer-env
23 | ```
24 |
25 | ## 💻 System Requirements
26 |
27 | Right now, **keyformer** has been tested on a 1xA100 node for `mosaicml/mpt-7b` and
28 | `EleutherAI/gpt-j-6B` with the models in `FP32` and evaluation done using `FP16` data formats.
29 |
30 | We welcome contributions to port and evaluate **keyformer** with different
31 | data formats, targeting lower model memory requirements and overall KVCache size
32 |
33 | ## 🏁 Model Download and Integration with Keyformer
34 |
35 | Note - **Keyformer** has been evaluated on the following models and we restrict this tutorial to the use of these -
36 | `['cerebras/Cerebras-GPT-6.7B',
37 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']`
38 |
39 | Clone the relevant model from huggingface into the `models` directory. For instance, to clone `mosaicml/mpt-7b` into `models/mpt-7b-keyformer`, do the following -
40 |
41 | ```bash
42 | cd models
43 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mosaicml/mpt-7b mpt-7b-keyformer
44 | cd ..
45 | ```
46 |
47 | Then, download the model weights (`*.bin`) using `download_model.py` file and use the model_name as the argument.
48 |
49 | ```bash
50 | cd models/model_download
51 | python3 download_model.py --model_name mosaicml/mpt-7b
52 | cd ../../
53 | ```
54 |
55 | Please keep in mind that supported arguments for `model_name` are restricted to `['cerebras/Cerebras-GPT-6.7B',
56 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']`
57 |
58 | This step will download the model parameters by creating a `model` folder inside `model_download` directory.
59 | Downloaded model parameters are required to be moved to the same directory where model files are available.
60 |
61 | To move the model parameters from `models/model_download/model/` to `models/mpt-7b-keyformer`
62 |
63 | ```bash
64 | mv models/model_download/model/* models/mpt-7b-keyformer/.
65 | ```
66 |
67 | ### Integrating Keyformer
68 | Copy the **Keyformer** source files from `models/mpt-keyformer-lib` to `models/mpt-7b-keyformer`
69 |
70 | ```bash
71 | cp -r models/mpt-keyformer-lib models/mpt-7b-keyformer
72 | ```
73 |
74 | This sets up `models/mpt-7b-keyformer` to be used with the various tasks described below.
75 |
76 | Similarly, find the keyformer-lib files for other supported models in their respective directories in `models/`.
77 |
78 | ## Using Keyformer for KV Cache reduction
79 |
80 | After setting up a model with `keyformer`, let's run it with a downstream task of interest. In the case of `keyformer`, we provide examples on how to run downstream tasks of **summarization** and **conversation**. Further, for the purposes of evaluation, we provide a step-by-step tutorial on how to use **lm_eval_harness**.
81 |
82 | Depending on your task of interest, please refer to the following links for more details
83 |
84 | - [📖 Summarization](summarization/README.md) - Running the Summarization task with Keyformer
85 | - [💬 Conversation](conversation/README.md) - Running the Conversation task with Keyformer
86 | - [📊 LM Eval Harness](lm_eval_harness/README.md) - Using LM Eval harness for evaluation with Keyformer
87 |
88 | ## TODO
89 |
90 | [ ] Instructions to integrate keyformer with any model from huggingface
91 |
92 | [ ] Using keyformer with quantized models
93 |
94 |
95 | ## Thank You
96 |
97 | Keyformer uses open source components available on [huggingface](https://huggingface.co/), [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b), [cerebras/Cerebras-GPT-6.7B](https://huggingface.co/cerebras/Cerebras-GPT-6.7B), and
98 | [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6b).
99 |
100 |
101 | ## 📝 Citation
102 | ```
103 | @article{2023keyformer,
104 | title={Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference},
105 | author={Adnan, Muhammad and Arunkumar, Akhil and Jain, Gaurav and Nair, Prashant and Soloveychik, Ilya and Kamath, Purushotham},
106 | journal={Proceedings of Machine Learning and Systems},
107 | volume={7},
108 | year={2024}
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/blog/README.md:
--------------------------------------------------------------------------------
1 | # Keyformer: KV Cache reduction through attention sparsification for Efficient Generative Inference
2 | **Muhammad Adnan*1**, **Gaurav Jain2**, **Akhil Arunkumar2**, **Prashant J. Nair1**, **Ilya Soloveychik2**, **Purushotham Kamath2**
3 |
4 | * Work done when the author was an intern at d-Matrix.
5 |
6 | 1 Department of Electrical and Computer Engineering, The University of British Columbia, Vancouver, BC, Canada.
7 | 2 d-Matrix, Santa Clara, California, USA
8 |
9 | **TL;DR:** Generative AI inference is often bottlenecked by growing KV cache. There have been several numerous strategies proposed to compress the KVCache to allow longer inference-time context lengths. However, most of these techniques require fine-tuning and even pre-training in some cases. We introduce **Keyformer**, a novel token inference-time discarding technique to reduce KV cache size to improve the overall inference latency and token generation throughput while preserving accuracy. Keyformer capitalizes on the observation that during generative inference, approximately 90% of the attention weight is concentrated on a select subset of tokens called **key tokens** and discards the *irrelevant* tokens to reduce the overall size of the KVCache. Thus, with employing **Keyformer** we are able to reduce required the KV Cache size by 50% and the latency by up to 2.1x, and boost the token generation throughput by 2.4x, all while preserving the model’s accuracy. Further, we are able to support cases of larger batch sizes which otherwise result in **Out-Of-Memory** errors.
10 |
11 | ## How Keyformer works
12 |
13 | Attention mechanism exhibit varying amounts of sparsity throughout the large number of model decoder layers. As seen in Figure 1(Left), attention sparsity significantly varies for models of the same sizes and all for the same CNN/DailyMail dataset summarization task. On the other hand, Figure 1(Right), through a cumulative distributive function (CDF) shows how the attention score is concentrated within a with small number of tokens during text generation. What this translates into for us is the importance of certain key tokens during token generation and more importantly, the relative irrelevance of a majority of tokens during the same.
14 |
15 |
16 |
17 | 
18 | Figure 1: (Left) Default attention sparsity for different models across layers. (Right) CDF of attention score for different models with 90% of attention score dedicated to 40% of tokens.
19 |
20 |
21 | In this work, **Keyformer**, we exploit this inherent sparsification within the decoder layers by identifying key tokens while still emphasizing on the recent tokens. We further adapt this behavior of discarding tokens by changing the score function and applying regularization to the unnormalized logits for key token(s) identification.
22 |
23 | ### What do we do for Regularization - Gumbel Distribution
24 |
25 | Once we have identified and discarded the irrelevant tokens, it is important we normalize the score function to account for this change. In that regard, we use the Gumbel distribution which enables our model to remain robust and adaptive. As an implementation strategy, we maintain a constant size, *k* of the KVCache and remove the *n − k* tokens from the context to avoid unwanted expansion of the memory.
26 |
27 | ### Bias Towards Initial Tokens
28 |
29 | Previous research has indicated a bias towards initial tokens. For instance, [StreamingLLM](https://arxiv.org/abs/2309.17453) highlights the importance of initial tokens as attention sinks, particularly in streaming applications. Similarly, [H2O](https://arxiv.org/abs/2306.14048) utilizes an accumulated attention score as a score function, which results in a predisposition
30 | towards initial tokens due to the cumulative effect during decoding iterations. To exploit this bias towards initial tokens and effectively model the distribution of maximum values (key tokens), we propose introducing a distribution that is skewed towards initial tokens while simultaneously features an asymmetric profile. This asymmetry introduces a pronounced right tail, which is characteristic of tokens typically drawn from the recent context window. Our choice of distribution is inspired by the [Gumbel distribution](https://arxiv.org/pdf/1502.02708.pdf).
31 |
32 |
33 | 
34 | Figure 2: Overview of Keyformer during multiple phases. Prompt processing phase with n-tokens in KV cache along with induction of noise by Keyformer for key tokens identification. It selects w tokens from the recent window while k − w tokens from remaining n − w tokens to keep k tokens in KV cache. In text generation phase, decoding step with k-tokens in KV cache with tokens discarded from previous iteration.
35 |
36 |
37 | ### Keyformer Score Function
38 |
39 | To overcome the limitations of uneven score distribution and respective key tokens identification, we introduce a novel Keyformer score function $f_{\theta(keyformer)}$. This score function incorporates the Gumbel distribution into the unnormalized logits. However, the discarded tokens are not incorporated in anyway in forming the probability distribution that underlies the score function. To address this oversight and incorporate the discarded tokens, we introduce a temperature parameter denoted as $\tau$ , as shown in below Equation.
40 |
41 | ```math
42 | f_{\theta(keyformer)}(x_i) = e^{\frac{x_i + \zeta_i}{\tau}} / \sum_{j=1}^{k} e^{\frac{x_j + \zeta_j}{\tau}}, \ \ \ \ i=1,2,\dots,k
43 | ```
44 |
45 | 
46 | Figure 3: Design of Keyformer for a single decoder layer.
47 |
48 |
49 | ## Key Results
50 |
51 | We evaluate Keyformer across three significant model families: [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b), [Cerebras-GPT](https://huggingface.co/cerebras/Cerebras-GPT-6.7B), and [MPT](https://huggingface.co/mosaicml/mpt-7b) and on two representative text generation tasks, i.e. summarization task using the [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail) dataset from [HELM](https://crfm.stanford.edu/helm/latest/), and the conversation task with the [SODA](https://huggingface.co/datasets/allenai/soda). GPT-J model is finetuned for summarization task, while Cerebras-GPT and MPT are pretrained models. For conversation tasks, we used the [MPT-chat](https://huggingface.co/mosaicml/mpt-7b-chat) version of the MPT model, which is fine-tuned for dialogue generation. Figure 4 shows that Keyformer achieves the baseline accuracy with 70% prompt KV cache size for Summarization task across different models while 90% of prompt KV cache for Conversation task while other baselines couldn't achieve the baseline accuracy.
52 |
53 |
54 | 
55 | Figure 4: Accuracy comparison of Full Attention, Window Attention, H2O and Keyformer with varying KV cache size. Solid black line shows Full Attention without discarding any token and full KV cache. Red dotted line shows 99% accuracy mark.
56 |
57 |
58 | For long-context scenarios, we turned to the [GovReport](https://huggingface.co/datasets/ccdv/govreport-summarization) for extended document summarization. To tackle long document summarization, we employed the [MPT-storywriter](https://huggingface.co/mosaicml/mpt-7b-storywriter) version of the MPT model, fine-tuned for writing fictional stories with a context length of 65k and the ability to generate content as long as 84k tokens.
59 |
60 |
61 |
62 | 
63 | Figure 5: (Left) Long context summarization using MPT-7B-storywriter model for GovReport dataset with a sequence length
64 | of 8k. (Right) Speedup of Keyformer with 50% KV cache reduction.
65 |
66 |
67 | Figure 5 shows that for long context summarization, Keyformer achieves baseline accuracy with 50% of prompt KV cache, improving the inference latency by 2.1x and token generation throughput by upto 2.4x.
68 |
69 | ## Get Started with Keyformer
70 | We have implemented Keyformer for multiple autoregressive models and provided respective model cards to run different tasks. Please find detailed instructions to use Keyformer [here](../README.md).
71 |
72 | ## Citation
73 | ```
74 | @article{2024keyformer,
75 | title={Keyformer: KV Cache reduction through key tokens selection for Efficient Generative Inference},
76 | author={Adnan, Muhammad and Arunkumar, Akhil and Jain, Gaurav and Nair, Prashant and Soloveychik, Ilya and Kamath, Purushotham},
77 | journal={Proceedings of Machine Learning and Systems},
78 | volume={7},
79 | year={2024}
80 | }
81 | ```
--------------------------------------------------------------------------------
/blog/images/Attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/Attention.png
--------------------------------------------------------------------------------
/blog/images/Keyformer_decoding.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/Keyformer_decoding.gif
--------------------------------------------------------------------------------
/blog/images/Keyformer_overview.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/Keyformer_overview.gif
--------------------------------------------------------------------------------
/blog/images/accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/accuracy.png
--------------------------------------------------------------------------------
/blog/images/attn_weights.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/attn_weights.png
--------------------------------------------------------------------------------
/blog/images/keyformer_decoding.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/keyformer_decoding.gif
--------------------------------------------------------------------------------
/blog/images/long_context_accuracy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/long_context_accuracy.png
--------------------------------------------------------------------------------
/blog/images/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/motivation.png
--------------------------------------------------------------------------------
/blog/images/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/performance.png
--------------------------------------------------------------------------------
/blog/images/sparsity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/sparsity.png
--------------------------------------------------------------------------------
/blog/images/speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/speedup.png
--------------------------------------------------------------------------------
/blog/images/uneven_score.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/blog/images/uneven_score.gif
--------------------------------------------------------------------------------
/conda-env.yml:
--------------------------------------------------------------------------------
1 | name: keyformer-env
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.01.10=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.2=h6a678d5_6
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=1.1.1t=h7f8727e_0
15 | - pip=23.0.1=py39h06a4308_0
16 | - python=3.9.16=h7a1cb2a_2
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=66.0.0=py39h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - wheel=0.38.4=py39h06a4308_0
22 | - xz=5.4.2=h5eee18b_0
23 | - zlib=1.2.13=h5eee18b_0
24 | - pip:
25 | - absl-py==1.4.0
26 | - accelerate==0.18.0
27 | - aiohttp==3.8.4
28 | - aiosignal==1.3.1
29 | - appdirs==1.4.4
30 | - async-timeout==4.0.2
31 | - attrs==23.1.0
32 | - certifi==2022.12.7
33 | - charset-normalizer==3.1.0
34 | - click==8.1.3
35 | - cmake==3.26.3
36 | - datasets==2.14.5
37 | - dill==0.3.6
38 | - docker-pycreds==0.4.0
39 | - evaluate==0.4.0
40 | - filelock==3.12.0
41 | - fire==0.5.0
42 | - frozenlist==1.3.3
43 | - fsspec==2023.4.0
44 | - gitdb==4.0.10
45 | - gitpython==3.1.31
46 | - huggingface-hub==0.14.1
47 | - idna==3.4
48 | - jinja2==3.1.2
49 | - joblib==1.2.0
50 | - lit==16.0.2
51 | - markupsafe==2.1.2
52 | - mpmath==1.3.0
53 | - multidict==6.0.4
54 | - multiprocess==0.70.14
55 | - networkx==3.1
56 | - nltk==3.8.1
57 | - numpy==1.24.3
58 | - nvidia-cublas-cu11==11.10.3.66
59 | - nvidia-cuda-cupti-cu11==11.7.101
60 | - nvidia-cuda-nvrtc-cu11==11.7.99
61 | - nvidia-cuda-runtime-cu11==11.7.99
62 | - nvidia-cudnn-cu11==8.5.0.96
63 | - nvidia-cufft-cu11==10.9.0.58
64 | - nvidia-curand-cu11==10.2.10.91
65 | - nvidia-cusolver-cu11==11.4.0.1
66 | - nvidia-cusparse-cu11==11.7.4.91
67 | - nvidia-nccl-cu11==2.14.3
68 | - nvidia-nvtx-cu11==11.7.91
69 | - openai==0.27.6
70 | - packaging==23.1
71 | - pandas==2.0.1
72 | - pathtools==0.1.2
73 | - protobuf==4.22.4
74 | - psutil==5.9.5
75 | - pyarrow==12.0.0
76 | - python-dateutil==2.8.2
77 | - pytz==2023.3
78 | - pyyaml==6.0
79 | - regex==2023.5.5
80 | - requests==2.30.0
81 | - responses==0.18.0
82 | - rouge-score==0.1.2
83 | - sentencepiece==0.1.99
84 | - sentry-sdk==1.21.1
85 | - setproctitle==1.3.2
86 | - simplejson==3.19.1
87 | - six==1.16.0
88 | - smmap==5.0.0
89 | - sympy==1.11.1
90 | - termcolor==2.3.0
91 | - tokenizers==0.13.3
92 | - torch==2.0.0
93 | - tqdm==4.65.0
94 | - transformers==4.28.1
95 | - triton==2.0.0
96 | - typing-extensions==4.5.0
97 | - tzdata==2023.3
98 | - urllib3==2.0.2
99 | - wandb==0.15.1
100 | - xxhash==3.2.0
101 | - yarl==1.9.2
102 | - einops
103 | - lm-eval
104 | - ftfy
105 | prefix: /nfs/tattafos/miniconda3/envs/summarization-env
106 |
--------------------------------------------------------------------------------
/conversation/README.md:
--------------------------------------------------------------------------------
1 | # 💬 Conversation Task
2 |
3 | ## Download dataset
4 | ```bash
5 | cd dataset_download
6 | python download_soda.py
7 | ```
8 |
9 | ## Getting started with Conversation Task
10 |
11 | To get started with conversation task, follow use below inputs
12 |
13 | ```
14 | python summarize.py --model_name \
15 | --dataset_path \
16 | --save_path \
17 | --score_path \
18 | --model_path \
19 | --attentions_path
20 | --device cuda \ # Device
21 | --task conversation \ # Task
22 | --bs 1 \ # Batch Size
23 | --dtype float16 \ # Data type of model parameters
24 | --causal_lm \ # Causal language model for summarization
25 | --early_stopping \ # Enable early stopping while generation
26 | --output_summaries_only \ # Output only summary - No prompt
27 | --output_sequence_scores \ # Output sequence scores and enable for output attentions
28 | --save_attentions \ # Save attention weights - Token Generation only
29 | --save_prompt_attentions \ # If enabled only prompt attention weights are stored
30 | --padding_side left \ # Padding side
31 | --beam 4 \ # Beam Size
32 | --model_parallelize \ # Parallelize Model across all available GPUs
33 | --keyformer \ # Enable Keyformer
34 | --kv_cache 60 \ # KV cache percentage of prompt length
35 | --recent 30 \ # Recent window percentage
36 | --tau_init 1 \ # Initial temperature parameter for Gumbel
37 | --tau_end 2 \ # End temperature parameter for Gumbel
38 | --no_repeat_ngram_size 0 \
39 | --repetition_penalty 1 \
40 | --max_tokenizer_length 1920 \ # Maximum prompt size
41 | --max_new_tokens 128 \ # Maximum newly generated tokens
42 | --min_gen_length 30 \ # Minimum newly generated tokens
43 | --num_return_sequences 1 \ # Number os return summaries per input
44 | --seed 12345 \ # Random seed value for radom samples
45 | --n_obs 1000 \ # Number of input samples
46 | ```
47 |
48 | Note: For data type of FP16, do not use model.half() instead utilize dtype in model creation [Link](https://stackoverflow.com/questions/69994731/what-is-the-difference-between-cuda-amp-and-model-half)
49 |
--------------------------------------------------------------------------------
/conversation/data/data_cache/data_cache:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/conversation/data/data_cache/data_cache
--------------------------------------------------------------------------------
/conversation/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 | from datasets import load_dataset, load_from_disk
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from torch.nn.functional import pad
8 | from torch.utils.data import DataLoader
9 | from typing import Optional, Dict, Sequence
10 | import utils
11 | import io
12 | import copy
13 | import sys
14 |
15 | PROMPT_DICT = {
16 | "prompt_input": (
17 | "Below is an instruction that describes a task, paired with an input that provides further context. "
18 | "Write a response that appropriately completes the request.\n\n"
19 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
20 | ),
21 | "prompt_no_input": (
22 | "Below is an instruction that describes a task. "
23 | "Write a response that appropriately completes the request.\n\n"
24 | "### Instruction:\n{instruction}\n\n### Response:"
25 | ),
26 | }
27 |
28 |
29 | class Dataset(torch.utils.data.Dataset):
30 | """Characterizes a dataset for PyTorch"""
31 |
32 | def __init__(
33 | self,
34 | dataset_path,
35 | tokenizer,
36 | model_name,
37 | return_tensors,
38 | truncation,
39 | padding,
40 | max_length=None,
41 | total_count_override=None,
42 | perf_count_override=None,
43 | ):
44 | self.dataset = "soda"
45 | self.dataset_path = dataset_path
46 | self.model_name = model_name
47 | self.tokenizer = tokenizer
48 | self.return_tensors = return_tensors
49 | self.truncation = truncation
50 | self.padding = padding
51 | self.max_length = max_length
52 |
53 | self.list_data_dict = utils.jload(self.dataset_path)
54 |
55 | self.examples = [example for example in self.list_data_dict]
56 |
57 | # Getting random samples for evaluation
58 | if total_count_override > 0:
59 | self.rand_samples = np.random.randint(
60 | len(self.examples), size=(total_count_override)
61 | ).tolist()
62 | self.examples = [self.examples[i] for i in self.rand_samples]
63 | self.count = total_count_override
64 | else:
65 | self.count = len(self.examples)
66 |
67 | self.perf_count = perf_count_override or self.count
68 |
69 | def encode_samples(self):
70 | print("Encoding Samples")
71 |
72 | total_samples = self.count
73 |
74 | source_encoded_input_ids = []
75 | source_encoded_attn_masks = []
76 |
77 | for i in range(total_samples):
78 | source_encoded = self.tokenizer(
79 | self.sources[i],
80 | return_tensors=self.return_tensors,
81 | padding=self.padding,
82 | truncation=self.truncation,
83 | max_length=self.max_length,
84 | )
85 | source_encoded_input_ids.append(source_encoded.input_ids)
86 | source_encoded_attn_masks.append(source_encoded.attention_mask)
87 |
88 | return source_encoded_input_ids, source_encoded_attn_masks
89 |
90 | def __len__(self):
91 | return self.count
92 |
93 | def __getitem__(self, index):
94 | return self.examples[index]
95 |
--------------------------------------------------------------------------------
/conversation/dataset_download/download_orca.py:
--------------------------------------------------------------------------------
1 | # experiment config
2 | dataset_id = "Open-Orca/OpenOrca"
3 |
4 | from datasets import load_dataset, concatenate_datasets
5 | from transformers import AutoTokenizer
6 | import numpy as np
7 | import os
8 | import simplejson as json
9 | import sys
10 |
11 | save_dataset_path = os.environ.get("DATASET_ORCA_PATH", "../data")
12 |
13 | # Check whether the specified path exists or not
14 | isExist = os.path.exists(save_dataset_path)
15 | if not isExist:
16 | # Create a new directory because it does not exist
17 | os.makedirs(save_dataset_path)
18 |
19 | # Load dataset from the hub
20 | print("Loading Dataset!!")
21 | dataset = load_dataset(dataset_id, cache_dir="../data/data_cache")
22 | print("Dataset Loaded!!")
23 |
24 |
25 | def preprocess_function(sample):
26 | # create list of samples
27 | inputs = []
28 |
29 | for i in range(0, len(sample["id"])):
30 | x = dict()
31 | x["instruction"] = sample["system_prompt"][i]
32 | x["input"] = sample["question"][i]
33 | x["output"] = sample["response"][i]
34 |
35 | inputs.append(x)
36 | model_inputs = dict()
37 | model_inputs["text"] = inputs
38 |
39 | return model_inputs
40 |
41 |
42 | # process dataset
43 | tokenized_dataset = dataset.map(preprocess_function, batched=True)
44 |
45 | # save dataset to disk
46 |
47 | with open(os.path.join(save_dataset_path, "orca_train.json"), "w") as write_f:
48 | json.dump(tokenized_dataset["train"]["text"], write_f, indent=4, ensure_ascii=False)
49 |
50 |
51 | print("Dataset saved in ", save_dataset_path)
52 |
--------------------------------------------------------------------------------
/conversation/dataset_download/download_soda.py:
--------------------------------------------------------------------------------
1 | # experiment config
2 | dataset_id = "allenai/soda"
3 |
4 | from datasets import load_dataset, concatenate_datasets
5 | from transformers import AutoTokenizer
6 | import numpy as np
7 | import os
8 | import simplejson as json
9 | import sys
10 |
11 | save_dataset_path = os.environ.get("DATASET_SODA_PATH", "../data")
12 |
13 | # Check whether the specified path exists or not
14 | isExist = os.path.exists(save_dataset_path)
15 | if not isExist:
16 | # Create a new directory because it does not exist
17 | os.makedirs(save_dataset_path)
18 |
19 | # Load dataset from the hub
20 | dataset = load_dataset(dataset_id, cache_dir="../data/data_cache")
21 |
22 |
23 | def preprocess_function(sample):
24 | # create list of samples
25 | inputs = []
26 |
27 | for i in range(0, len(sample["head"])):
28 | x = dict()
29 | x["head"] = sample["head"][i]
30 | x["relation"] = sample["relation"][i]
31 | x["tail"] = sample["tail"][i]
32 | x["literal"] = sample["literal"][i]
33 | x["narrative"] = sample["narrative"][i]
34 | x["dialogue"] = sample["dialogue"][i]
35 | x["speakers"] = sample["speakers"][i]
36 | x["PersonX"] = sample["PersonX"][i]
37 | x["PersonY"] = sample["PersonY"][i]
38 | x["PersonZ"] = sample["PersonZ"][i]
39 | x["original_index"] = sample["original_index"][i]
40 | x["split"] = sample["split"][i]
41 | x["head_answer"] = sample["head_answer"][i]
42 | x["pmi_head_answer"] = sample["pmi_head_answer"][i]
43 | x["relation_tail_answer"] = sample["relation_tail_answer"][i]
44 | x["pmi_relation_tail_answer"] = sample["pmi_relation_tail_answer"][i]
45 |
46 | inputs.append(x)
47 | model_inputs = dict()
48 | model_inputs["text"] = inputs
49 |
50 | return model_inputs
51 |
52 |
53 | # process dataset
54 | tokenized_dataset = dataset.map(
55 | preprocess_function, batched=True, remove_columns=list(dataset["train"].features)
56 | )
57 |
58 | # save dataset to disk
59 |
60 | with open(os.path.join(save_dataset_path, "soda_eval.json"), "w") as write_f:
61 | json.dump(
62 | tokenized_dataset["validation"]["text"], write_f, indent=4, ensure_ascii=False
63 | )
64 |
65 |
66 | print("Dataset saved in ", save_dataset_path)
67 |
--------------------------------------------------------------------------------
/conversation/run_conversation_task.sh:
--------------------------------------------------------------------------------
1 | python dialogue.py --model_name \
2 | --model_path \
3 | --dataset_path \
4 | --save_path \
5 | --score_path \
6 | --device cuda \
7 | --task conversation \
8 | --bs 1 \
9 | --dtype bfloat16 \
10 | --causal_lm \
11 | --early_stopping \
12 | --output_summaries_only \
13 | --padding_side left \
14 | --beam 4 \
15 | --model_parallelize \
16 | --keyformer \
17 | --kv_cache 60 \
18 | --recent 30 \
19 | --tau_init 1 \
20 | --tau_end 2 \
21 | --no_repeat_ngram_size 0 \
22 | --repetition_penalty 1 \
23 | --max_tokenizer_length 1920 \
24 | --max_new_tokens 128 \
25 | --min_gen_length 30 \
26 | --num_return_sequences 1 \
27 | --seed 12345 \
28 | --n_obs 1000
--------------------------------------------------------------------------------
/conversation/utils.py:
--------------------------------------------------------------------------------
1 | """From Huggingface Transformers."""
2 | import itertools
3 | import json
4 | import linecache
5 | import os
6 | import pickle
7 | import warnings
8 | from logging import getLogger
9 | from pathlib import Path
10 | from typing import Callable, Dict, Iterable, List
11 | import io
12 |
13 | import numpy as np
14 | import torch
15 | import nltk
16 | from rouge_score import rouge_scorer, scoring
17 |
18 | from torch import nn
19 | from torch.utils.data import Dataset, Sampler
20 |
21 |
22 | def _make_r_io_base(f, mode: str):
23 | if not isinstance(f, io.IOBase):
24 | f = open(f, mode=mode)
25 | return f
26 |
27 |
28 | def jload(f, mode="r"):
29 | """Load a .json file into a dictionary."""
30 | f = _make_r_io_base(f, mode)
31 | jdict = json.load(f)
32 | f.close()
33 | return jdict
34 |
35 |
36 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=-100):
37 | """From fairseq"""
38 | if target.dim() == lprobs.dim() - 1:
39 | target = target.unsqueeze(-1)
40 | nll_loss = -lprobs.gather(dim=-1, index=target)
41 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
42 | if ignore_index is not None:
43 | pad_mask = target.eq(ignore_index)
44 | nll_loss.masked_fill_(pad_mask, 0.0)
45 | smooth_loss.masked_fill_(pad_mask, 0.0)
46 | else:
47 | nll_loss = nll_loss.squeeze(-1)
48 | smooth_loss = smooth_loss.squeeze(-1)
49 |
50 | nll_loss = nll_loss.sum() # mean()? Scared to break other math.
51 | smooth_loss = smooth_loss.sum()
52 | eps_i = epsilon / lprobs.size(-1)
53 | loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
54 | return loss, nll_loss
55 |
56 |
57 | def encode_line(
58 | tokenizer, line, max_length, pad_to_max_length=True, return_tensors="pt"
59 | ):
60 | extra_kw = (
61 | {"add_prefix_space": True} if isinstance(tokenizer, BartTokenizer) else {}
62 | )
63 | return tokenizer(
64 | [line],
65 | max_length=max_length,
66 | padding="max_length" if pad_to_max_length else None,
67 | truncation=True,
68 | return_tensors=return_tensors,
69 | **extra_kw,
70 | )
71 |
72 |
73 | def lmap(f: Callable, x: Iterable) -> List:
74 | """list(map(f, x))"""
75 | return list(map(f, x))
76 |
77 |
78 | def calculate_bleu_score(output_lns, refs_lns, **kwargs) -> dict:
79 | """Uses sacrebleu's corpus_bleu implementation."""
80 | return {"bleu": corpus_bleu(output_lns, [refs_lns], **kwargs).score}
81 |
82 |
83 | def trim_batch(
84 | input_ids,
85 | pad_token_id,
86 | attention_mask=None,
87 | ):
88 | """Remove columns that are populated exclusively by pad_token_id"""
89 | keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)
90 | if attention_mask is None:
91 | return input_ids[:, keep_column_mask]
92 | else:
93 | return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])
94 |
95 |
96 | class Seq2SeqDataset(Dataset):
97 | """class Seq2SeqDataset"""
98 |
99 | def __init__(
100 | self,
101 | tokenizer,
102 | data_dir,
103 | max_source_length,
104 | max_target_length,
105 | type_path="train",
106 | n_obs=None,
107 | src_lang=None,
108 | tgt_lang=None,
109 | prefix="",
110 | ):
111 | super().__init__()
112 | self.src_file = Path(data_dir).joinpath(type_path + ".source")
113 | self.tgt_file = Path(data_dir).joinpath(type_path + ".target")
114 | self.src_lens = self.get_char_lens(self.src_file)
115 | self.max_source_length = max_source_length
116 | self.max_target_length = max_target_length
117 | assert min(self.src_lens) > 0, f"found empty line in {self.src_file}"
118 | self.tokenizer = tokenizer
119 | self.prefix = prefix
120 | if n_obs is not None:
121 | self.src_lens = self.src_lens[:n_obs]
122 | self.pad_token_id = self.tokenizer.pad_token_id
123 | self.src_lang = src_lang
124 | self.tgt_lang = tgt_lang
125 |
126 | def __len__(self):
127 | return len(self.src_lens)
128 |
129 | def __getitem__(self, index) -> Dict[str, torch.Tensor]:
130 | index = index + 1 # linecache starts at 1
131 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
132 | "\n"
133 | )
134 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
135 | assert source_line, f"empty source line for index {index}"
136 | assert tgt_line, f"empty tgt line for index {index}"
137 | source_inputs = encode_line(self.tokenizer, source_line, self.max_source_length)
138 | target_inputs = encode_line(self.tokenizer, tgt_line, self.max_target_length)
139 |
140 | source_ids = source_inputs["input_ids"].squeeze()
141 | target_ids = target_inputs["input_ids"].squeeze()
142 | src_mask = source_inputs["attention_mask"].squeeze()
143 | return {
144 | "input_ids": source_ids,
145 | "attention_mask": src_mask,
146 | "decoder_input_ids": target_ids,
147 | }
148 |
149 | @staticmethod
150 | def get_char_lens(data_file):
151 | return [len(x) for x in Path(data_file).open().readlines()]
152 |
153 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
154 | """Reorganize batch data."""
155 | input_ids = torch.stack([x["input_ids"] for x in batch])
156 | masks = torch.stack([x["attention_mask"] for x in batch])
157 | target_ids = torch.stack([x["decoder_input_ids"] for x in batch])
158 | pad_token_id = self.pad_token_id
159 | y = trim_batch(target_ids, pad_token_id)
160 | source_ids, source_mask = trim_batch(
161 | input_ids, pad_token_id, attention_mask=masks
162 | )
163 | batch = {
164 | "input_ids": source_ids,
165 | "attention_mask": source_mask,
166 | "decoder_input_ids": y,
167 | }
168 | return batch
169 |
170 | def make_sortish_sampler(self, batch_size):
171 | return SortishSampler(self.src_lens, batch_size)
172 |
173 |
174 | class TranslationDataset(Seq2SeqDataset):
175 | """A dataset that calls prepare_seq2seq_batch."""
176 |
177 | def __init__(self, *args, **kwargs):
178 | super().__init__(*args, **kwargs)
179 | if self.max_source_length != self.max_target_length:
180 | warnings.warn(
181 | f"Mbart is using sequence lengths {self.max_source_length}, \
182 | {self.max_target_length}. "
183 | f"Imbalanced sequence lengths may be undesired for \
184 | translation tasks"
185 | )
186 |
187 | def __getitem__(self, index) -> Dict[str, str]:
188 | index = index + 1 # linecache starts at 1
189 | source_line = self.prefix + linecache.getline(str(self.src_file), index).rstrip(
190 | "\n"
191 | )
192 | tgt_line = linecache.getline(str(self.tgt_file), index).rstrip("\n")
193 | assert source_line, f"empty source line for index {index}"
194 | assert tgt_line, f"empty tgt line for index {index}"
195 | return {
196 | "tgt_texts": tgt_line,
197 | "src_texts": source_line,
198 | }
199 |
200 | def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
201 | batch_encoding = self.tokenizer.prepare_seq2seq_batch(
202 | [x["src_texts"] for x in batch],
203 | src_lang=self.src_lang,
204 | tgt_texts=[x["tgt_texts"] for x in batch],
205 | tgt_lang=self.tgt_lang,
206 | max_length=self.max_source_length,
207 | max_target_length=self.max_target_length,
208 | )
209 | return batch_encoding.data
210 |
211 |
212 | class SortishSampler(Sampler):
213 | """
214 | Go through the text data by order of src length with a bit of randomness.
215 | From fastai repo.
216 | """
217 |
218 | def __init__(self, data, batch_size):
219 | self.data, self.bs = data, batch_size
220 |
221 | def key(self, i):
222 | return self.data[i]
223 |
224 | def __len__(self) -> int:
225 | return len(self.data)
226 |
227 | def __iter__(self):
228 | idxs = np.random.permutation(len(self.data))
229 | sz = self.bs * 50
230 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
231 | sort_idx = np.concatenate(
232 | [sorted(s, key=self.key, reverse=True) for s in ck_idx]
233 | )
234 | sz = self.bs
235 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
236 | max_ck = np.argmax(
237 | [self.key(ck[0]) for ck in ck_idx]
238 | ) # find the chunk with the largest key,
239 | ck_idx[0], ck_idx[max_ck] = (
240 | ck_idx[max_ck],
241 | ck_idx[0],
242 | ) # then make sure it goes first.
243 | sort_idx = (
244 | np.concatenate(np.random.permutation(ck_idx[1:]))
245 | if len(ck_idx) > 1
246 | else np.array([], dtype=np.int)
247 | )
248 | sort_idx = np.concatenate((ck_idx[0], sort_idx))
249 | return iter(sort_idx)
250 |
251 |
252 | logger = getLogger(__name__)
253 |
254 |
255 | def use_task_specific_params(model, task):
256 | """Update config with summarization specific params."""
257 | task_specific_params = model.config.task_specific_params
258 |
259 | if task_specific_params is not None:
260 | pars = task_specific_params.get(task, {})
261 | print("Task Params: ", pars)
262 | logger.info(f"using task specific params for {task}: {pars}")
263 | model.config.update(pars)
264 |
265 |
266 | def pickle_load(path):
267 | """pickle.load(path)"""
268 | with open(path, "rb") as f:
269 | return pickle.load(f)
270 |
271 |
272 | def pickle_save(obj, path):
273 | """pickle.dump(obj, path)"""
274 | with open(path, "wb") as f:
275 | return pickle.dump(obj, f)
276 |
277 |
278 | def flatten_list(summary_ids: List[List]):
279 | return [x for x in itertools.chain.from_iterable(summary_ids)]
280 |
281 |
282 | def save_git_info(folder_path: str) -> None:
283 | """Save git information to output_dir/git_log.json"""
284 | repo_infos = get_git_info()
285 | save_json(repo_infos, os.path.join(folder_path, "git_log.json"))
286 |
287 |
288 | def save_json(content, path):
289 | with open(path, "w") as f:
290 | json.dump(content, f, indent=4)
291 |
292 |
293 | def load_json(path):
294 | with open(path) as f:
295 | return json.load(f)
296 |
297 |
298 | def get_git_info():
299 | repo = git.Repo(search_parent_directories=True)
300 | repo_infos = {
301 | "repo_id": str(repo),
302 | "repo_sha": str(repo.head.object.hexsha),
303 | "repo_branch": str(repo.active_branch),
304 | }
305 | return repo_infos
306 |
307 |
308 | ROUGE_KEYS = ["rouge1", "rouge2", "rougeL"]
309 |
310 |
311 | def calculate_rouge(
312 | output_lns: List[str], reference_lns: List[str], use_stemmer=True
313 | ) -> Dict:
314 | """Calculate rouge scores"""
315 | scorer = rouge_scorer.RougeScorer(ROUGE_KEYS, use_stemmer=use_stemmer)
316 | aggregator = scoring.BootstrapAggregator()
317 |
318 | for reference_ln, output_ln in zip(reference_lns, output_lns):
319 | scores = scorer.score(reference_ln, output_ln)
320 | aggregator.add_scores(scores)
321 |
322 | result = aggregator.aggregate()
323 | return {k: v.mid.fmeasure for k, v in result.items()}
324 |
325 |
326 | def postprocess_text(preds, targets):
327 | preds = [pred.strip() for pred in preds]
328 | targets = [target.strip() for target in targets]
329 |
330 | # rougeLSum expects newline after each sentence
331 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
332 | targets = ["\n".join(nltk.sent_tokenize(target)) for target in targets]
333 |
334 | return preds, targets
335 |
336 |
337 | def freeze_params(model: nn.Module):
338 | for par in model.parameters():
339 | par.requires_grad = False
340 |
341 |
342 | def grad_status(model: nn.Module) -> Iterable:
343 | return (par.requires_grad for par in model.parameters())
344 |
345 |
346 | def any_requires_grad(model: nn.Module) -> bool:
347 | return any(grad_status(model))
348 |
349 |
350 | def assert_all_frozen(model):
351 | model_grads: List[bool] = list(grad_status(model))
352 | n_require_grad = sum(lmap(int, model_grads))
353 | npars = len(model_grads)
354 | assert not any(
355 | model_grads
356 | ), f"{n_require_grad/npars:.1%} of {npars} weights require grad"
357 |
358 |
359 | def assert_not_all_frozen(model):
360 | model_grads: List[bool] = list(grad_status(model))
361 | npars = len(model_grads)
362 | assert any(model_grads), f"none of {npars} weights require grad"
363 |
--------------------------------------------------------------------------------
/lm_eval_harness/README.md:
--------------------------------------------------------------------------------
1 | # 📊 LM Eval Harness Tasks
2 |
3 | Evaluation of **Keyformer** on [lm-eval-harness](https://github.com/EleutherAI/lm-evaluation-harness) framework tasks.
4 |
5 | ## Generate Task data
6 | ```python
7 | python -u generate_task_data.py \
8 | --output-file ./task_data/${task}-${shots}.jsonl \
9 | --task-name ${task} \
10 | --num-fewshot ${shots}
11 | ```
12 |
13 | ## Generate Output data with model
14 | #### Full Attention
15 | ```python
16 | python -u run_lm_eval_harness.py \
17 | --input-path ./task_data/${task}-${shots}.jsonl \
18 | --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \
19 | --model-name ${model_name} \
20 | --model-path ${model_path} \
21 | --dtype ${dtype} \
22 | --kv_cache ${kv_cache} \
23 | --recent ${recent}
24 | ```
25 | #### Keyformer
26 | ```python
27 | python -u run_lm_eval_harness.py \
28 | --input-path ./task_data/${task}-${shots}.jsonl \
29 | --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \
30 | --model-name ${model_name} \
31 | --model-path ${model_path} \
32 | --dtype ${dtype} \
33 | --keyformer \
34 | --kv_cache ${kv_cache} \
35 | --recent ${recent}
36 | ```
37 |
38 | ## Evaluate the performance
39 | ```python
40 | python -u evaluate_task_result.py \
41 | --result-file ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \
42 | --output-file ./output/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl \
43 | --task-name ${task} \
44 | --num-fewshot ${shots} \
45 | --model-name ${model_name}
46 | ```
--------------------------------------------------------------------------------
/lm_eval_harness/evaluate_task_result.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | from lm_eval import evaluator, tasks
6 | from tasks import EvalHarnessAdaptor
7 |
8 |
9 | def json_to_key(obj):
10 | return json.dumps(obj)
11 |
12 |
13 | if __name__ == "__main__":
14 | parser = argparse.ArgumentParser(
15 | prog="ProgramName",
16 | description="What the program does",
17 | epilog="Text at the bottom of help",
18 | )
19 |
20 | parser.add_argument("--result-file", type=str, default="result.jsonl")
21 | parser.add_argument("--output-file", type=str, default="output.jsonl")
22 | parser.add_argument("--task-name", type=str, default="hellaswag")
23 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
24 | parser.add_argument("--debug", action="store_true", default=False)
25 | parser.add_argument("--num-fewshot", type=int, default=0)
26 | args = parser.parse_args()
27 |
28 | os.environ["MODEL_NAME"] = args.model_name
29 | """
30 | if args.model_type == 'opt':
31 | os.environ['MODEL_NAME'] = "facebook/opt-66b"
32 | elif args.model_type == 'bloom':
33 | os.environ['MODEL_NAME'] = "bigscience/bloom"
34 | elif args.model_type == 'gpt_neox':
35 | os.environ['MODEL_NAME'] = "EleutherAI/gpt-neox-20b"
36 | elif args.model_type == 'llama':
37 | os.environ['MODEL_NAME'] = "huggyllama/llama-7b"
38 | else:
39 | assert False
40 | """
41 | seq = 1024
42 | total_batch = 1
43 | pe = "fixed"
44 |
45 | class RealRunner:
46 | def __init__(self, args):
47 | self.results = {}
48 |
49 | with open(args.result_file, "r") as f:
50 | for line in f:
51 | if line.strip() == "":
52 | continue
53 |
54 | item = json.loads(line)
55 |
56 | request = item["request"]
57 | result = item["result"]
58 |
59 | self.results[json_to_key(request)] = result
60 |
61 | print(f"{len(self.results)} items in the cache")
62 |
63 | def eval(self, batch):
64 | from tasks.eval_harness import tokenizer
65 |
66 | mask_loss = []
67 | each_correct = []
68 |
69 | for i, text in enumerate(batch["text"]):
70 | request = {
71 | "best_of": 1,
72 | "echo": True,
73 | "logprobs": 1,
74 | "max_tokens": 0,
75 | "model": "x",
76 | "n": 1,
77 | "prompt": text,
78 | "request_type": "language-model-inference",
79 | "stop": None,
80 | "temperature": 0,
81 | "top_p": 1,
82 | }
83 |
84 | key = json_to_key(request)
85 |
86 | correct = True
87 |
88 | if key in self.results:
89 | result = self.results[key]
90 |
91 | token_logprobs = result["choices"][0]["logprobs"]["token_logprobs"]
92 | tokens = result["choices"][0]["logprobs"]["tokens"]
93 | top_logprobs = result["choices"][0]["logprobs"]["top_logprobs"]
94 | assert token_logprobs[0] is None
95 |
96 | token_ids = tokenizer.convert_tokens_to_ids(tokens)
97 |
98 | obs = batch["obs"][i]
99 | target = batch["target"][i]
100 | eval_mask = batch["eval_mask"][i]
101 |
102 | n_positive = 0
103 | sum_lobprob = 0
104 | if args.debug:
105 | print(target)
106 | for i, mask in enumerate(eval_mask):
107 | try:
108 | if i + 1 >= len(tokens):
109 | break
110 |
111 | if mask == True:
112 | if args.debug:
113 | print(
114 | tokens[i + 1],
115 | next(iter(top_logprobs[i + 1].keys())),
116 | )
117 | correct = correct and (
118 | tokens[i + 1]
119 | == next(iter(top_logprobs[i + 1].keys()))
120 | )
121 | sum_lobprob += token_logprobs[i + 1]
122 | n_positive += 1
123 | except Exception as e:
124 | raise e
125 |
126 | # avg_logprob = sum(token_logprobs[1:]) / (len(token_logprobs) - 1)
127 | avg_logprob = sum_lobprob / n_positive
128 |
129 | mask_loss.append(-avg_logprob)
130 |
131 | each_correct.append(correct)
132 |
133 | else:
134 | assert False
135 |
136 | out = {
137 | "mask_loss": mask_loss,
138 | "each_correct": each_correct,
139 | }
140 |
141 | return out
142 |
143 | t = RealRunner(args)
144 |
145 | adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed")
146 |
147 | results = evaluator.evaluate(
148 | adaptor,
149 | tasks.get_task_dict(
150 | [
151 | args.task_name
152 | # "lambada_openai",
153 | # "piqa",
154 | # "hellaswag",
155 | # "winogrande",
156 | # "mathqa",
157 | # "pubmedqa",
158 | # "boolq",
159 | # "cb",
160 | # "copa",
161 | # "multirc",
162 | # "record",
163 | # "wic",
164 | # "wsc",
165 | ]
166 | ),
167 | False,
168 | args.num_fewshot,
169 | None,
170 | )
171 |
172 | dumped = json.dumps(results, indent=2)
173 | print(dumped)
174 |
175 | with open(args.output_file, "w") as outfile:
176 | outfile.write(dumped)
177 |
--------------------------------------------------------------------------------
/lm_eval_harness/experiments.sh:
--------------------------------------------------------------------------------
1 | ## =================== MPT-7B ====================
2 | #bash run.sh <1/0 (1 for Keyformer)>
3 | # ==> Full Attention
4 | bash run.sh openbookqa 0 mosaicml/mpt-7b ./MPT-7B mosaicml-mpt-7b float16 0 60 30
5 | # ==> Keyformer
6 | bash run.sh openbookqa 0 mosaicml/mpt-7b ./MPT-7B mosaicml-mpt-7b float16 1 60 30
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/lm_eval_harness/generate_task_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import sys
5 |
6 | from lm_eval import evaluator, tasks
7 | from tasks import EvalHarnessAdaptor
8 |
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser(
12 | prog="ProgramName",
13 | description="What the program does",
14 | epilog="Text at the bottom of help",
15 | )
16 |
17 | parser.add_argument("--output-file", type=str, default="input.jsonl")
18 | parser.add_argument("--task-name", type=str, default="hellaswag")
19 | parser.add_argument("--num-fewshot", type=int, default=0)
20 | args = parser.parse_args()
21 |
22 | if os.path.isfile(args.output_file):
23 | print(f"{args.task_name} data with {args.num_fewshot} shots already exist!!")
24 | sys.exit()
25 |
26 | seq = 1024
27 | total_batch = 1
28 | pe = "fixed"
29 |
30 | with open(args.output_file, "w") as f:
31 | pass
32 |
33 | class DryRunner:
34 | def eval(self, batch):
35 | with open(args.output_file, "a") as f:
36 | for text in batch["text"]:
37 | item = {
38 | "best_of": 1,
39 | "echo": True,
40 | "logprobs": 1,
41 | "max_tokens": 0,
42 | "model": "x",
43 | "n": 1,
44 | "prompt": text,
45 | "request_type": "language-model-inference",
46 | "stop": None,
47 | "temperature": 0,
48 | "top_p": 1,
49 | }
50 |
51 | f.write(json.dumps(item) + "\n")
52 |
53 | out = {
54 | "mask_loss": [1.0] * len(batch),
55 | "each_correct": [True] * len(batch),
56 | }
57 | return out
58 |
59 | t = DryRunner()
60 | adaptor = EvalHarnessAdaptor(t, seq, total_batch, shrink=pe != "fixed")
61 | results = evaluator.evaluate(
62 | adaptor,
63 | tasks.get_task_dict(
64 | [
65 | args.task_name
66 | # "lambada_openai",
67 | # "piqa",
68 | # "hellaswag",
69 | # "winogrande",
70 | # "mathqa",
71 | # "pubmedqa",
72 | # "boolq",
73 | # "cb",
74 | # "copa",
75 | # "multirc",
76 | # "record",
77 | # "wic",
78 | # "wsc",
79 | ]
80 | ),
81 | False,
82 | args.num_fewshot,
83 | None,
84 | )
85 | print(f"Finished Generating {args.task_name} data with {args.num_fewshot} shots!!")
86 |
--------------------------------------------------------------------------------
/lm_eval_harness/model_eval_data/model_eval_data:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/lm_eval_harness/model_eval_data/model_eval_data
--------------------------------------------------------------------------------
/lm_eval_harness/output/output:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/lm_eval_harness/output/output
--------------------------------------------------------------------------------
/lm_eval_harness/run.sh:
--------------------------------------------------------------------------------
1 | # ## Obtain inference data
2 | task=$1
3 | shots=$2
4 | model_name=$3
5 | model_path=$4
6 | model_type=$5
7 | dtype=$6
8 | keyformer=$7
9 | kv_cache=$8
10 | recent=$9
11 |
12 | echo " "
13 | echo "--------------------------------------------"
14 | echo "Task ${task} data with ${shots} shots generation"
15 | echo "--------------------------------------------"
16 |
17 | python -u generate_task_data.py --output-file ./task_data/${task}-${shots}.jsonl --task-name ${task} --num-fewshot ${shots}
18 |
19 | ## Inference, and generate output json file
20 | echo " "
21 | echo "--------------------------------------------"
22 | echo "LM Eval Harness for model ${model_name}"
23 | echo "--------------------------------------------"
24 |
25 | if [ $keyformer -eq 0 ]
26 | then
27 | echo "====> Full Attention!!"
28 | python -u run_lm_eval_harness.py --input-path ./task_data/${task}-${shots}.jsonl --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --model-name ${model_name} --model-path ${model_path} --dtype ${dtype} --kv_cache ${kv_cache} --recent ${recent}
29 | else
30 | echo "====> Keyformer enabled with ${kv_cache}% KV Cache with ${recent}% recent tokens!!"
31 | python -u run_lm_eval_harness.py --input-path ./task_data/${task}-${shots}.jsonl --output-path ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --model-name ${model_name} --model-path ${model_path} --dtype ${dtype} --keyformer --kv_cache ${kv_cache} --recent ${recent}
32 | fi
33 |
34 | ## Evaluate results
35 | echo " "
36 | echo "--------------------------------------------"
37 | echo "Results Generation"
38 | echo "--------------------------------------------"
39 | python -u evaluate_task_result.py --result-file ./model_eval_data/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --output-file ./output/${task}-${shots}-${model_type}-${keyformer}-${kv_cache}-${recent}.jsonl --task-name ${task} --num-fewshot ${shots} --model-name ${model_name}
--------------------------------------------------------------------------------
/lm_eval_harness/run_lm_eval_harness.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json, tqdm
3 | import torch
4 | import copy
5 | import sys
6 |
7 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
8 |
9 | if __name__ == "__main__":
10 | parser = argparse.ArgumentParser(
11 | prog="ProgramName",
12 | description="What the program does",
13 | epilog="Text at the bottom of help",
14 | )
15 |
16 | parser.add_argument("--input-path", type=str, default=None)
17 | parser.add_argument("--output-path", type=str, default=None)
18 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
19 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
20 | parser.add_argument("--model-type", type=str, default="opt")
21 | parser.add_argument("--cache-dir", type=str, default="./")
22 | parser.add_argument(
23 | "--dtype",
24 | default="float32",
25 | help="data type of the model, choose from float16, bfloat16 and float32",
26 | )
27 | ############################################### Keyformer #################################################################
28 | parser.add_argument(
29 | "--keyformer", action="store_true", help="Keyformer enabled - reduced KV cache"
30 | )
31 | parser.add_argument(
32 | "--kv_cache",
33 | type=float,
34 | default=60,
35 | required=False,
36 | help="KV cache percentage for Keyformer",
37 | )
38 | parser.add_argument(
39 | "--recent",
40 | type=float,
41 | default=30,
42 | required=False,
43 | help="Recent window percentage for Keyformer",
44 | )
45 | parser.add_argument(
46 | "--tau_init",
47 | type=float,
48 | default=1.0,
49 | required=False,
50 | help="Initial temperature",
51 | )
52 | parser.add_argument(
53 | "--tau_end", type=float, default=2.0, required=False, help="Final temperature"
54 | )
55 | ###########################################################################################################################
56 | parser.add_argument(
57 | "--token_discard",
58 | action="store_true",
59 | help="When enable token discarded below threshold",
60 | )
61 | parser.add_argument(
62 | "--sparse_threshold",
63 | type=float,
64 | default=0.0,
65 | required=False,
66 | help="Threshold for sparsification",
67 | )
68 | parser.add_argument(
69 | "--prompt_sparse_threshold",
70 | type=float,
71 | default=0.0,
72 | required=False,
73 | help="Threshold for prompt sparsification",
74 | )
75 | parser.add_argument(
76 | "--sparse_itr",
77 | type=int,
78 | default=1,
79 | required=False,
80 | help="Tokene Generation Iteration for updating token discard mask",
81 | )
82 |
83 | args = parser.parse_args()
84 |
85 | input_path = args.input_path
86 | output_path = args.output_path
87 | model_name = args.model_name
88 | model_path = args.model_path
89 | dtype = args.dtype
90 | keyformer = args.keyformer
91 | kv_cache = args.kv_cache
92 | recent = args.recent
93 | tau_init = args.tau_init
94 | tau_end = args.tau_end
95 |
96 | if dtype == "bfloat16":
97 | amp_enabled = True
98 | amp_dtype = torch.bfloat16
99 | elif dtype == "float16":
100 | amp_enabled = True
101 | amp_dtype = torch.float16
102 | else:
103 | amp_enabled = False
104 | amp_dtype = torch.float32
105 |
106 | if model_name.split("/")[0] == "mosaicml":
107 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
108 | config.attn_config["attn_impl"] = "torch"
109 | config.init_device = "cuda:0"
110 | config.use_cache = True
111 | config.torch_dtype = dtype
112 | config.keyformer_config["keyformer"] = keyformer
113 | config.keyformer_config["kv_cache"] = kv_cache
114 | config.keyformer_config["recent"] = recent
115 | config.keyformer_config["tau_init"] = tau_init
116 | config.keyformer_config["tau_delta"] = tau_end - tau_init
117 | model = AutoModelForCausalLM.from_pretrained(
118 | model_path, config=config, torch_dtype=amp_dtype, trust_remote_code=True
119 | )
120 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
121 |
122 | elif model_name == "EleutherAI/gpt-j-6B":
123 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
124 | config.keyformer_config["keyformer"] = keyformer
125 | config.keyformer_config["kv_cache"] = kv_cache
126 | config.keyformer_config["recent"] = recent
127 | config.keyformer_config["tau_init"] = tau_init
128 | config.keyformer_config["tau_delta"] = tau_end - tau_init
129 | model = AutoModelForCausalLM.from_pretrained(
130 | model_path, config=config, torch_dtype=amp_dtype, trust_remote_code=True
131 | )
132 | tokenizer = AutoTokenizer.from_pretrained(model_name)
133 |
134 | else:
135 | config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
136 | config.keyformer_config["keyformer"] = keyformer
137 | config.keyformer_config["kv_cache"] = kv_cache
138 | config.keyformer_config["recent"] = recent
139 | config.keyformer_config["tau_init"] = tau_init
140 | config.keyformer_config["tau_delta"] = tau_end - tau_init
141 | model = AutoModelForCausalLM.from_pretrained(
142 | model_path, config=config, torch_dtype=amp_dtype, trust_remote_code=True
143 | )
144 | tokenizer = AutoTokenizer.from_pretrained(model_name)
145 |
146 | model.eval().cuda()
147 |
148 | requests = []
149 | with open(input_path, "r") as f:
150 | for line in f:
151 | if line.strip() != "":
152 | requests.append(json.loads(line))
153 |
154 | results = []
155 | with torch.no_grad():
156 | for request in tqdm.tqdm(requests):
157 | result = {"request": request, "result": {}}
158 | prompt = request["prompt"]
159 | input_ids = tokenizer(
160 | prompt, add_special_tokens=False, return_tensors="pt"
161 | ).input_ids.to(model.device)
162 |
163 | logits = model(input_ids).logits.log_softmax(dim=-1)
164 | values, indices = logits.squeeze(0).topk(dim=-1, k=1)
165 | tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze(0))
166 |
167 | gold_indices = input_ids[:, 1:] # skip first
168 | logprobs = [None] + torch.gather(
169 | logits, -1, gold_indices.unsqueeze(-1)
170 | ).squeeze(-1).squeeze(0).detach().cpu().tolist()
171 | top_logprobs = [None] + [
172 | {tokenizer.convert_ids_to_tokens(i.item()): v.item()}
173 | for v, i in zip(values.squeeze(-1), indices.squeeze(-1))
174 | ]
175 |
176 | result["result"] = {
177 | "choices": [
178 | {
179 | "text": prompt,
180 | "logprobs": {
181 | "tokens": tokens,
182 | "token_logprobs": logprobs,
183 | "top_logprobs": top_logprobs,
184 | "text_offset": [],
185 | },
186 | "finish_reason": "length",
187 | }
188 | ],
189 | "request_time": {"batch_time": 0, "batch_size": 1},
190 | }
191 |
192 | results.append(result)
193 |
194 | with open(output_path, "w") as f:
195 | for result in results:
196 | f.write(json.dumps(result) + "\n")
197 |
--------------------------------------------------------------------------------
/lm_eval_harness/task_data/task_data:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/lm_eval_harness/task_data/task_data
--------------------------------------------------------------------------------
/lm_eval_harness/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from tasks.eval_harness import EvalHarnessAdaptor
--------------------------------------------------------------------------------
/lm_eval_harness/tasks/eval_harness.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import os
4 | import transformers
5 | from lm_eval.base import LM
6 | from tqdm import tqdm
7 | import numpy as np
8 |
9 | from tasks.util import sample_batch, shrink_seq
10 | import multiprocessing
11 | import ftfy
12 |
13 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
14 |
15 | tokenizer = None
16 |
17 |
18 | def process_init():
19 | global tokenizer
20 | model_name = os.environ.get("MODEL_NAME", "facebook/opt-1.3b")
21 |
22 | if model_name.split("/")[0] == "mosaicml":
23 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
24 | tokenizer.model_max_length = int(1e30)
25 | tokenizer.pad_token = "<|endoftext|>"
26 | elif model_name == "EleutherAI/gpt-j-6B":
27 | tokenizer = AutoTokenizer.from_pretrained(model_name)
28 | tokenizer.model_max_length = int(1e30)
29 | tokenizer.pad_token = "<|endoftext|>"
30 | else:
31 | tokenizer = AutoTokenizer.from_pretrained(model_name)
32 | tokenizer.model_max_length = int(1e30)
33 | tokenizer.pad_token = "<|endoftext|>"
34 |
35 |
36 | def process_request(x, seq):
37 | global tokenizer
38 |
39 | ctx, cont = x
40 | # ctx_tokens = tokenizer.encode("<|endoftext|>" + ftfy.fix_text(ctx, normalization="NFKC"))
41 | ctx_text = ftfy.fix_text(ctx, normalization="NFKC")
42 | cont_text = ftfy.fix_text(cont, normalization="NFKC")
43 | all_text = ctx_text + cont_text
44 |
45 | ctx_tokens = tokenizer(ctx_text, add_special_tokens=False)["input_ids"]
46 | cont_tokens = tokenizer(cont_text, add_special_tokens=False)["input_ids"]
47 |
48 | all_tokens = ctx_tokens + cont_tokens
49 | all_tokens = np.array(all_tokens)[-seq:] # truncate sequence at seq length
50 |
51 | provided_ctx = len(all_tokens) - 1
52 | pad_amount = seq - provided_ctx
53 |
54 | return {
55 | "obs": np.pad(
56 | all_tokens[:-1], ((0, pad_amount),), constant_values=tokenizer.pad_token_id
57 | ),
58 | "target": np.pad(
59 | all_tokens[1:], ((0, pad_amount),), constant_values=tokenizer.pad_token_id
60 | ),
61 | "ctx_length": seq,
62 | "eval_mask": np.logical_and(
63 | np.arange(0, seq) > len(all_tokens) - len(cont_tokens) - 2,
64 | np.arange(0, seq) < len(all_tokens) - 1,
65 | ),
66 | "prompt": ctx_text,
67 | "target": cont_text,
68 | "text": all_text,
69 | }
70 |
71 |
72 | class EvalHarnessAdaptor(LM):
73 | def greedy_until(self, requests):
74 | raise Exception("unimplemented")
75 |
76 | def loglikelihood_rolling(self, requests):
77 | raise Exception("unimplemented")
78 |
79 | def __init__(self, tpu_cluster, seq, batch, shrink, min_seq=None):
80 | super().__init__()
81 | self.tpu = tpu_cluster
82 | self.seq = seq
83 | self.batch = batch
84 | self.shrink = shrink
85 | self.min_seq = min_seq
86 |
87 | self.pool = multiprocessing.Pool(processes=1, initializer=process_init)
88 | # self.pool = multiprocessing.Pool(initializer=process_init)
89 | process_init()
90 |
91 | def convert_requests(self, requests):
92 | return self.pool.imap(partial(process_request, seq=self.seq), requests)
93 |
94 | def loglikelihood(self, requests):
95 | output = []
96 |
97 | r = self.convert_requests(requests)
98 | zero_example = process_request(requests[0], self.seq)
99 |
100 | for b in tqdm(
101 | sample_batch(r, self.batch, zero_example),
102 | desc="LM eval harness",
103 | total=len(requests) // self.batch,
104 | ):
105 | if self.shrink:
106 | b = shrink_seq(b, min_seq=self.min_seq)
107 |
108 | out = self.tpu.eval(b)
109 |
110 | for loss, correct in zip(out["mask_loss"], out["each_correct"]):
111 | output.append((float(-loss), bool(correct)))
112 |
113 | return output
114 |
--------------------------------------------------------------------------------
/lm_eval_harness/tasks/util.py:
--------------------------------------------------------------------------------
1 | from itertools import zip_longest
2 |
3 | import numpy as np
4 |
5 |
6 | def grouper(n, iterable, fillvalue):
7 | "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
8 | args = [iter(iterable)] * n
9 | return zip_longest(fillvalue=fillvalue, *args)
10 |
11 |
12 | # divide the seq length by 2 until it would truncate actual context
13 | def shrink_seq(examples, min_seq=None):
14 | length = examples["obs"].shape[-1]
15 |
16 | new_length = length // 2
17 |
18 | if min_seq is not None:
19 | if new_length < min_seq:
20 | return examples
21 |
22 | max_length = np.max(examples["eval_mask"] * np.arange(0, length)) + 1
23 |
24 | if max_length < new_length:
25 | examples["obs"] = examples["obs"][:, :new_length]
26 | examples["target"] = examples["target"][:, :new_length]
27 | examples["eval_mask"] = examples["eval_mask"][:, :new_length]
28 |
29 | return shrink_seq(examples, min_seq=min_seq)
30 | else:
31 | return examples
32 |
33 |
34 | def sample_batch(examples, bs, zero_example_shape):
35 | zero_example = {
36 | "obs": np.zeros_like(zero_example_shape["obs"]),
37 | "target": np.zeros_like(zero_example_shape["target"]),
38 | "eval_mask": np.zeros_like(zero_example_shape["eval_mask"]),
39 | "ctx_length": 0,
40 | }
41 |
42 | for batch in grouper(bs, examples, zero_example):
43 | batch_flattened = {
44 | "obs": [],
45 | "target": [],
46 | "eval_mask": [],
47 | "ctx_length": [],
48 | "text": [],
49 | }
50 |
51 | for sample in batch:
52 | batch_flattened["obs"].append(sample["obs"])
53 | batch_flattened["target"].append(sample["target"])
54 | batch_flattened["eval_mask"].append(sample["eval_mask"])
55 | batch_flattened["ctx_length"].append(sample["ctx_length"])
56 | batch_flattened["text"].append(sample["text"])
57 |
58 | batch_flattened["obs"] = np.array(batch_flattened["obs"])
59 | batch_flattened["target"] = np.array(batch_flattened["target"])
60 | batch_flattened["eval_mask"] = np.array(batch_flattened["eval_mask"])
61 | batch_flattened["ctx_length"] = np.array(batch_flattened["ctx_length"])
62 |
63 | yield batch_flattened
64 |
--------------------------------------------------------------------------------
/models/README.md:
--------------------------------------------------------------------------------
1 | # 🏁 Models
2 |
3 | ## LLM Model Checkpoint
4 | Note - **Keyformer** has been evaluated on the following models and we restrict this tutorial to the use of these -
5 | `['cerebras/Cerebras-GPT-6.7B',
6 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']`
7 |
8 | Clone the relevant model from huggingface into the `models` directory. For instance, to clone `mosaicml/mpt-7b` into `models/mpt-7b-keyformer`, do the following -
9 |
10 | ```bash
11 | cd models
12 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/mosaicml/mpt-7b mpt-7b-keyformer
13 | cd ..
14 | ```
15 |
16 | Then, download the model weights (`*.bin`) using `download_model.py` file and use the model_name as the argument.
17 |
18 | ```bash
19 | cd models/model_download
20 | python3 download_model.py --model_name mosaicml/mpt-7b
21 | cd ../../
22 | ```
23 |
24 | Please keep in mind that supported arguments for `model_name` are restricted to `['cerebras/Cerebras-GPT-6.7B',
25 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']`
26 |
27 | This step will download the model parameters by creating a `model` folder inside `model_download` directory.
28 | Downloaded model parameters are required to be moved to the same directory where model files are available.
29 |
30 | To move the model parameters from `models/model_download/model/` to `models/mpt-7b-keyformer`
31 |
32 | ```bash
33 | mv models/model_download/model/* models/mpt-7b-keyformer/
34 | ```
35 |
36 | ### Integrating Keyformer
37 | Copy the **Keyformer** source files from `models/mpt-keyformer-lib` to `models/mpt-7b-keyformer`
38 |
39 | ```bash
40 | cp -r models/mpt-keyformer-lib models/mpt-7b-keyformer
41 | ```
42 |
43 | This sets up `models/mpt-7b-keyformer` to be used with the various tasks described below.
44 |
45 | Similarly, find the keyformer-lib files for other supported models in their respective directories in `models/`.
46 |
47 | ## LLM Model Cards
48 |
49 | We have provided model cards for below models.
50 | - [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6b)
51 | - [MPT](https://huggingface.co/mosaicml/mpt-7b)
52 | - [Cerebras-GPT](https://huggingface.co/cerebras/Cerebras-GPT-6.7B)
53 |
--------------------------------------------------------------------------------
/models/cerebras-keyformer-lib/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "cerebras/Cerebras-GPT-6.7B",
3 | "activation_function": "gelu",
4 | "architectures": [
5 | "GPT2LMHeadModel"
6 | ],
7 | "auto_map": {
8 | "AutoConfig": "configuration_gpt2.GPT2Config",
9 | "AutoModelForCausalLM": "modeling_gpt2.GPT2LMHeadModel"
10 | },
11 | "keyformer_config": {
12 | "keyformer": false,
13 | "kv_cache": 60,
14 | "recent": 30,
15 | "tau_init": 1.0,
16 | "tau_delta": 0.01
17 | },
18 | "attn_pdrop": 0.0,
19 | "bos_token_id": 50256,
20 | "embd_pdrop": 0.0,
21 | "eos_token_id": 50256,
22 | "initializer_range": 0.02,
23 | "layer_norm_epsilon": 1e-05,
24 | "model_type": "gpt2",
25 | "n_embd": 4096,
26 | "n_head": 32,
27 | "n_inner": 16384,
28 | "n_layer": 32,
29 | "n_positions": 2048,
30 | "reorder_and_upcast_attn": false,
31 | "resid_pdrop": 0.0,
32 | "scale_attn_by_inverse_layer_idx": false,
33 | "scale_attn_weights": true,
34 | "summary_activation": null,
35 | "summary_first_dropout": 0.1,
36 | "summary_proj_to_labels": true,
37 | "summary_type": "cls_index",
38 | "summary_use_proj": true,
39 | "torch_dtype": "float32",
40 | "transformers_version": "4.28.1",
41 | "use_cache": true,
42 | "vocab_size": 50257
43 | }
44 |
--------------------------------------------------------------------------------
/models/cerebras-keyformer-lib/configuration_gpt2.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ OpenAI GPT-2 configuration"""
17 | from collections import OrderedDict
18 | from typing import Any, List, Mapping, Optional
19 |
20 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available
21 | from transformers import PretrainedConfig
22 | from transformers.onnx import OnnxConfigWithPast, PatchingSpec
23 | from transformers.utils import logging
24 |
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {
29 | "gpt2": "https://huggingface.co/gpt2/resolve/main/config.json",
30 | "gpt2-medium": "https://huggingface.co/gpt2-medium/resolve/main/config.json",
31 | "gpt2-large": "https://huggingface.co/gpt2-large/resolve/main/config.json",
32 | "gpt2-xl": "https://huggingface.co/gpt2-xl/resolve/main/config.json",
33 | "distilgpt2": "https://huggingface.co/distilgpt2/resolve/main/config.json",
34 | }
35 | keyformer_config_defaults: Dict = {
36 | "keyformer": False,
37 | "kv_cache": 60,
38 | "recent": 30,
39 | "tau_init": 1.0,
40 | "tau_delta": 0.01,
41 | }
42 |
43 |
44 | class GPT2Config(PretrainedConfig):
45 | """
46 | This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to
47 | instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a
48 | configuration with the defaults will yield a similar configuration to that of the GPT-2
49 | [gpt2](https://huggingface.co/gpt2) architecture.
50 |
51 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
52 | documentation from [`PretrainedConfig`] for more information.
53 |
54 |
55 | Args:
56 | vocab_size (`int`, *optional*, defaults to 50257):
57 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
58 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
59 | n_positions (`int`, *optional*, defaults to 1024):
60 | The maximum sequence length that this model might ever be used with. Typically set this to something large
61 | just in case (e.g., 512 or 1024 or 2048).
62 | n_embd (`int`, *optional*, defaults to 768):
63 | Dimensionality of the embeddings and hidden states.
64 | n_layer (`int`, *optional*, defaults to 12):
65 | Number of hidden layers in the Transformer encoder.
66 | n_head (`int`, *optional*, defaults to 12):
67 | Number of attention heads for each attention layer in the Transformer encoder.
68 | n_inner (`int`, *optional*, defaults to None):
69 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
70 | activation_function (`str`, *optional*, defaults to `"gelu"`):
71 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
72 | resid_pdrop (`float`, *optional*, defaults to 0.1):
73 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
74 | embd_pdrop (`float`, *optional*, defaults to 0.1):
75 | The dropout ratio for the embeddings.
76 | attn_pdrop (`float`, *optional*, defaults to 0.1):
77 | The dropout ratio for the attention.
78 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
79 | The epsilon to use in the layer normalization layers.
80 | initializer_range (`float`, *optional*, defaults to 0.02):
81 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
82 | summary_type (`string`, *optional*, defaults to `"cls_index"`):
83 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
84 | [`TFGPT2DoubleHeadsModel`].
85 |
86 | Has to be one of the following options:
87 |
88 | - `"last"`: Take the last token hidden state (like XLNet).
89 | - `"first"`: Take the first token hidden state (like BERT).
90 | - `"mean"`: Take the mean of all tokens hidden states.
91 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
92 | - `"attn"`: Not implemented now, use multi-head attention.
93 | summary_use_proj (`bool`, *optional*, defaults to `True`):
94 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
95 | [`TFGPT2DoubleHeadsModel`].
96 |
97 | Whether or not to add a projection after the vector extraction.
98 | summary_activation (`str`, *optional*):
99 | Argument used when doing sequence summary. Used in for the multiple choice head in
100 | [`GPT2DoubleHeadsModel`].
101 |
102 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
103 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
104 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
105 | [`TFGPT2DoubleHeadsModel`].
106 |
107 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
108 | summary_first_dropout (`float`, *optional*, defaults to 0.1):
109 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
110 | [`TFGPT2DoubleHeadsModel`].
111 |
112 | The dropout ratio to be used after the projection and activation.
113 | scale_attn_weights (`bool`, *optional*, defaults to `True`):
114 | Scale attention weights by dividing by sqrt(hidden_size)..
115 | use_cache (`bool`, *optional*, defaults to `True`):
116 | Whether or not the model should return the last key/values attentions (not used by all models).
117 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
118 | Whether to additionally scale attention weights by `1 / layer_idx + 1`.
119 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
120 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
121 | dot-product/softmax to float() when training with mixed precision.
122 | keyformer_config (Dict): A dictionary used to configure the keyformer parameters:
123 | keyformer (bool): When enabled, keyformer based KV cache reduction
124 | kv_cache (float): KV cache percentage for Keyformer.
125 | recent (float): Recent window percentage for Keyformer.
126 | tau_init (float): Initial temperature for Keyformer score function calculation.
127 | tau_delta (float): Delta temperature change for Keyformer score function.
128 |
129 | Example:
130 |
131 | ```python
132 | >>> from transformers import GPT2Config, GPT2Model
133 |
134 | >>> # Initializing a GPT2 configuration
135 | >>> configuration = GPT2Config()
136 |
137 | >>> # Initializing a model (with random weights) from the configuration
138 | >>> model = GPT2Model(configuration)
139 |
140 | >>> # Accessing the model configuration
141 | >>> configuration = model.config
142 | ```"""
143 |
144 | model_type = "gpt2"
145 | keys_to_ignore_at_inference = ["past_key_values"]
146 | attribute_map = {
147 | "hidden_size": "n_embd",
148 | "max_position_embeddings": "n_positions",
149 | "num_attention_heads": "n_head",
150 | "num_hidden_layers": "n_layer",
151 | }
152 |
153 | def __init__(
154 | self,
155 | vocab_size=50257,
156 | n_positions=1024,
157 | n_embd=768,
158 | n_layer=12,
159 | n_head=12,
160 | n_inner=None,
161 | activation_function="gelu_new",
162 | resid_pdrop=0.1,
163 | embd_pdrop=0.1,
164 | attn_pdrop=0.1,
165 | layer_norm_epsilon=1e-5,
166 | initializer_range=0.02,
167 | summary_type="cls_index",
168 | summary_use_proj=True,
169 | summary_activation=None,
170 | summary_proj_to_labels=True,
171 | summary_first_dropout=0.1,
172 | scale_attn_weights=True,
173 | use_cache=True,
174 | bos_token_id=50256,
175 | eos_token_id=50256,
176 | scale_attn_by_inverse_layer_idx=False,
177 | reorder_and_upcast_attn=False,
178 | keyformer_config: Dict = keyformer_config_defaults,
179 | **kwargs,
180 | ):
181 | self.vocab_size = vocab_size
182 | self.n_positions = n_positions
183 | self.n_embd = n_embd
184 | self.n_layer = n_layer
185 | self.n_head = n_head
186 | self.n_inner = n_inner
187 | self.activation_function = activation_function
188 | self.resid_pdrop = resid_pdrop
189 | self.embd_pdrop = embd_pdrop
190 | self.attn_pdrop = attn_pdrop
191 | self.layer_norm_epsilon = layer_norm_epsilon
192 | self.initializer_range = initializer_range
193 | self.summary_type = summary_type
194 | self.summary_use_proj = summary_use_proj
195 | self.summary_activation = summary_activation
196 | self.summary_first_dropout = summary_first_dropout
197 | self.summary_proj_to_labels = summary_proj_to_labels
198 | self.scale_attn_weights = scale_attn_weights
199 | self.use_cache = use_cache
200 | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
201 | self.reorder_and_upcast_attn = reorder_and_upcast_attn
202 | # ====== Keyformer ======
203 | self.keyformer_config = keyformer_config
204 | # =======================
205 | self.bos_token_id = bos_token_id
206 | self.eos_token_id = eos_token_id
207 |
208 | super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
209 |
210 |
211 | class GPT2OnnxConfig(OnnxConfigWithPast):
212 | def __init__(
213 | self,
214 | config: PretrainedConfig,
215 | task: str = "default",
216 | patching_specs: List[PatchingSpec] = None,
217 | use_past: bool = False,
218 | ):
219 | super().__init__(
220 | config, task=task, patching_specs=patching_specs, use_past=use_past
221 | )
222 | if not getattr(self._config, "pad_token_id", None):
223 | # TODO: how to do that better?
224 | self._config.pad_token_id = 0
225 |
226 | @property
227 | def inputs(self) -> Mapping[str, Mapping[int, str]]:
228 | common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
229 | if self.use_past:
230 | self.fill_with_past_key_values_(common_inputs, direction="inputs")
231 | common_inputs["attention_mask"] = {
232 | 0: "batch",
233 | 1: "past_sequence + sequence",
234 | }
235 | else:
236 | common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
237 |
238 | return common_inputs
239 |
240 | @property
241 | def num_layers(self) -> int:
242 | return self._config.n_layer
243 |
244 | @property
245 | def num_attention_heads(self) -> int:
246 | return self._config.n_head
247 |
248 | def generate_dummy_inputs(
249 | self,
250 | tokenizer: PreTrainedTokenizer,
251 | batch_size: int = -1,
252 | seq_length: int = -1,
253 | is_pair: bool = False,
254 | framework: Optional[TensorType] = None,
255 | ) -> Mapping[str, Any]:
256 | common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
257 | tokenizer,
258 | batch_size=batch_size,
259 | seq_length=seq_length,
260 | is_pair=is_pair,
261 | framework=framework,
262 | )
263 |
264 | # We need to order the input in the way they appears in the forward()
265 | ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
266 |
267 | # Need to add the past_keys
268 | if self.use_past:
269 | if not is_torch_available():
270 | raise ValueError(
271 | "Cannot generate dummy past_keys inputs without PyTorch installed."
272 | )
273 | else:
274 | import torch
275 |
276 | batch, seqlen = common_inputs["input_ids"].shape
277 | # Not using the same length for past_key_values
278 | past_key_values_length = seqlen + 2
279 | past_shape = (
280 | batch,
281 | self.num_attention_heads,
282 | past_key_values_length,
283 | self._config.hidden_size // self.num_attention_heads,
284 | )
285 | ordered_inputs["past_key_values"] = [
286 | (torch.zeros(past_shape), torch.zeros(past_shape))
287 | for _ in range(self.num_layers)
288 | ]
289 |
290 | ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
291 | if self.use_past:
292 | mask_dtype = ordered_inputs["attention_mask"].dtype
293 | ordered_inputs["attention_mask"] = torch.cat(
294 | [
295 | ordered_inputs["attention_mask"],
296 | torch.ones(batch, past_key_values_length, dtype=mask_dtype),
297 | ],
298 | dim=1,
299 | )
300 |
301 | return ordered_inputs
302 |
303 | @property
304 | def default_onnx_opset(self) -> int:
305 | return 13
306 |
--------------------------------------------------------------------------------
/models/gptj-keyformer-lib/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "EleutherAI/gpt-j-6b",
3 | "activation_function": "gelu_new",
4 | "architectures": [
5 | "GPTJForCausalLM"
6 | ],
7 | "auto_map": {
8 | "AutoConfig": "configuration_gptj.GPTJConfig",
9 | "AutoModelForCausalLM": "modeling_gptj.GPTJForCausalLM"
10 | },
11 | "keyformer_config": {
12 | "keyformer": false,
13 | "kv_cache": 60,
14 | "recent": 30,
15 | "tau_init": 1.0,
16 | "tau_delta": 0.01
17 | },
18 | "attn_pdrop": 0.0,
19 | "bos_token_id": 50256,
20 | "embd_pdrop": 0.0,
21 | "eos_token_id": 50256,
22 | "gradient_checkpointing": false,
23 | "initializer_range": 0.02,
24 | "layer_norm_epsilon": 1e-05,
25 | "model_type": "gptj",
26 | "n_embd": 4096,
27 | "n_head": 16,
28 | "n_inner": null,
29 | "n_layer": 28,
30 | "n_positions": 2048,
31 | "resid_pdrop": 0.0,
32 | "rotary": true,
33 | "rotary_dim": 64,
34 | "scale_attn_weights": true,
35 | "summary_activation": null,
36 | "summary_first_dropout": 0.1,
37 | "summary_proj_to_labels": true,
38 | "summary_type": "cls_index",
39 | "summary_use_proj": true,
40 | "task_specific_params": {
41 | "text-generation": {
42 | "do_sample": true,
43 | "max_length": 50,
44 | "temperature": 1.0
45 | }
46 | },
47 | "tie_word_embeddings": false,
48 | "tokenizer_class": "GPT2Tokenizer",
49 | "torch_dtype": "float32",
50 | "transformers_version": "4.28.0.dev0",
51 | "use_cache": true,
52 | "vocab_size": 50401
53 | }
54 |
--------------------------------------------------------------------------------
/models/gptj-keyformer-lib/configuration_gptj.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ GPT-J model configuration"""
16 | from collections import OrderedDict
17 | from typing import Any, List, Mapping, Optional
18 |
19 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available
20 | from transformers import PretrainedConfig
21 | from transformers.onnx import OnnxConfigWithPast, PatchingSpec
22 | from transformers.utils import logging
23 |
24 |
25 | logger = logging.get_logger(__name__)
26 |
27 | GPTJ_PRETRAINED_CONFIG_ARCHIVE_MAP = {
28 | "EleutherAI/gpt-j-6B": "https://huggingface.co/EleutherAI/gpt-j-6B/resolve/main/config.json",
29 | # See all GPT-J models at https://huggingface.co/models?filter=gpt_j
30 | }
31 | keyformer_config_defaults: Dict = {
32 | "keyformer": False,
33 | "kv_cache": 60,
34 | "recent": 30,
35 | "tau_init": 1.0,
36 | "tau_delta": 0.01,
37 | }
38 |
39 |
40 | class GPTJConfig(PretrainedConfig):
41 | r"""
42 | This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J
43 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
44 | defaults will yield a similar configuration to that of the GPT-J
45 | [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from
46 | [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
47 | for more information.
48 |
49 | Args:
50 | vocab_size (`int`, *optional*, defaults to 50400):
51 | Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the
52 | `inputs_ids` passed when calling [`GPTJModel`].
53 | n_positions (`int`, *optional*, defaults to 2048):
54 | The maximum sequence length that this model might ever be used with. Typically set this to something large
55 | just in case (e.g., 512 or 1024 or 2048).
56 | n_embd (`int`, *optional*, defaults to 4096):
57 | Dimensionality of the embeddings and hidden states.
58 | n_layer (`int`, *optional*, defaults to 28):
59 | Number of hidden layers in the Transformer encoder.
60 | n_head (`int`, *optional*, defaults to 16):
61 | Number of attention heads for each attention layer in the Transformer encoder.
62 | rotary_dim (`int`, *optional*, defaults to 64):
63 | Number of dimensions in the embedding that Rotary Position Embedding is applied to.
64 | n_inner (`int`, *optional*, defaults to None):
65 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
66 | activation_function (`str`, *optional*, defaults to `"gelu_new"`):
67 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
68 | resid_pdrop (`float`, *optional*, defaults to 0.1):
69 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
70 | embd_pdrop (`int`, *optional*, defaults to 0.1):
71 | The dropout ratio for the embeddings.
72 | attn_pdrop (`float`, *optional*, defaults to 0.1):
73 | The dropout ratio for the attention.
74 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
75 | The epsilon to use in the layer normalization layers.
76 | initializer_range (`float`, *optional*, defaults to 0.02):
77 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
78 | use_cache (`bool`, *optional*, defaults to `True`):
79 | Whether or not the model should return the last key/values attentions (not used by all models).
80 | keyformer_config (Dict): A dictionary used to configure the keyformer parameters:
81 | keyformer (bool): When enabled, keyformer based KV cache reduction
82 | kv_cache (float): KV cache percentage for Keyformer.
83 | recent (float): Recent window percentage for Keyformer.
84 | tau_init (float): Initial temperature for Keyformer score function calculation.
85 | tau_delta (float): Delta temperature change for Keyformer score function.
86 |
87 | Example:
88 |
89 | ```python
90 | >>> from transformers import GPTJModel, GPTJConfig
91 |
92 | >>> # Initializing a GPT-J 6B configuration
93 | >>> configuration = GPTJConfig()
94 |
95 | >>> # Initializing a model from the configuration
96 | >>> model = GPTJModel(configuration)
97 |
98 | >>> # Accessing the model configuration
99 | >>> configuration = model.config
100 | ```"""
101 | model_type = "gptj"
102 | attribute_map = {
103 | "max_position_embeddings": "n_positions",
104 | "hidden_size": "n_embd",
105 | "num_attention_heads": "n_head",
106 | "num_hidden_layers": "n_layer",
107 | }
108 |
109 | def __init__(
110 | self,
111 | vocab_size=50400,
112 | n_positions=2048,
113 | n_embd=4096,
114 | n_layer=28,
115 | n_head=16,
116 | rotary_dim=64,
117 | n_inner=None,
118 | activation_function="gelu_new",
119 | resid_pdrop=0.0,
120 | embd_pdrop=0.0,
121 | attn_pdrop=0.0,
122 | layer_norm_epsilon=1e-5,
123 | initializer_range=0.02,
124 | use_cache=True,
125 | bos_token_id=50256,
126 | eos_token_id=50256,
127 | tie_word_embeddings=False,
128 | keyformer_config: Dict = keyformer_config_defaults,
129 | **kwargs,
130 | ):
131 | self.vocab_size = vocab_size
132 | self.n_positions = n_positions
133 | self.n_embd = n_embd
134 | self.n_layer = n_layer
135 | self.n_head = n_head
136 | self.n_inner = n_inner
137 | self.rotary_dim = rotary_dim
138 | self.activation_function = activation_function
139 | self.resid_pdrop = resid_pdrop
140 | self.embd_pdrop = embd_pdrop
141 | self.attn_pdrop = attn_pdrop
142 | self.layer_norm_epsilon = layer_norm_epsilon
143 | self.initializer_range = initializer_range
144 | self.use_cache = use_cache
145 | # ====== Keyformer ======
146 | self.keyformer_config = keyformer_config
147 | # =======================
148 | self.bos_token_id = bos_token_id
149 | self.eos_token_id = eos_token_id
150 |
151 | super().__init__(
152 | bos_token_id=bos_token_id,
153 | eos_token_id=eos_token_id,
154 | tie_word_embeddings=tie_word_embeddings,
155 | **kwargs,
156 | )
157 |
158 |
159 | # Copied from transformers.models.gpt2.configuration_gpt2.GPT2OnnxConfig
160 | class GPTJOnnxConfig(OnnxConfigWithPast):
161 | def __init__(
162 | self,
163 | config: PretrainedConfig,
164 | task: str = "default",
165 | patching_specs: List[PatchingSpec] = None,
166 | use_past: bool = False,
167 | ):
168 | super().__init__(
169 | config, task=task, patching_specs=patching_specs, use_past=use_past
170 | )
171 | if not getattr(self._config, "pad_token_id", None):
172 | # TODO: how to do that better?
173 | self._config.pad_token_id = 0
174 |
175 | @property
176 | def inputs(self) -> Mapping[str, Mapping[int, str]]:
177 | common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
178 | if self.use_past:
179 | self.fill_with_past_key_values_(common_inputs, direction="inputs")
180 | common_inputs["attention_mask"] = {
181 | 0: "batch",
182 | 1: "past_sequence + sequence",
183 | }
184 | else:
185 | common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
186 |
187 | return common_inputs
188 |
189 | @property
190 | def num_layers(self) -> int:
191 | return self._config.n_layer
192 |
193 | @property
194 | def num_attention_heads(self) -> int:
195 | return self._config.n_head
196 |
197 | def generate_dummy_inputs(
198 | self,
199 | tokenizer: PreTrainedTokenizer,
200 | batch_size: int = -1,
201 | seq_length: int = -1,
202 | is_pair: bool = False,
203 | framework: Optional[TensorType] = None,
204 | ) -> Mapping[str, Any]:
205 | common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
206 | tokenizer,
207 | batch_size=batch_size,
208 | seq_length=seq_length,
209 | is_pair=is_pair,
210 | framework=framework,
211 | )
212 |
213 | # We need to order the input in the way they appears in the forward()
214 | ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
215 |
216 | # Need to add the past_keys
217 | if self.use_past:
218 | if not is_torch_available():
219 | raise ValueError(
220 | "Cannot generate dummy past_keys inputs without PyTorch installed."
221 | )
222 | else:
223 | import torch
224 |
225 | batch, seqlen = common_inputs["input_ids"].shape
226 | # Not using the same length for past_key_values
227 | past_key_values_length = seqlen + 2
228 | past_shape = (
229 | batch,
230 | self.num_attention_heads,
231 | past_key_values_length,
232 | self._config.hidden_size // self.num_attention_heads,
233 | )
234 | ordered_inputs["past_key_values"] = [
235 | (torch.zeros(past_shape), torch.zeros(past_shape))
236 | for _ in range(self.num_layers)
237 | ]
238 |
239 | ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
240 | if self.use_past:
241 | mask_dtype = ordered_inputs["attention_mask"].dtype
242 | ordered_inputs["attention_mask"] = torch.cat(
243 | [
244 | ordered_inputs["attention_mask"],
245 | torch.ones(batch, past_key_values_length, dtype=mask_dtype),
246 | ],
247 | dim=1,
248 | )
249 |
250 | return ordered_inputs
251 |
252 | @property
253 | def default_onnx_opset(self) -> int:
254 | return 13
255 |
--------------------------------------------------------------------------------
/models/model_download/download_model.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import torch
5 | import transformers
6 | from transformers import AutoModelForCausalLM
7 |
8 | model_path = os.path.join(os.getcwd(), "model")
9 |
10 | if not os.path.exists(os.path.dirname(model_path)):
11 | os.mkdir(model_path)
12 |
13 | def download_model(model_name):
14 | model = AutoModelForCausalLM.from_pretrained(
15 | model_name, device_map="auto", torchscript=True, trust_remote_code=True
16 | ) # torchscript will force `return_dict=False` to avoid jit errors
17 | print("Loaded model")
18 |
19 | model.save_pretrained(model_path)
20 |
21 | print("Model downloaded and Saved in : ", model_path)
22 |
23 |
24 | def get_args():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument(
27 | "--model_name", type=str, default=None,
28 | help="like cerebras/Cerebras-GPT-6.7B, etc."
29 | )
30 | args = parser.parse_args()
31 |
32 | return args
33 |
34 | if __name__ == '__main__':
35 | args = get_args()
36 |
37 | # This check will be relaxed as more models are supported.
38 | supported_models = ['cerebras/Cerebras-GPT-6.7B',
39 | 'mosaicml/mpt-7b', 'EleutherAI/gpt-j-6B']
40 | if args.model_name not in supported_models:
41 | raise Exception(f'Unsupported Model Name,. Only models in ' \
42 | f'{supported_models} are supported.')
43 |
44 | download_model(args.model_name)
--------------------------------------------------------------------------------
/models/mpt-keyformer-lib/attention_lm_eval_harness.py:
--------------------------------------------------------------------------------
1 | """Attention layers."""
2 | import sys
3 | import math
4 | import warnings
5 | import copy
6 | from typing import Optional
7 | from functools import reduce
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn.parameter import Parameter
11 | from einops import rearrange
12 | from packaging import version
13 | from torch import nn
14 | import numpy as np
15 | from .norm import LPLayerNorm
16 |
17 |
18 | def _reset_is_causal(
19 | num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
20 | ):
21 | if original_is_causal and num_query_tokens != num_key_tokens:
22 | if num_query_tokens != 1:
23 | raise NotImplementedError(
24 | "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
25 | )
26 | else:
27 | return False
28 | return original_is_causal
29 |
30 |
31 | def keyformer_mask(self, attn_weights, key_tokens, tau_init):
32 | # attn_weights (BS, head, query, keys)
33 | dtype_attn_weights = attn_weights.dtype
34 | seq_length = attn_weights.shape[-1]
35 | padding_length = 0
36 |
37 | offset = torch.finfo(attn_weights.dtype).min
38 | tmp_attn = nn.functional.gumbel_softmax(
39 | attn_weights, dim=-1, tau=tau_init, hard=False
40 | ).to(dtype_attn_weights)
41 |
42 | accumulated_score = torch.sum(
43 | tmp_attn[:, :, padding_length : key_tokens + padding_length, :], dim=-2
44 | ) # (head, keys)
45 | accumulated_score[:, :, key_tokens + padding_length :] = 0
46 | accumulated_score[:, :, :padding_length] = 0
47 |
48 | mask_bottom = torch.zeros_like(attn_weights, dtype=torch.bool)
49 | mask_bottom[
50 | :,
51 | :,
52 | padding_length : key_tokens + padding_length,
53 | padding_length : key_tokens + padding_length,
54 | ] = True
55 |
56 | for token_index in range(key_tokens + padding_length, seq_length):
57 | tmp_attn_index = nn.functional.gumbel_softmax(
58 | attn_weights[:, :, token_index, :], dim=-1, tau=tau_init, hard=False
59 | ).to(dtype_attn_weights)
60 | _, tmp_topk_index = accumulated_score.topk(k=key_tokens - 1, dim=-1)
61 | zeros_index = torch.zeros_like(tmp_attn_index, dtype=torch.bool)
62 | mask_bottom_index = zeros_index.scatter(
63 | -1, tmp_topk_index, True
64 | ) # (head, keys)
65 | mask_bottom_index[:, :, token_index] = True
66 |
67 | mask_bottom[:, :, token_index, :] = mask_bottom_index
68 | accumulated_attention_score += tmp_attn_index
69 | accumulated_attention_score = accumulated_score * mask_bottom_index
70 |
71 | return mask_bottom
72 |
73 |
74 | def scaled_multihead_dot_product_attention(
75 | query,
76 | key,
77 | value,
78 | n_heads,
79 | past_key_value=None,
80 | softmax_scale=None,
81 | attn_bias=None,
82 | key_padding_mask=None,
83 | is_causal=False,
84 | dropout_p=0.0,
85 | keyformer=False,
86 | score_fn=None,
87 | token_discard_mask=None,
88 | token_discard_idx=None,
89 | req_tokens=None,
90 | kv_cache=60.0,
91 | recent=30.0,
92 | tau_init=1.0,
93 | tau_delta=0.01,
94 | itr_count=0,
95 | training=False,
96 | needs_weights=False,
97 | multiquery=False,
98 | ):
99 | q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
100 | kv_n_heads = 1 if multiquery else n_heads
101 | k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
102 | v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
103 |
104 | if past_key_value is not None:
105 | if len(past_key_value) != 0:
106 | k = torch.cat([past_key_value[0], k], dim=3)
107 | v = torch.cat([past_key_value[1], v], dim=2)
108 | past_key_value = (k, v)
109 |
110 | (b, h, s_q, d) = q.shape
111 | s_k = k.size(-1)
112 |
113 | if softmax_scale is None:
114 | softmax_scale = 1 / math.sqrt(d)
115 | attn_weight = q.matmul(k) * softmax_scale
116 |
117 | _s_q = max(0, attn_bias.size(2) - s_q)
118 | _s_k = max(0, attn_bias.size(3) - s_k)
119 | attn_bias = attn_bias[:, :, _s_q:, _s_k:]
120 | if (
121 | attn_bias.size(-1) != 1
122 | and attn_bias.size(-1) != s_k
123 | or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
124 | ):
125 | raise RuntimeError(
126 | f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
127 | )
128 | attn_weight = attn_weight + attn_bias
129 |
130 | if key_padding_mask is not None:
131 | if attn_bias is not None:
132 | warnings.warn(
133 | "Propogating key_padding_mask to the attention module "
134 | + "and applying it within the attention module can cause "
135 | + "unneccessary computation/memory usage. Consider integrating "
136 | + "into attn_bias once and passing that to each attention "
137 | + "module instead."
138 | )
139 | attn_weight = attn_weight.masked_fill(
140 | ~key_padding_mask.view((b, 1, 1, s_k)), min_val
141 | )
142 |
143 | min_val = torch.finfo(q.dtype).min
144 |
145 | if is_causal and (not q.size(2) == 1):
146 | s = max(s_q, s_k)
147 | causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
148 | causal_mask = causal_mask.tril()
149 | causal_mask = causal_mask.to(torch.bool)
150 | causal_mask = ~causal_mask
151 | causal_mask = causal_mask[-s_q:, -s_k:]
152 | attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
153 | # Initialize score_fn
154 | score_fn = None
155 | # Initialize token discard mask during prompt
156 | token_discard_mask = None
157 | # Initialize itr_count during prompt
158 | itr_count = 0
159 | # Initialize required tokens during prompt
160 | req_tokens = None
161 |
162 | (b, h, s, d) = attn_weight.shape
163 | # ==============================================================
164 |
165 | # ==============================================================
166 | if keyformer:
167 | # Keyformer
168 | req_tokens = int((attn_weight.shape[-1] * kv_cache) / 100)
169 | recent_tokens = int((req_tokens * recent) / 100)
170 | prompt_tokens = req_tokens - recent_tokens
171 | req_tokens = (recent_tokens, prompt_tokens)
172 |
173 | if req_tokens[1] > 0:
174 | mask_bottom = self.keyformer_mask(attn_weight, req_tokens[1], tau_init)
175 | else:
176 | mask_bottom = torch.zeros_like(attn_weight, dtype=torch.bool)
177 |
178 | ones = torch.ones_like(attn_weight, dtype=torch.bool)
179 | ones = torch.triu(ones, diagonal=-req_tokens[0])
180 | mask_bottom = torch.logical_or(mask_bottom, ones)
181 |
182 | mask_bottom = torch.tril(mask_bottom, diagonal=0)
183 |
184 | attn_weight[~mask_bottom] = torch.tensor(torch.finfo(attn_weights.dtype).min)
185 | # ==============================================================
186 |
187 | attn_weight = torch.softmax(attn_weight, dim=-1)
188 | if dropout_p:
189 | attn_weight = torch.nn.functional.dropout(
190 | attn_weight, p=dropout_p, training=training, inplace=True
191 | )
192 |
193 | sparsity = past_key_value[0].shape
194 |
195 | out = attn_weight.matmul(v)
196 | out = rearrange(out, "b h s d -> b s (h d)")
197 |
198 | itr_count = itr_count + 1
199 |
200 | if needs_weights:
201 | return (
202 | out,
203 | attn_weight,
204 | sparsity,
205 | past_key_value,
206 | score_fn,
207 | token_discard_mask,
208 | token_discard_idx,
209 | req_tokens,
210 | itr_count,
211 | )
212 | return (
213 | out,
214 | None,
215 | None,
216 | past_key_value,
217 | score_fn,
218 | token_discard_mask,
219 | token_discard_idx,
220 | req_tokens,
221 | itr_count,
222 | )
223 |
224 |
225 | def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
226 | for tensor in tensors:
227 | if tensor.dtype not in valid_dtypes:
228 | raise TypeError(
229 | f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
230 | )
231 | if not tensor.is_cuda:
232 | raise TypeError(
233 | f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
234 | )
235 |
236 |
237 | def flash_attn_fn(
238 | query,
239 | key,
240 | value,
241 | n_heads,
242 | past_key_value=None,
243 | softmax_scale=None,
244 | attn_bias=None,
245 | key_padding_mask=None,
246 | is_causal=False,
247 | dropout_p=0.0,
248 | training=False,
249 | needs_weights=False,
250 | multiquery=False,
251 | ):
252 | try:
253 | from flash_attn import bert_padding, flash_attn_interface
254 | except:
255 | raise RuntimeError("Please install flash-attn==1.0.3.post0")
256 | check_valid_inputs(query, key, value)
257 | if past_key_value is not None:
258 | if len(past_key_value) != 0:
259 | key = torch.cat([past_key_value[0], key], dim=1)
260 | value = torch.cat([past_key_value[1], value], dim=1)
261 | past_key_value = (key, value)
262 | if attn_bias is not None:
263 | _s_q = max(0, attn_bias.size(2) - query.size(1))
264 | _s_k = max(0, attn_bias.size(3) - key.size(1))
265 | attn_bias = attn_bias[:, :, _s_q:, _s_k:]
266 | if attn_bias is not None:
267 | raise NotImplementedError(f"attn_bias not implemented for flash attn.")
268 | (batch_size, seqlen) = query.shape[:2]
269 | if key_padding_mask is None:
270 | key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
271 | query_padding_mask = key_padding_mask[:, -query.size(1) :]
272 | (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
273 | query, query_padding_mask
274 | )
275 | query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
276 | (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
277 | key, key_padding_mask
278 | )
279 | key_unpad = rearrange(
280 | key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
281 | )
282 | (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
283 | value_unpad = rearrange(
284 | value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
285 | )
286 | if multiquery:
287 | key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
288 | value_unpad = value_unpad.expand(
289 | value_unpad.size(0), n_heads, value_unpad.size(-1)
290 | )
291 | dropout_p = dropout_p if training else 0.0
292 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
293 | output_unpad = flash_attn_interface.flash_attn_unpadded_func(
294 | query_unpad,
295 | key_unpad,
296 | value_unpad,
297 | cu_seqlens_q,
298 | cu_seqlens_k,
299 | max_seqlen_q,
300 | max_seqlen_k,
301 | dropout_p,
302 | softmax_scale=softmax_scale,
303 | causal=reset_is_causal,
304 | return_attn_probs=needs_weights,
305 | )
306 | output = bert_padding.pad_input(
307 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
308 | )
309 | return (output, None, past_key_value)
310 |
311 |
312 | def triton_flash_attn_fn(
313 | query,
314 | key,
315 | value,
316 | n_heads,
317 | past_key_value=None,
318 | softmax_scale=None,
319 | attn_bias=None,
320 | key_padding_mask=None,
321 | is_causal=False,
322 | dropout_p=0.0,
323 | training=False,
324 | needs_weights=False,
325 | multiquery=False,
326 | ):
327 | try:
328 | from .flash_attn_triton import flash_attn_func
329 | except:
330 | _installed = False
331 | if version.parse(torch.__version__) < version.parse("2.0.0"):
332 | _installed = True
333 | try:
334 | from flash_attn.flash_attn_triton import flash_attn_func
335 | except:
336 | _installed = False
337 | if not _installed:
338 | raise RuntimeError(
339 | "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
340 | )
341 | check_valid_inputs(query, key, value)
342 | if past_key_value is not None:
343 | if len(past_key_value) != 0:
344 | key = torch.cat([past_key_value[0], key], dim=1)
345 | value = torch.cat([past_key_value[1], value], dim=1)
346 | past_key_value = (key, value)
347 | if attn_bias is not None:
348 | _s_q = max(0, attn_bias.size(2) - query.size(1))
349 | _s_k = max(0, attn_bias.size(3) - key.size(1))
350 | attn_bias = attn_bias[:, :, _s_q:, _s_k:]
351 | if dropout_p:
352 | raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
353 | if needs_weights:
354 | raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
355 | if key_padding_mask is not None:
356 | warnings.warn(
357 | "Propagating key_padding_mask to the attention module "
358 | + "and applying it within the attention module can cause "
359 | + "unnecessary computation/memory usage. Consider integrating "
360 | + "into attn_bias once and passing that to each attention "
361 | + "module instead."
362 | )
363 | (b_size, s_k) = key_padding_mask.shape[:2]
364 | if attn_bias is None:
365 | attn_bias = query.new_zeros(b_size, 1, 1, s_k)
366 | attn_bias = attn_bias.masked_fill(
367 | ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
368 | )
369 | query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
370 | key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
371 | value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
372 | if multiquery:
373 | key = key.expand(*key.shape[:2], n_heads, key.size(-1))
374 | value = value.expand(*value.shape[:2], n_heads, value.size(-1))
375 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
376 | attn_output = flash_attn_func(
377 | query, key, value, attn_bias, reset_is_causal, softmax_scale
378 | )
379 | output = attn_output.view(*attn_output.shape[:2], -1)
380 | return (output, None, past_key_value)
381 |
382 |
383 | class MultiheadAttention(nn.Module):
384 | """Multi-head self attention.
385 |
386 | Using torch or triton attention implemetation enables user to also use
387 | additive bias.
388 | """
389 |
390 | def __init__(
391 | self,
392 | d_model: int,
393 | n_heads: int,
394 | attn_impl: str = "triton",
395 | clip_qkv: Optional[float] = None,
396 | qk_ln: bool = False,
397 | softmax_scale: Optional[float] = None,
398 | attn_pdrop: float = 0.0,
399 | keyformer: bool = False,
400 | kv_cache: float = 60.0,
401 | recent: float = 30.0,
402 | tau_init: float = 1.0,
403 | tau_delta: float = 0.01,
404 | low_precision_layernorm: bool = False,
405 | verbose: int = 0,
406 | device: Optional[str] = None,
407 | ):
408 | super().__init__()
409 | self.attn_impl = attn_impl
410 | self.clip_qkv = clip_qkv
411 | self.qk_ln = qk_ln
412 | self.d_model = d_model
413 | self.n_heads = n_heads
414 | self.softmax_scale = softmax_scale
415 | if self.softmax_scale is None:
416 | self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
417 | self.attn_dropout_p = attn_pdrop
418 | # ==================== Keyformer =========================
419 | self.keyformer = keyformer
420 | self.kv_cache = kv_cache
421 | self.recent = recent
422 | self.tau_init = tau_init
423 | self.tau_delta = tau_delta
424 | # ========================================================
425 | self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
426 | fuse_splits = (d_model, 2 * d_model)
427 | self.Wqkv._fused = (0, fuse_splits)
428 | if self.qk_ln:
429 | layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
430 | self.q_ln = layernorm_class(self.d_model, device=device)
431 | self.k_ln = layernorm_class(self.d_model, device=device)
432 | if self.attn_impl == "flash":
433 | self.attn_fn = flash_attn_fn
434 | elif self.attn_impl == "triton":
435 | self.attn_fn = triton_flash_attn_fn
436 | if verbose:
437 | warnings.warn(
438 | "While `attn_impl: triton` can be faster than `attn_impl: flash` "
439 | + "it uses more memory. When training larger models this can trigger "
440 | + "alloc retries which hurts performance. If encountered, we recommend "
441 | + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
442 | )
443 | elif self.attn_impl == "torch":
444 | self.attn_fn = scaled_multihead_dot_product_attention
445 | if torch.cuda.is_available() and verbose:
446 | warnings.warn(
447 | "Using `attn_impl: torch`. If your model does not use `alibi` or "
448 | + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
449 | + "we recommend using `attn_impl: triton`."
450 | )
451 | else:
452 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
453 | self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
454 | self.out_proj._is_residual = True
455 |
456 | def forward(
457 | self,
458 | x,
459 | past_key_value=None,
460 | attn_bias=None,
461 | attention_mask=None,
462 | is_causal=True,
463 | needs_weights=False,
464 | score_fn=None,
465 | token_discard_mask=None,
466 | token_discard_idx=None,
467 | req_tokens=None,
468 | itr_count=0,
469 | ):
470 | qkv = self.Wqkv(x)
471 | if self.clip_qkv:
472 | qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
473 | (query, key, value) = qkv.chunk(3, dim=2)
474 | key_padding_mask = attention_mask
475 | if self.qk_ln:
476 | dtype = query.dtype
477 | query = self.q_ln(query).to(dtype)
478 | key = self.k_ln(key).to(dtype)
479 | (
480 | context,
481 | attn_weights,
482 | sparsity,
483 | past_key_value,
484 | score_fn,
485 | token_discard_mask,
486 | token_discard_idx,
487 | req_tokens,
488 | itr_count,
489 | ) = self.attn_fn(
490 | query,
491 | key,
492 | value,
493 | self.n_heads,
494 | past_key_value=past_key_value,
495 | softmax_scale=self.softmax_scale,
496 | attn_bias=attn_bias,
497 | key_padding_mask=key_padding_mask,
498 | is_causal=is_causal,
499 | dropout_p=self.attn_dropout_p,
500 | keyformer=self.keyformer,
501 | score_fn=score_fn,
502 | token_discard_mask=token_discard_mask,
503 | token_discard_idx=token_discard_idx,
504 | req_tokens=req_tokens,
505 | kv_cache=self.kv_cache,
506 | recent=self.recent,
507 | tau_init=self.tau_init,
508 | tau_delta=self.tau_delta,
509 | itr_count=itr_count,
510 | training=self.training,
511 | needs_weights=needs_weights,
512 | )
513 | return (
514 | self.out_proj(context),
515 | attn_weights,
516 | sparsity,
517 | past_key_value,
518 | score_fn,
519 | token_discard_mask,
520 | token_discard_idx,
521 | req_tokens,
522 | itr_count,
523 | )
524 |
525 |
526 | def attn_bias_shape(
527 | attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
528 | ):
529 | if attn_impl == "flash":
530 | return None
531 | elif attn_impl in ["torch", "triton"]:
532 | if alibi:
533 | if (prefix_lm or not causal) or use_sequence_id:
534 | return (1, n_heads, seq_len, seq_len)
535 | return (1, n_heads, 1, seq_len)
536 | elif prefix_lm or use_sequence_id:
537 | return (1, 1, seq_len, seq_len)
538 | return None
539 | else:
540 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
541 |
542 |
543 | def build_attn_bias(
544 | attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
545 | ):
546 | if attn_impl == "flash":
547 | return None
548 | elif attn_impl in ["torch", "triton"]:
549 | if alibi:
550 | (device, dtype) = (attn_bias.device, attn_bias.dtype)
551 | attn_bias = attn_bias.add(
552 | build_alibi_bias(
553 | n_heads,
554 | seq_len,
555 | full=not causal,
556 | alibi_bias_max=alibi_bias_max,
557 | device=device,
558 | dtype=dtype,
559 | )
560 | )
561 | return attn_bias
562 | else:
563 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
564 |
565 |
566 | def gen_slopes(n_heads, alibi_bias_max=8, device=None):
567 | _n_heads = 2 ** math.ceil(math.log2(n_heads))
568 | m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
569 | m = m.mul(alibi_bias_max / _n_heads)
570 | slopes = 1.0 / torch.pow(2, m)
571 | if _n_heads != n_heads:
572 | slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
573 | return slopes.view(1, n_heads, 1, 1)
574 |
575 |
576 | def build_alibi_bias(
577 | n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
578 | ):
579 | alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
580 | 1, 1, 1, seq_len
581 | )
582 | if full:
583 | alibi_bias = alibi_bias - torch.arange(
584 | 1 - seq_len, 1, dtype=torch.int32, device=device
585 | ).view(1, 1, seq_len, 1)
586 | alibi_bias = alibi_bias.abs().mul(-1)
587 | slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
588 | alibi_bias = alibi_bias * slopes
589 | return alibi_bias.to(dtype=dtype)
590 |
591 |
592 | ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention}
593 |
--------------------------------------------------------------------------------
/models/mpt-keyformer-lib/attention_streaming_llm.py:
--------------------------------------------------------------------------------
1 | """Attention layers."""
2 | import sys
3 | import math
4 | import warnings
5 | import copy
6 | from typing import Optional
7 | from functools import reduce
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn.parameter import Parameter
11 | from einops import rearrange
12 | from packaging import version
13 | from torch import nn
14 | import numpy as np
15 | from .norm import LPLayerNorm
16 |
17 |
18 | def _reset_is_causal(
19 | num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
20 | ):
21 | if original_is_causal and num_query_tokens != num_key_tokens:
22 | if num_query_tokens != 1:
23 | raise NotImplementedError(
24 | "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
25 | )
26 | else:
27 | return False
28 | return original_is_causal
29 |
30 |
31 | def scaled_multihead_dot_product_attention(
32 | query,
33 | key,
34 | value,
35 | n_heads,
36 | past_key_value=None,
37 | softmax_scale=None,
38 | attn_bias=None,
39 | key_padding_mask=None,
40 | is_causal=False,
41 | dropout_p=0.0,
42 | keyformer=False,
43 | score_fn=None,
44 | token_discard_mask=None,
45 | token_discard_idx=None,
46 | req_tokens=None,
47 | kv_cache=60.0,
48 | recent=30.0,
49 | tau_init=1.0,
50 | tau_delta=0.01,
51 | itr_count=0,
52 | training=False,
53 | needs_weights=False,
54 | multiquery=False,
55 | ):
56 | q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
57 | kv_n_heads = 1 if multiquery else n_heads
58 | k = rearrange(key, "b s (h d) -> b h d s", h=kv_n_heads)
59 | v = rearrange(value, "b s (h d) -> b h s d", h=kv_n_heads)
60 |
61 | if past_key_value is not None:
62 | if len(past_key_value) != 0:
63 | k = torch.cat([past_key_value[0], k], dim=3)
64 | v = torch.cat([past_key_value[1], v], dim=2)
65 | past_key_value = (k, v)
66 |
67 | (b, h, s_q, d) = q.shape
68 | s_k = k.size(-1)
69 |
70 | if softmax_scale is None:
71 | softmax_scale = 1 / math.sqrt(d)
72 | attn_weight = q.matmul(k) * softmax_scale
73 |
74 | _s_q = max(0, attn_bias.size(2) - s_q)
75 | _s_k = max(0, attn_bias.size(3) - s_k)
76 | attn_bias = attn_bias[:, :, _s_q:, _s_k:]
77 | if (
78 | attn_bias.size(-1) != 1
79 | and attn_bias.size(-1) != s_k
80 | or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
81 | ):
82 | raise RuntimeError(
83 | f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
84 | )
85 | attn_weight = attn_weight + attn_bias
86 |
87 | if key_padding_mask is not None:
88 | if attn_bias is not None:
89 | warnings.warn(
90 | "Propogating key_padding_mask to the attention module "
91 | + "and applying it within the attention module can cause "
92 | + "unneccessary computation/memory usage. Consider integrating "
93 | + "into attn_bias once and passing that to each attention "
94 | + "module instead."
95 | )
96 | attn_weight = attn_weight.masked_fill(
97 | ~key_padding_mask.view((b, 1, 1, s_k)), min_val
98 | )
99 |
100 | min_val = torch.finfo(q.dtype).min
101 |
102 | if is_causal and (not q.size(2) == 1):
103 | s = max(s_q, s_k)
104 | causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
105 | causal_mask = causal_mask.tril()
106 | causal_mask = causal_mask.to(torch.bool)
107 | causal_mask = ~causal_mask
108 | causal_mask = causal_mask[-s_q:, -s_k:]
109 | attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
110 | # Initialize score_fn
111 | score_fn = None
112 | # Initialize token discard mask during prompt
113 | token_discard_mask = None
114 | # Initialize itr_count during prompt
115 | itr_count = 0
116 | # Initialize required tokens during prompt
117 | req_tokens = None
118 |
119 | (b, h, s, d) = attn_weight.shape
120 |
121 | attn_weight = torch.softmax(attn_weight, dim=-1)
122 | if dropout_p:
123 | attn_weight = torch.nn.functional.dropout(
124 | attn_weight, p=dropout_p, training=training, inplace=True
125 | )
126 |
127 | # ==============================================================
128 | # ==> StreamingLLM (Attention sinks + Recent Window)
129 | if is_causal:
130 | # Recent + Attention Sink
131 | if not (q.size(2) == 1):
132 | req_tokens = int((attn_weight.shape[-1] * kv_cache) / 100)
133 | prompt_tokens = 4 # Attention Sinks
134 | recent_tokens = req_tokens - prompt_tokens
135 | req_tokens = (recent_tokens, prompt_tokens)
136 |
137 | token_mask = attn_weight.new_ones(b, h, 1, d, dtype=torch.float16)
138 | token_mask = token_mask.to(torch.bool)
139 |
140 | # activate most recent k-cache
141 | token_mask[:, :, :, : -req_tokens[0]] = False
142 | # activate attention sinks
143 | token_mask[:, :, :, : req_tokens[1]] = True
144 |
145 | sparse_len = req_tokens[0] + req_tokens[1]
146 |
147 | # ==> Reducing KV cache
148 | v_new = v[token_mask.squeeze()]
149 | v_new = rearrange(v_new, "(b h s) d -> b h s d", h=kv_n_heads, s=sparse_len)
150 |
151 | k_new = k.transpose(2, 3)[token_mask.squeeze()]
152 | k_new = rearrange(k_new, "(b h s) d -> b h s d", h=kv_n_heads, s=sparse_len)
153 | k_new = k_new.transpose(2, 3)
154 |
155 | # ==> Updating KV cache
156 | past_key_value = (k_new, v_new)
157 | # ================================================================================
158 | sparsity = past_key_value[0].shape
159 |
160 | out = attn_weight.matmul(v)
161 | out = rearrange(out, "b h s d -> b s (h d)")
162 |
163 | itr_count = itr_count + 1
164 |
165 | if needs_weights:
166 | return (
167 | out,
168 | attn_weight,
169 | sparsity,
170 | past_key_value,
171 | score_fn,
172 | token_discard_mask,
173 | token_discard_idx,
174 | req_tokens,
175 | itr_count,
176 | )
177 | return (
178 | out,
179 | None,
180 | None,
181 | past_key_value,
182 | score_fn,
183 | token_discard_mask,
184 | token_discard_idx,
185 | req_tokens,
186 | itr_count,
187 | )
188 |
189 |
190 | def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
191 | for tensor in tensors:
192 | if tensor.dtype not in valid_dtypes:
193 | raise TypeError(
194 | f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
195 | )
196 | if not tensor.is_cuda:
197 | raise TypeError(
198 | f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
199 | )
200 |
201 |
202 | def flash_attn_fn(
203 | query,
204 | key,
205 | value,
206 | n_heads,
207 | past_key_value=None,
208 | softmax_scale=None,
209 | attn_bias=None,
210 | key_padding_mask=None,
211 | is_causal=False,
212 | dropout_p=0.0,
213 | training=False,
214 | needs_weights=False,
215 | multiquery=False,
216 | ):
217 | try:
218 | from flash_attn import bert_padding, flash_attn_interface
219 | except:
220 | raise RuntimeError("Please install flash-attn==1.0.3.post0")
221 | check_valid_inputs(query, key, value)
222 | if past_key_value is not None:
223 | if len(past_key_value) != 0:
224 | key = torch.cat([past_key_value[0], key], dim=1)
225 | value = torch.cat([past_key_value[1], value], dim=1)
226 | past_key_value = (key, value)
227 | if attn_bias is not None:
228 | _s_q = max(0, attn_bias.size(2) - query.size(1))
229 | _s_k = max(0, attn_bias.size(3) - key.size(1))
230 | attn_bias = attn_bias[:, :, _s_q:, _s_k:]
231 | if attn_bias is not None:
232 | raise NotImplementedError(f"attn_bias not implemented for flash attn.")
233 | (batch_size, seqlen) = query.shape[:2]
234 | if key_padding_mask is None:
235 | key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
236 | query_padding_mask = key_padding_mask[:, -query.size(1) :]
237 | (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
238 | query, query_padding_mask
239 | )
240 | query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
241 | (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
242 | key, key_padding_mask
243 | )
244 | key_unpad = rearrange(
245 | key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
246 | )
247 | (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
248 | value_unpad = rearrange(
249 | value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
250 | )
251 | if multiquery:
252 | key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
253 | value_unpad = value_unpad.expand(
254 | value_unpad.size(0), n_heads, value_unpad.size(-1)
255 | )
256 | dropout_p = dropout_p if training else 0.0
257 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
258 | output_unpad = flash_attn_interface.flash_attn_unpadded_func(
259 | query_unpad,
260 | key_unpad,
261 | value_unpad,
262 | cu_seqlens_q,
263 | cu_seqlens_k,
264 | max_seqlen_q,
265 | max_seqlen_k,
266 | dropout_p,
267 | softmax_scale=softmax_scale,
268 | causal=reset_is_causal,
269 | return_attn_probs=needs_weights,
270 | )
271 | output = bert_padding.pad_input(
272 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
273 | )
274 | return (output, None, past_key_value)
275 |
276 |
277 | def triton_flash_attn_fn(
278 | query,
279 | key,
280 | value,
281 | n_heads,
282 | past_key_value=None,
283 | softmax_scale=None,
284 | attn_bias=None,
285 | key_padding_mask=None,
286 | is_causal=False,
287 | dropout_p=0.0,
288 | training=False,
289 | needs_weights=False,
290 | multiquery=False,
291 | ):
292 | try:
293 | from .flash_attn_triton import flash_attn_func
294 | except:
295 | _installed = False
296 | if version.parse(torch.__version__) < version.parse("2.0.0"):
297 | _installed = True
298 | try:
299 | from flash_attn.flash_attn_triton import flash_attn_func
300 | except:
301 | _installed = False
302 | if not _installed:
303 | raise RuntimeError(
304 | "Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed."
305 | )
306 | check_valid_inputs(query, key, value)
307 | if past_key_value is not None:
308 | if len(past_key_value) != 0:
309 | key = torch.cat([past_key_value[0], key], dim=1)
310 | value = torch.cat([past_key_value[1], value], dim=1)
311 | past_key_value = (key, value)
312 | if attn_bias is not None:
313 | _s_q = max(0, attn_bias.size(2) - query.size(1))
314 | _s_k = max(0, attn_bias.size(3) - key.size(1))
315 | attn_bias = attn_bias[:, :, _s_q:, _s_k:]
316 | if dropout_p:
317 | raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
318 | if needs_weights:
319 | raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
320 | if key_padding_mask is not None:
321 | warnings.warn(
322 | "Propagating key_padding_mask to the attention module "
323 | + "and applying it within the attention module can cause "
324 | + "unnecessary computation/memory usage. Consider integrating "
325 | + "into attn_bias once and passing that to each attention "
326 | + "module instead."
327 | )
328 | (b_size, s_k) = key_padding_mask.shape[:2]
329 | if attn_bias is None:
330 | attn_bias = query.new_zeros(b_size, 1, 1, s_k)
331 | attn_bias = attn_bias.masked_fill(
332 | ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
333 | )
334 | query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
335 | key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
336 | value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
337 | if multiquery:
338 | key = key.expand(*key.shape[:2], n_heads, key.size(-1))
339 | value = value.expand(*value.shape[:2], n_heads, value.size(-1))
340 | reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
341 | attn_output = flash_attn_func(
342 | query, key, value, attn_bias, reset_is_causal, softmax_scale
343 | )
344 | output = attn_output.view(*attn_output.shape[:2], -1)
345 | return (output, None, past_key_value)
346 |
347 |
348 | class MultiheadAttention(nn.Module):
349 | """Multi-head self attention.
350 |
351 | Using torch or triton attention implemetation enables user to also use
352 | additive bias.
353 | """
354 |
355 | def __init__(
356 | self,
357 | d_model: int,
358 | n_heads: int,
359 | attn_impl: str = "triton",
360 | clip_qkv: Optional[float] = None,
361 | qk_ln: bool = False,
362 | softmax_scale: Optional[float] = None,
363 | attn_pdrop: float = 0.0,
364 | keyformer: bool = False,
365 | kv_cache: float = 60.0,
366 | recent: float = 30.0,
367 | tau_init: float = 1.0,
368 | tau_delta: float = 0.01,
369 | low_precision_layernorm: bool = False,
370 | verbose: int = 0,
371 | device: Optional[str] = None,
372 | ):
373 | super().__init__()
374 | self.attn_impl = attn_impl
375 | self.clip_qkv = clip_qkv
376 | self.qk_ln = qk_ln
377 | self.d_model = d_model
378 | self.n_heads = n_heads
379 | self.softmax_scale = softmax_scale
380 | if self.softmax_scale is None:
381 | self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
382 | self.attn_dropout_p = attn_pdrop
383 | # ==================== Keyformer =========================
384 | self.keyformer = keyformer
385 | self.kv_cache = kv_cache
386 | self.recent = recent
387 | self.tau_init = tau_init
388 | self.tau_delta = tau_delta
389 | # ========================================================
390 | self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
391 | fuse_splits = (d_model, 2 * d_model)
392 | self.Wqkv._fused = (0, fuse_splits)
393 | if self.qk_ln:
394 | layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
395 | self.q_ln = layernorm_class(self.d_model, device=device)
396 | self.k_ln = layernorm_class(self.d_model, device=device)
397 | if self.attn_impl == "flash":
398 | self.attn_fn = flash_attn_fn
399 | elif self.attn_impl == "triton":
400 | self.attn_fn = triton_flash_attn_fn
401 | if verbose:
402 | warnings.warn(
403 | "While `attn_impl: triton` can be faster than `attn_impl: flash` "
404 | + "it uses more memory. When training larger models this can trigger "
405 | + "alloc retries which hurts performance. If encountered, we recommend "
406 | + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
407 | )
408 | elif self.attn_impl == "torch":
409 | self.attn_fn = scaled_multihead_dot_product_attention
410 | if torch.cuda.is_available() and verbose:
411 | warnings.warn(
412 | "Using `attn_impl: torch`. If your model does not use `alibi` or "
413 | + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
414 | + "we recommend using `attn_impl: triton`."
415 | )
416 | else:
417 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
418 | self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
419 | self.out_proj._is_residual = True
420 |
421 | def forward(
422 | self,
423 | x,
424 | past_key_value=None,
425 | attn_bias=None,
426 | attention_mask=None,
427 | is_causal=True,
428 | needs_weights=False,
429 | score_fn=None,
430 | token_discard_mask=None,
431 | token_discard_idx=None,
432 | req_tokens=None,
433 | itr_count=0,
434 | ):
435 | qkv = self.Wqkv(x)
436 | if self.clip_qkv:
437 | qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
438 | (query, key, value) = qkv.chunk(3, dim=2)
439 | key_padding_mask = attention_mask
440 | if self.qk_ln:
441 | dtype = query.dtype
442 | query = self.q_ln(query).to(dtype)
443 | key = self.k_ln(key).to(dtype)
444 | (
445 | context,
446 | attn_weights,
447 | sparsity,
448 | past_key_value,
449 | score_fn,
450 | token_discard_mask,
451 | token_discard_idx,
452 | req_tokens,
453 | itr_count,
454 | ) = self.attn_fn(
455 | query,
456 | key,
457 | value,
458 | self.n_heads,
459 | past_key_value=past_key_value,
460 | softmax_scale=self.softmax_scale,
461 | attn_bias=attn_bias,
462 | key_padding_mask=key_padding_mask,
463 | is_causal=is_causal,
464 | dropout_p=self.attn_dropout_p,
465 | keyformer=self.keyformer,
466 | score_fn=score_fn,
467 | token_discard_mask=token_discard_mask,
468 | token_discard_idx=token_discard_idx,
469 | req_tokens=req_tokens,
470 | kv_cache=self.kv_cache,
471 | recent=self.recent,
472 | tau_init=self.tau_init,
473 | tau_delta=self.tau_delta,
474 | itr_count=itr_count,
475 | training=self.training,
476 | needs_weights=needs_weights,
477 | )
478 | return (
479 | self.out_proj(context),
480 | attn_weights,
481 | sparsity,
482 | past_key_value,
483 | score_fn,
484 | token_discard_mask,
485 | token_discard_idx,
486 | req_tokens,
487 | itr_count,
488 | )
489 |
490 |
491 | def attn_bias_shape(
492 | attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
493 | ):
494 | if attn_impl == "flash":
495 | return None
496 | elif attn_impl in ["torch", "triton"]:
497 | if alibi:
498 | if (prefix_lm or not causal) or use_sequence_id:
499 | return (1, n_heads, seq_len, seq_len)
500 | return (1, n_heads, 1, seq_len)
501 | elif prefix_lm or use_sequence_id:
502 | return (1, 1, seq_len, seq_len)
503 | return None
504 | else:
505 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
506 |
507 |
508 | def build_attn_bias(
509 | attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
510 | ):
511 | if attn_impl == "flash":
512 | return None
513 | elif attn_impl in ["torch", "triton"]:
514 | if alibi:
515 | (device, dtype) = (attn_bias.device, attn_bias.dtype)
516 | attn_bias = attn_bias.add(
517 | build_alibi_bias(
518 | n_heads,
519 | seq_len,
520 | full=not causal,
521 | alibi_bias_max=alibi_bias_max,
522 | device=device,
523 | dtype=dtype,
524 | )
525 | )
526 | return attn_bias
527 | else:
528 | raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
529 |
530 |
531 | def gen_slopes(n_heads, alibi_bias_max=8, device=None):
532 | _n_heads = 2 ** math.ceil(math.log2(n_heads))
533 | m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
534 | m = m.mul(alibi_bias_max / _n_heads)
535 | slopes = 1.0 / torch.pow(2, m)
536 | if _n_heads != n_heads:
537 | slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
538 | return slopes.view(1, n_heads, 1, 1)
539 |
540 |
541 | def build_alibi_bias(
542 | n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
543 | ):
544 | alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
545 | 1, 1, 1, seq_len
546 | )
547 | if full:
548 | alibi_bias = alibi_bias - torch.arange(
549 | 1 - seq_len, 1, dtype=torch.int32, device=device
550 | ).view(1, 1, seq_len, 1)
551 | alibi_bias = alibi_bias.abs().mul(-1)
552 | slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
553 | alibi_bias = alibi_bias * slopes
554 | return alibi_bias.to(dtype=dtype)
555 |
556 |
557 | ATTN_CLASS_REGISTRY = {"multihead_attention": MultiheadAttention}
558 |
--------------------------------------------------------------------------------
/models/mpt-keyformer-lib/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import sys
4 | import torch
5 | import torch.nn as nn
6 | from .attention import ATTN_CLASS_REGISTRY
7 | from .norm import NORM_CLASS_REGISTRY
8 |
9 |
10 | class MPTMLP(nn.Module):
11 | def __init__(
12 | self, d_model: int, expansion_ratio: int, device: Optional[str] = None
13 | ):
14 | super().__init__()
15 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
16 | self.act = nn.GELU(approximate="none")
17 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
18 | self.down_proj._is_residual = True
19 |
20 | def forward(self, x):
21 | return self.down_proj(self.act(self.up_proj(x)))
22 |
23 |
24 | class MPTBlock(nn.Module):
25 | def __init__(
26 | self,
27 | d_model: int,
28 | n_heads: int,
29 | expansion_ratio: int,
30 | attn_config: Dict = {
31 | "attn_type": "multihead_attention",
32 | "attn_pdrop": 0.0,
33 | "attn_impl": "triton",
34 | "qk_ln": False,
35 | "clip_qkv": None,
36 | "softmax_scale": None,
37 | "prefix_lm": False,
38 | "attn_uses_sequence_id": False,
39 | "alibi": False,
40 | "alibi_bias_max": 8,
41 | },
42 | keyformer_config: Dict = {
43 | "keyformer": False,
44 | "kv_cache": 60,
45 | "recent": 30,
46 | "tau_init": 1.0,
47 | "tau_delta": 0.01,
48 | },
49 | resid_pdrop: float = 0.0,
50 | norm_type: str = "low_precision_layernorm",
51 | verbose: int = 0,
52 | device: Optional[str] = None,
53 | **kwargs
54 | ):
55 | del kwargs
56 | super().__init__()
57 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
58 | attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
59 | self.norm_1 = norm_class(d_model, device=device)
60 | self.attn = attn_class(
61 | attn_impl=attn_config["attn_impl"],
62 | clip_qkv=attn_config["clip_qkv"],
63 | qk_ln=attn_config["qk_ln"],
64 | softmax_scale=attn_config["softmax_scale"],
65 | attn_pdrop=attn_config["attn_pdrop"],
66 | keyformer=keyformer_config["keyformer"],
67 | kv_cache=keyformer_config["kv_cache"],
68 | recent=keyformer_config["recent"],
69 | tau_init=keyformer_config["tau_init"],
70 | tau_delta=keyformer_config["tau_delta"],
71 | d_model=d_model,
72 | n_heads=n_heads,
73 | verbose=verbose,
74 | device=device,
75 | )
76 | self.norm_2 = norm_class(d_model, device=device)
77 | self.ffn = MPTMLP(
78 | d_model=d_model, expansion_ratio=expansion_ratio, device=device
79 | )
80 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
81 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
82 | # ==================== Keyformer =========================
83 | self.keyformer = keyformer_config["keyformer"]
84 | self.kv_cache = keyformer_config["kv_cache"]
85 | self.recent = keyformer_config["recent"]
86 | self.tau_init = keyformer_config["tau_init"]
87 | self.tau_delta = keyformer_config["tau_delta"]
88 | self.token_discard_mask = None # Token discard mask for each MPT block
89 | self.token_discard_idx = None # Token discard indices for each MPT block
90 | self.req_tokens = None # Required tokens to attend for each MPT block
91 | self.score_fn = None # Score fn for each MPT block
92 | self.itr_count = 0
93 | # ========================================================
94 |
95 | # added output_attentions parameter for extracting output attentions
96 | def forward(
97 | self,
98 | x: torch.Tensor,
99 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
100 | attn_bias: Optional[torch.Tensor] = None,
101 | attention_mask: Optional[torch.ByteTensor] = None,
102 | is_causal: bool = True,
103 | output_attentions: bool = False,
104 | ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
105 | a = self.norm_1(x)
106 | # added needs_weight parameter missing for output_attentions
107 | (
108 | b,
109 | attn_weights,
110 | sparsity,
111 | past_key_value,
112 | score_fn,
113 | token_discard_mask,
114 | token_discard_idx,
115 | req_tokens,
116 | itr_count,
117 | ) = self.attn(
118 | a,
119 | past_key_value=past_key_value,
120 | attn_bias=attn_bias,
121 | attention_mask=attention_mask,
122 | is_causal=is_causal,
123 | needs_weights=output_attentions,
124 | score_fn=self.score_fn,
125 | token_discard_mask=self.token_discard_mask,
126 | token_discard_idx=self.token_discard_idx,
127 | req_tokens=self.req_tokens,
128 | itr_count=self.itr_count,
129 | )
130 | self.score_fn = score_fn
131 | self.token_discard_mask = token_discard_mask
132 | self.token_discard_idx = token_discard_idx
133 | self.req_tokens = req_tokens
134 | self.itr_count = itr_count
135 |
136 | x = x + self.resid_attn_dropout(b)
137 | m = self.norm_2(x)
138 | n = self.ffn(m)
139 | x = x + self.resid_ffn_dropout(n)
140 | return (x, attn_weights, sparsity, past_key_value)
141 |
--------------------------------------------------------------------------------
/models/mpt-keyformer-lib/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "mosaicml/mpt-7b",
3 | "architectures": [
4 | "MPTForCausalLM"
5 | ],
6 | "attn_config": {
7 | "alibi": true,
8 | "alibi_bias_max": 8,
9 | "attn_impl": "torch",
10 | "attn_pdrop": 0,
11 | "attn_type": "multihead_attention",
12 | "attn_uses_sequence_id": false,
13 | "clip_qkv": null,
14 | "prefix_lm": false,
15 | "qk_ln": false,
16 | "softmax_scale": null
17 | },
18 | "auto_map": {
19 | "AutoConfig": "configuration_mpt.MPTConfig",
20 | "AutoModelForCausalLM": "modeling_mpt.MPTForCausalLM"
21 | },
22 | "keyformer_config": {
23 | "keyformer": false,
24 | "kv_cache": 60,
25 | "recent": 30,
26 | "tau_init": 1.0,
27 | "tau_delta": 0.01
28 | },
29 | "d_model": 4096,
30 | "emb_pdrop": 0,
31 | "embedding_fraction": 1.0,
32 | "expansion_ratio": 4,
33 | "init_config": {
34 | "emb_init_std": null,
35 | "emb_init_uniform_lim": null,
36 | "fan_mode": "fan_in",
37 | "init_div_is_residual": true,
38 | "init_gain": 0,
39 | "init_nonlinearity": "relu",
40 | "init_std": 0.02,
41 | "name": "kaiming_normal_",
42 | "verbose": 0
43 | },
44 | "init_device": "cpu",
45 | "learned_pos_emb": true,
46 | "logit_scale": null,
47 | "max_seq_len": 2048,
48 | "model_type": "mpt",
49 | "n_heads": 32,
50 | "n_layers": 32,
51 | "no_bias": true,
52 | "norm_type": "low_precision_layernorm",
53 | "resid_pdrop": 0,
54 | "tokenizer_name": "EleutherAI/gpt-neox-20b",
55 | "torch_dtype": "float32",
56 | "transformers_version": "4.28.1",
57 | "use_cache": false,
58 | "verbose": 0,
59 | "vocab_size": 50432
60 | }
61 |
--------------------------------------------------------------------------------
/models/mpt-keyformer-lib/configuration_mpt.py:
--------------------------------------------------------------------------------
1 | """A HuggingFace-style model configuration."""
2 | from typing import Dict, Optional, Union
3 | from transformers import PretrainedConfig
4 |
5 | attn_config_defaults: Dict = {
6 | "attn_type": "multihead_attention",
7 | "attn_pdrop": 0.0,
8 | "attn_impl": "triton",
9 | "qk_ln": False,
10 | "clip_qkv": None,
11 | "softmax_scale": None,
12 | "prefix_lm": False,
13 | "attn_uses_sequence_id": False,
14 | "alibi": False,
15 | "alibi_bias_max": 8,
16 | }
17 | init_config_defaults: Dict = {
18 | "name": "kaiming_normal_",
19 | "fan_mode": "fan_in",
20 | "init_nonlinearity": "relu",
21 | "init_div_is_residual": True,
22 | "emb_init_std": None,
23 | "emb_init_uniform_lim": None,
24 | "init_std": None,
25 | "init_gain": 0.0,
26 | }
27 | keyformer_config_defaults: Dict = {
28 | "keyformer": False,
29 | "kv_cache": 60,
30 | "recent": 30,
31 | "tau_init": 1.0,
32 | "tau_delta": 0.01,
33 | }
34 |
35 |
36 | class MPTConfig(PretrainedConfig):
37 | model_type = "mpt"
38 |
39 | def __init__(
40 | self,
41 | d_model: int = 2048,
42 | n_heads: int = 16,
43 | n_layers: int = 24,
44 | expansion_ratio: int = 4,
45 | max_seq_len: int = 2048,
46 | vocab_size: int = 50368,
47 | resid_pdrop: float = 0.0,
48 | emb_pdrop: float = 0.0,
49 | learned_pos_emb: bool = True,
50 | attn_config: Dict = attn_config_defaults,
51 | init_device: str = "cpu",
52 | logit_scale: Optional[Union[float, str]] = None,
53 | no_bias: bool = False,
54 | verbose: int = 0,
55 | embedding_fraction: float = 1.0,
56 | norm_type: str = "low_precision_layernorm",
57 | use_cache: bool = False,
58 | init_config: Dict = init_config_defaults,
59 | keyformer_config: Dict = keyformer_config_defaults,
60 | **kwargs,
61 | ):
62 | """The MPT configuration class.
63 |
64 | Args:
65 | d_model (int): The size of the embedding dimension of the model.
66 | n_heads (int): The number of attention heads.
67 | n_layers (int): The number of layers in the model.
68 | expansion_ratio (int): The ratio of the up/down scale in the MLP.
69 | max_seq_len (int): The maximum sequence length of the model.
70 | vocab_size (int): The size of the vocabulary.
71 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
72 | emb_pdrop (float): The dropout probability for the embedding layer.
73 | learned_pos_emb (bool): Whether to use learned positional embeddings
74 | attn_config (Dict): A dictionary used to configure the model's attention module:
75 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, groupquery_attention
76 | attn_pdrop (float): The dropout probability for the attention layers.
77 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
78 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
79 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
80 | this value.
81 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
82 | use the default scale of ``1/sqrt(d_keys)``.
83 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
84 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
85 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
86 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
87 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
88 | which sub-sequence each token belongs to.
89 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
90 | alibi (bool): Whether to use the alibi bias instead of position embeddings.
91 | alibi_bias_max (int): The maximum value of the alibi bias.
92 | keyformer_config (Dict): A dictionary used to configure the keyformer parameters:
93 | keyformer (bool): When enabled, keyformer based KV cache reduction
94 | kv_cache (float): KV cache percentage for Keyformer.
95 | recent (float): Recent window percentage for Keyformer.
96 | tau_init (float): Initial temperature for Keyformer score function calculation.
97 | tau_delta (float): Delta temperature change for Keyformer score function.
98 |
99 | init_device (str): The device to use for parameter initialization.
100 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
101 | no_bias (bool): Whether to use bias in all layers.
102 | verbose (int): The verbosity level. 0 is silent.
103 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
104 | norm_type (str): choose type of norm to use
105 | multiquery_attention (bool): Whether to use multiquery attention implementation.
106 | use_cache (bool): Whether or not the model should return the last key/values attentions
107 | init_config (Dict): A dictionary used to configure the model initialization:
108 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
109 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
110 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
111 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
112 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
113 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
114 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
115 | init_std (float): The standard deviation of the normal distribution used to initialize the model,
116 | if using the baseline_ parameter initialization scheme.
117 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
118 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
119 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
120 | ---
121 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
122 | """
123 | self.d_model = d_model
124 | self.n_heads = n_heads
125 | self.n_layers = n_layers
126 | self.expansion_ratio = expansion_ratio
127 | self.max_seq_len = max_seq_len
128 | self.vocab_size = vocab_size
129 | self.resid_pdrop = resid_pdrop
130 | self.emb_pdrop = emb_pdrop
131 | self.learned_pos_emb = learned_pos_emb
132 | self.attn_config = attn_config
133 | self.init_device = init_device
134 | self.logit_scale = logit_scale
135 | self.no_bias = no_bias
136 | self.verbose = verbose
137 | self.embedding_fraction = embedding_fraction
138 | self.norm_type = norm_type
139 | self.use_cache = use_cache
140 | self.init_config = init_config
141 | # ====== Keyformer ======
142 | self.keyformer_config = keyformer_config
143 | # =======================
144 | if "name" in kwargs:
145 | del kwargs["name"]
146 | if "loss_fn" in kwargs:
147 | del kwargs["loss_fn"]
148 | super().__init__(**kwargs)
149 | self._validate_config()
150 |
151 | def _set_config_defaults(self, config, config_defaults):
152 | for k, v in config_defaults.items():
153 | if k not in config:
154 | config[k] = v
155 | return config
156 |
157 | def _validate_config(self):
158 | self.attn_config = self._set_config_defaults(
159 | self.attn_config, attn_config_defaults
160 | )
161 | self.init_config = self._set_config_defaults(
162 | self.init_config, init_config_defaults
163 | )
164 | if self.d_model % self.n_heads != 0:
165 | raise ValueError("d_model must be divisible by n_heads")
166 | if any(
167 | (
168 | prob < 0 or prob > 1
169 | for prob in [
170 | self.attn_config["attn_pdrop"],
171 | self.resid_pdrop,
172 | self.emb_pdrop,
173 | ]
174 | )
175 | ):
176 | raise ValueError(
177 | "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
178 | )
179 | if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
180 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
181 | if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
182 | "torch",
183 | "triton",
184 | ]:
185 | raise NotImplementedError(
186 | "prefix_lm only implemented with torch and triton attention."
187 | )
188 | if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [
189 | "torch",
190 | "triton",
191 | ]:
192 | raise NotImplementedError(
193 | "alibi only implemented with torch and triton attention."
194 | )
195 | if self.attn_config["attn_uses_sequence_id"] and self.attn_config[
196 | "attn_impl"
197 | ] not in ["torch", "triton"]:
198 | raise NotImplementedError(
199 | "attn_uses_sequence_id only implemented with torch and triton attention."
200 | )
201 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
202 | raise ValueError(
203 | "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
204 | )
205 | if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
206 | raise ValueError(
207 | f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
208 | )
209 | if self.init_config.get("name", None) is None:
210 | raise ValueError(
211 | f"self.init_config={self.init_config!r} 'name' needs to be set."
212 | )
213 | if not self.learned_pos_emb and (not self.attn_config["alibi"]):
214 | raise ValueError(
215 | f"Positional information must be provided to the model using either learned_pos_emb or alibi."
216 | )
217 | # ====================== Keyformer parameters validation ====================
218 | if (
219 | self.keyformer_config["kv_cache"] < 0
220 | or self.keyformer_config["kv_cache"] > 100
221 | ):
222 | raise ValueError(
223 | f"KV cache percentage should be in the range from 0 to 100."
224 | )
225 | if self.keyformer_config["recent"] < 0 or self.keyformer_config["recent"] > 100:
226 | raise ValueError(f"Recent percentage should be in the range from 0 to 100.")
227 | if self.keyformer_config["tau_init"] < 0:
228 | raise ValueError(f"Initial temperature parameter can not be less than 0.")
229 | if self.keyformer_config["tau_delta"] < 0:
230 | raise ValueError(
231 | f"Delta temperature change parameter can not be less than 0."
232 | )
233 | # ===========================================================================
234 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.4.0
2 | accelerate==0.18.0
3 | aiohttp==3.8.4
4 | aiosignal==1.3.1
5 | appdirs==1.4.4
6 | async-timeout==4.0.2
7 | attrs==23.1.0
8 | certifi==2022.12.7
9 | charset-normalizer==3.1.0
10 | click==8.1.3
11 | cmake==3.26.3
12 | datasets==2.13.0
13 | dill==0.3.6
14 | docker-pycreds==0.4.0
15 | evaluate==0.4.0
16 | filelock==3.12.0
17 | fire==0.5.0
18 | frozenlist==1.3.3
19 | fsspec==2023.4.0
20 | gitdb==4.0.10
21 | gitpython==3.1.31
22 | huggingface-hub==0.14.1
23 | idna==3.4
24 | jinja2==3.1.2
25 | joblib==1.2.0
26 | lit==16.0.2
27 | markupsafe==2.1.2
28 | mpmath==1.3.0
29 | multidict==6.0.4
30 | multiprocess==0.70.14
31 | networkx==3.1
32 | nltk==3.8.1
33 | numpy==1.24.3
34 | nvidia-cublas-cu11==11.10.3.66
35 | nvidia-cuda-cupti-cu11==11.7.101
36 | nvidia-cuda-nvrtc-cu11==11.7.99
37 | nvidia-cuda-runtime-cu11==11.7.99
38 | nvidia-cudnn-cu11==8.5.0.96
39 | nvidia-cufft-cu11==10.9.0.58
40 | nvidia-curand-cu11==10.2.10.91
41 | nvidia-cusolver-cu11==11.4.0.1
42 | nvidia-cusparse-cu11==11.7.4.91
43 | nvidia-nccl-cu11==2.14.3
44 | nvidia-nvtx-cu11==11.7.91
45 | openai==0.27.6
46 | packaging==23.1
47 | pandas==2.0.1
48 | pathtools==0.1.2
49 | protobuf==4.22.4
50 | psutil==5.9.5
51 | pyarrow==12.0.0
52 | python-dateutil==2.8.2
53 | pytz==2023.3
54 | pyyaml==6.0
55 | regex==2023.5.5
56 | requests==2.30.0
57 | responses==0.18.0
58 | rouge-score==0.1.2
59 | sentencepiece==0.1.99
60 | sentry-sdk==1.21.1
61 | setproctitle==1.3.2
62 | simplejson==3.19.1
63 | six==1.16.0
64 | smmap==5.0.0
65 | sympy==1.11.1
66 | termcolor==2.3.0
67 | tokenizers==0.12.1
68 | torch==2.0.0
69 | tqdm==4.65.0
70 | transformers==4.28.1
71 | triton==2.0.0
72 | typing-extensions==4.5.0
73 | tzdata==2023.3
74 | urllib3==2.0.2
75 | wandb==0.15.1
76 | xxhash==3.2.0
77 | yarl==1.9.2
78 | einops
--------------------------------------------------------------------------------
/summarization/README.md:
--------------------------------------------------------------------------------
1 | # 📖 Summarization Task
2 |
3 | ## Download dataset
4 |
5 | You can download the respective summarization datasets using below script.
6 | ```bash
7 | cd dataset_download
8 | python download_cnndm.py
9 | ```
10 |
11 | We have provided data download script for below summarization datasets.
12 | - [CNN/DailyMail](https://huggingface.co/datasets/cnn_dailymail)
13 | - [XSUM](https://huggingface.co/datasets/EdinburghNLP/xsum)
14 | - [GovReports](https://huggingface.co/datasets/ccdv/govreport-summarization?row=0)
15 |
16 | ## Quick start with the Summarization Task
17 |
18 | You can directly run the summarization task using the following -
19 |
20 | ```
21 | bash run_summarization_task.sh
22 | ```
23 |
24 | By default, this runs `mosaicml/mpt-7b`. It assumes that the model copy with
25 | `keyformer` is stored in `models/mpt-7b-keyformer` and the dataset is stored in
26 | `data/cnn_eval.json`. For custom execution, read below.
27 |
28 |
29 | ## Summarization Task
30 |
31 | To get started with summarization task, setup the model parameters and use below script
32 |
33 | ```
34 | python summarize.py --model_name \
35 | --dataset_path \
36 | --save_path \
37 | --score_path \
38 | --model_path \
39 | --attentions_path
40 | --device cuda \ # Device
41 | --task summarization \ # Task - Currently only summarization supported
42 | --bs 1 \ # Batch Size
43 | --dtype float16 \ # Data type of model parameters
44 | --causal_lm \ # Causal language model for summarization
45 | --early_stopping \ # Enable early stopping while generation
46 | --output_summaries_only \ # Output only summary - No prompt
47 | --output_sequence_scores \ # Output sequence scores and enable for output attentions
48 | --save_attentions \ # Save attention weights - Token Generation only
49 | --save_prompt_attentions \ # If enabled only prompt attention weights are stored
50 | --padding_side left \ # Padding side
51 | --beam 4 \ # Beam Size
52 | --model_parallelize \ # Parallelize Model across all available GPUs
53 | --keyformer \ # Enable Keyformer
54 | --kv_cache 60 \ # KV cache percentage of prompt length
55 | --recent 30 \ # Recent window percentage
56 | --tau_init 1 \ # Initial temperature parameter for Gumbel
57 | --tau_end 2 \ # End temperature parameter for Gumbel
58 | --no_repeat_ngram_size 0 \
59 | --repetition_penalty 1 \
60 | --max_tokenizer_length 1920 \ # Maximum prompt size
61 | --max_new_tokens 128 \ # Maximum newly generated tokens
62 | --min_gen_length 30 \ # Minimum newly generated tokens
63 | --num_return_sequences 1 \ # Number os return summaries per input
64 | --seed 12345 \ # Random seed value for radom samples
65 | --n_obs 1000 \ # Number of input samples
66 | ```
67 |
68 | Note: For data type of FP16, do not use model.half() instead utilize dtype in model creation [Link](https://stackoverflow.com/questions/69994731/what-is-the-difference-between-cuda-amp-and-model-half)
69 |
--------------------------------------------------------------------------------
/summarization/__pycache__/dataset.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/dataset.cpython-310.pyc
--------------------------------------------------------------------------------
/summarization/__pycache__/dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/summarization/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/summarization/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/summarization/data/data_cache/.._data_data_cache_cnn_dailymail_3.0.0_0.0.0_d3306fda2d44efc3.lock:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/data/data_cache/.._data_data_cache_cnn_dailymail_3.0.0_0.0.0_d3306fda2d44efc3.lock
--------------------------------------------------------------------------------
/summarization/data/data_cache/data_cache:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/d-matrix-ai/keyformer-llm/f327831056a9909b133aaa14a48f2bd0722acede/summarization/data/data_cache/data_cache
--------------------------------------------------------------------------------
/summarization/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 | from datasets import load_dataset, load_from_disk
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from torch.nn.functional import pad
8 | from torch.utils.data import DataLoader
9 | from typing import Optional, Dict, Sequence
10 | import io
11 | import utils
12 | import copy
13 |
14 | PROMPT_DICT = {
15 | "prompt_input": (
16 | "Below is an instruction that describes a task, paired with an input that provides further context. "
17 | "Write a response that appropriately completes the request.\n\n"
18 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
19 | ),
20 | "prompt_no_input": (
21 | "Below is an instruction that describes a task. "
22 | "Write a response that appropriately completes the request.\n\n"
23 | "### Instruction:\n{instruction}\n\n### Response:"
24 | ),
25 | }
26 |
27 |
28 | class Dataset(torch.utils.data.Dataset):
29 | """Characterizes a dataset for PyTorch"""
30 |
31 | def __init__(
32 | self,
33 | dataset_path,
34 | tokenizer,
35 | model_name,
36 | return_tensors,
37 | truncation,
38 | padding,
39 | max_length=None,
40 | total_count_override=None,
41 | perf_count_override=None,
42 | ):
43 | self.dataset = "cnn_dailymail"
44 | self.dataset_path = dataset_path
45 | self.model_name = model_name
46 | self.tokenizer = tokenizer
47 | self.return_tensors = return_tensors
48 | self.truncation = truncation
49 | self.padding = padding
50 | self.max_length = max_length
51 |
52 | self.list_data_dict = utils.jload(self.dataset_path)
53 |
54 | prompt_input, prompt_no_input = (
55 | PROMPT_DICT["prompt_input"],
56 | PROMPT_DICT["prompt_no_input"],
57 | )
58 | self.sources = [
59 | prompt_input.format_map(example) for example in self.list_data_dict
60 | ]
61 | self.targets = [f"{example['output']}" for example in self.list_data_dict]
62 |
63 | # Getting random samples for evaluation
64 | if total_count_override > 0:
65 | self.rand_samples = np.random.randint(
66 | len(self.sources), size=(total_count_override)
67 | ).tolist()
68 | self.sources = [self.sources[i] for i in self.rand_samples]
69 | self.targets = [self.targets[i] for i in self.rand_samples]
70 | self.count = total_count_override
71 | else:
72 | self.count = len(self.sources)
73 |
74 | (
75 | self.source_encoded_input_ids,
76 | self.source_encoded_attn_masks,
77 | ) = self.encode_samples()
78 |
79 | self.perf_count = perf_count_override or self.count
80 |
81 | def encode_samples(self):
82 | print("Encoding Samples")
83 |
84 | total_samples = self.count
85 |
86 | source_encoded_input_ids = []
87 | source_encoded_attn_masks = []
88 |
89 | for i in range(total_samples):
90 | source_encoded = self.tokenizer(
91 | self.sources[i],
92 | return_tensors=self.return_tensors,
93 | padding=self.padding,
94 | truncation=self.truncation,
95 | max_length=self.max_length,
96 | )
97 | source_encoded_input_ids.append(source_encoded.input_ids)
98 | source_encoded_attn_masks.append(source_encoded.attention_mask)
99 |
100 | return source_encoded_input_ids, source_encoded_attn_masks
101 |
102 | def __len__(self):
103 | return self.count
104 |
105 | def __getitem__(self, index):
106 | return (
107 | self.source_encoded_input_ids[index],
108 | self.source_encoded_attn_masks[index],
109 | )
110 |
--------------------------------------------------------------------------------
/summarization/dataset_download/download_cnndm.py:
--------------------------------------------------------------------------------
1 | # experiment config
2 | dataset_id = "cnn_dailymail"
3 | dataset_config = "3.0.0"
4 | text_column = "article"
5 | summary_column = "highlights"
6 |
7 | from datasets import load_dataset, concatenate_datasets
8 | from transformers import AutoTokenizer
9 | import numpy as np
10 | import os
11 | import simplejson as json
12 | import sys
13 |
14 | save_dataset_path = os.environ.get("DATASET_CNNDM_PATH", "../data")
15 |
16 | # Check whether the specified path exists or not
17 | isExist = os.path.exists(save_dataset_path)
18 | if not isExist:
19 | # Create a new directory because it does not exist
20 | os.makedirs(save_dataset_path)
21 |
22 | # Load dataset from the hub
23 | dataset = load_dataset(dataset_id, name=dataset_config, cache_dir="../data/data_cache")
24 |
25 | instruction_template = "Summarize the following news article:"
26 |
27 |
28 | def preprocess_function(sample):
29 | # create list of samples
30 | inputs = []
31 |
32 | for i in range(0, len(sample[text_column])):
33 | x = dict()
34 | x["instruction"] = instruction_template
35 | x["input"] = sample[text_column][i]
36 | x["output"] = sample[summary_column][i]
37 | inputs.append(x)
38 | model_inputs = dict()
39 | model_inputs["text"] = inputs
40 |
41 | return model_inputs
42 |
43 |
44 | # process dataset
45 | tokenized_dataset = dataset.map(
46 | preprocess_function, batched=True, remove_columns=list(dataset["train"].features)
47 | )
48 |
49 | # save dataset to disk
50 | with open(os.path.join(save_dataset_path, "cnn_eval.json"), "w") as write_f:
51 | json.dump(
52 | tokenized_dataset["validation"]["text"], write_f, indent=4, ensure_ascii=False
53 | )
54 |
55 |
56 | print("Dataset saved in ", save_dataset_path)
57 |
--------------------------------------------------------------------------------
/summarization/dataset_download/download_longbench.py:
--------------------------------------------------------------------------------
1 | # experiment config
2 | """
3 | ====> Loading Longbanch Data
4 | ################################################################################################################
5 | from datasets import load_dataset
6 |
7 | datasets = ["narrativeqa", "qasper", "multifieldqa_en", "multifieldqa_zh", "hotpotqa", "2wikimqa", "musique", \
8 | "dureader", "gov_report", "qmsum", "multi_news", "vcsum", "trec", "triviaqa", "samsum", "lsht", \
9 | "passage_count", "passage_retrieval_en", "passage_retrieval_zh", "lcc", "repobench-p"]
10 |
11 | for dataset in datasets:
12 | data = load_dataset('THUDM/LongBench', dataset, split='test')
13 | ################################################################################################################
14 |
15 | ====> Loading Longbanch-E Data (Uniformly Sampled Data)
16 | ################################################################################################################
17 | from datasets import load_dataset
18 |
19 | datasets = ["qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "gov_report", "multi_news", "trec", \
20 | "triviaqa", "samsum", "passage_count", "passage_retrieval_en", "lcc", "repobench-p"]
21 |
22 | for dataset in datasets:
23 | data = load_dataset('THUDM/LongBench', f"{dataset}_e", split='test')
24 | ################################################################################################################
25 |
26 | ################################################################################################################
27 | {
28 | "input": "The input/command for the task, usually short, such as questions in QA, queries in Few-shot tasks, etc",
29 | "context": "The long context required for the task, such as documents, cross-file code, few-shot examples in Few-shot tasks",
30 | "answers": "A List of all true answers",
31 | "length": "Total length of the first three items (counted in characters for Chinese and words for English)",
32 | "dataset": "The name of the dataset to which this piece of data belongs",
33 | "language": "The language of this piece of data",
34 | "all_classes": "All categories in classification tasks, null for non-classification tasks",
35 | "_id": "Random id for each piece of data"
36 | }
37 | ################################################################################################################
38 | """
39 | dataset_id = "THUDM/LongBench"
40 | dataset_name = "gov_report"
41 | text_column = "context"
42 | summary_column = "answers"
43 | instruction_template_gov_reports = "You are given a report by a government agency. Write a one-page summary of the report."
44 | instruction_template_multi_news = (
45 | "You are given several news passages. Write a one-page summary of all news."
46 | )
47 | instruction_template_passage_count = (
48 | "Determine the number of unique passages among the given set."
49 | )
50 | instruction_template_lcc = "Write the next line of code."
51 |
52 | from datasets import load_dataset, concatenate_datasets
53 | from transformers import AutoTokenizer
54 | import numpy as np
55 | import os
56 | import simplejson as json
57 | import sys
58 |
59 | save_dataset_path = os.environ.get("DATASET_CNNDM_PATH", "../data")
60 |
61 | # Check whether the specified path exists or not
62 | isExist = os.path.exists(save_dataset_path)
63 | if not isExist:
64 | # Create a new directory because it does not exist
65 | os.makedirs(save_dataset_path)
66 |
67 | # Load dataset from the hub
68 | dataset = load_dataset(
69 | dataset_id, name=dataset_name, cache_dir="../data/data_cache", split="test"
70 | )
71 |
72 |
73 | def preprocess_function(sample):
74 | # create list of samples
75 | inputs = []
76 |
77 | for i in range(0, len(sample[text_column])):
78 | x = dict()
79 | # x["instruction"] = sample["input"][i]
80 | x["instruction"] = instruction_template_gov_reports
81 | x["input"] = sample[text_column][i]
82 | x["output"] = sample[summary_column][i][0]
83 | # x["classes"] = sample["all_classes"][i]
84 | inputs.append(x)
85 | model_inputs = dict()
86 | model_inputs["text"] = inputs
87 |
88 | return model_inputs
89 |
90 |
91 | # process dataset
92 | tokenized_dataset = dataset.map(preprocess_function, batched=True)
93 |
94 | # save dataset to disk
95 |
96 | file_name = dataset_name + ".json"
97 | with open(os.path.join(save_dataset_path, file_name), "w") as write_f:
98 | json.dump(tokenized_dataset["text"], write_f, indent=4, ensure_ascii=False)
99 |
100 |
101 | print("Dataset saved in ", save_dataset_path)
102 |
--------------------------------------------------------------------------------
/summarization/dataset_download/download_xsum.py:
--------------------------------------------------------------------------------
1 | # experiment config
2 | text_column = "document"
3 | summary_column = "summary"
4 |
5 | from datasets import load_dataset, concatenate_datasets
6 | from transformers import AutoTokenizer
7 | import numpy as np
8 | import os
9 | import simplejson as json
10 | import sys
11 |
12 | save_dataset_path = os.environ.get("DATASET_XSUM_PATH", "../data")
13 |
14 | # Check whether the specified path exists or not
15 | isExist = os.path.exists(save_dataset_path)
16 | if not isExist:
17 | # Create a new directory because it does not exist
18 | os.makedirs(save_dataset_path)
19 |
20 | # Load dataset from the hub
21 | dataset = load_dataset("xsum", cache_dir="../data/data_cache")
22 |
23 | instruction_template = "Provide one sentence summary of the following news article:"
24 |
25 |
26 | def preprocess_function(sample):
27 | # create list of samples
28 | inputs = []
29 | for i in range(0, len(sample[text_column])):
30 | x = dict()
31 | x["instruction"] = instruction_template
32 | x["input"] = sample[text_column][i]
33 | x["output"] = sample[summary_column][i]
34 | inputs.append(x)
35 | model_inputs = dict()
36 | model_inputs["text"] = inputs
37 |
38 | return model_inputs
39 |
40 |
41 | # process dataset
42 | tokenized_dataset = dataset.map(preprocess_function, batched=True)
43 |
44 | # save dataset to disk
45 |
46 | with open(os.path.join(save_dataset_path, "xsum_eval.json"), "w") as write_f:
47 | json.dump(
48 | tokenized_dataset["validation"]["text"], write_f, indent=4, ensure_ascii=False
49 | )
50 |
51 |
52 | print("Dataset saved in ", save_dataset_path)
53 |
--------------------------------------------------------------------------------
/summarization/dataset_synthetic_long_context.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import numpy as np
4 | import torch
5 | from datasets import load_dataset, load_from_disk
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | from torch.nn.functional import pad
8 | from torch.utils.data import DataLoader
9 | from typing import Optional, Dict, Sequence
10 | import io
11 | import utils
12 | import copy
13 |
14 | PROMPT_DICT = {
15 | "prompt_input": (
16 | "Below is an instruction that describes a task, paired with an input that provides further context. "
17 | "Write a response that appropriately completes the request.\n\n"
18 | "### Instruction:\n{instruction}\n\n### Input:\n{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}{input}\n\n### Response:"
19 | ),
20 | "prompt_no_input": (
21 | "Below is an instruction that describes a task. "
22 | "Write a response that appropriately completes the request.\n\n"
23 | "### Instruction:\n{instruction}\n\n### Response:"
24 | ),
25 | }
26 |
27 |
28 | class Dataset(torch.utils.data.Dataset):
29 | """Characterizes a dataset for PyTorch"""
30 |
31 | def __init__(
32 | self,
33 | dataset_path,
34 | tokenizer,
35 | model_name,
36 | return_tensors,
37 | truncation,
38 | padding,
39 | max_length=None,
40 | total_count_override=None,
41 | perf_count_override=None,
42 | ):
43 | self.dataset = "cnn_dailymail"
44 | self.dataset_path = dataset_path
45 | self.model_name = model_name
46 | self.tokenizer = tokenizer
47 | self.return_tensors = return_tensors
48 | self.truncation = truncation
49 | self.padding = padding
50 | self.max_length = max_length
51 |
52 | self.list_data_dict = utils.jload(self.dataset_path)
53 |
54 | prompt_input, prompt_no_input = (
55 | PROMPT_DICT["prompt_input"],
56 | PROMPT_DICT["prompt_no_input"],
57 | )
58 | self.sources = [
59 | prompt_input.format_map(example) for example in self.list_data_dict
60 | ]
61 | self.targets = [f"{example['output']}" for example in self.list_data_dict]
62 |
63 | # Getting random samples for evaluation
64 | if total_count_override > 0:
65 | self.rand_samples = np.random.randint(
66 | len(self.sources), size=(total_count_override)
67 | ).tolist()
68 | self.sources = [self.sources[i] for i in self.rand_samples]
69 | self.targets = [self.targets[i] for i in self.rand_samples]
70 | self.count = total_count_override
71 | else:
72 | self.count = len(self.sources)
73 |
74 | (
75 | self.source_encoded_input_ids,
76 | self.source_encoded_attn_masks,
77 | ) = self.encode_samples()
78 |
79 | self.perf_count = perf_count_override or self.count
80 |
81 | def encode_samples(self):
82 | print("Encoding Samples")
83 |
84 | total_samples = self.count
85 |
86 | source_encoded_input_ids = []
87 | source_encoded_attn_masks = []
88 |
89 | for i in range(total_samples):
90 | source_encoded = self.tokenizer(
91 | self.sources[i],
92 | return_tensors=self.return_tensors,
93 | padding=self.padding,
94 | truncation=self.truncation,
95 | max_length=self.max_length,
96 | )
97 | source_encoded_input_ids.append(source_encoded.input_ids)
98 | source_encoded_attn_masks.append(source_encoded.attention_mask)
99 |
100 | return source_encoded_input_ids, source_encoded_attn_masks
101 |
102 | def __len__(self):
103 | return self.count
104 |
105 | def __getitem__(self, index):
106 | return (
107 | self.source_encoded_input_ids[index],
108 | self.source_encoded_attn_masks[index],
109 | )
110 |
--------------------------------------------------------------------------------
/summarization/out_model.score:
--------------------------------------------------------------------------------
1 | {"rouge1": 34.4293, "rouge2": 17.7591, "rougeL": 23.474, "rougeLsum": 32.8713, "gen_len": 5605.0, "gen_num": 10.0}
--------------------------------------------------------------------------------
/summarization/out_model.summary:
--------------------------------------------------------------------------------
1 | Douglas Costa will spark a transfer scramble this summer with Shakhtar Donetsk ready to sell their prized-asset. Chelsea manager Jose Mourinho is a known admirer of the Brazil international having tried to land the midfielder in the previous transfer windows. Shakhtar chiefs are now open to selling Costa this summer and talks with third parties over his departure are underway. Brazil international Douglas Costa is set to depart Shakhtar Donetsk for £25million in the summer . Midfielder Costa (left) could spark a bidding war from Chelsea, Real Madrid and Barcelona . The 24-year-old
2 | ### Solution: #include #include #include #include #include #include