├── LICENSE
├── README.md
├── applications
├── chatbot.py
├── eval_classeval.py
├── eval_cnndm.py
├── eval_humaneval.py
├── eval_mtbench.py
├── eval_xsum.py
├── run_chat.sh
└── run_mtbench.sh
├── lade
├── __init__.py
├── decoding.py
├── lade_distributed.py
├── models
│ └── modeling_llama.py
└── utils.py
├── media
├── acc-demo.gif
├── jacobi-iteration.gif
├── lookahead-decoding.gif
├── lookahead-perf.png
└── mask.png
├── minimal-flash.py
├── minimal.py
├── requirements.txt
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Break the Sequential Dependency of LLM Inference Using Lookahead Decoding
2 |
3 |
4 | | Paper | Blog | Roadmap |
5 |
6 |
7 | ---
8 | *News* 🔥
9 | - [2024/2] Lookahead Decoding Paper now available on [arXiv](https://arxiv.org/abs/2402.02057). [Sampling](#use-lookahead-decoding-in-your-own-code) and [FlashAttention](#flashAttention-support) are supported. Advanced features for better token prediction are updated.
10 |
11 | ---
12 | ## Introduction
13 | We introduce lookahead decoding:
14 | - A parallel decoding algorithm to accelerate LLM inference.
15 | - Without the need for a draft model or a data store.
16 | - Linearly decreases #decoding steps relative to log(FLOPs) used per decoding step.
17 |
18 | Below is a demo of lookahead decoding accelerating LLaMa-2-Chat 7B generation:
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 | Demo of speedups by lookahead decoding on LLaMA-2-Chat 7B generation. Blue fonts are tokens generated in parallel in a decoding step.
27 |
28 |
29 |
30 |
31 | ### Background: Parallel LLM Decoding Using Jacobi Iteration
32 |
33 | Lookahead decoding is motivated by [Jacobi decoding](https://arxiv.org/pdf/2305.10427.pdf), which views autoregressive decoding as solving nonlinear systems and decodes all future tokens simultaneously using a fixed-point iteration method. Below is a Jacobi decoding example.
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | Illustration of applying Jacobi iteration method for parallel LLM decoding.
42 |
43 |
44 |
45 |
46 | However, Jacobi decoding can barely see wall-clock speedup in real-world LLM applications.
47 |
48 | ### Lookahead Decoding: Make Jacobi Decoding Feasible
49 |
50 | Lookahead decoding takes advantage of Jacobi decoding's ability by collecting and caching n-grams generated from Jacobi iteration trajectories.
51 |
52 | The following gif shows the process of collecting 2 grams via Jacobi decoding and verifying them to accelerate decoding.
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 | Illustration of lookahead decoding with 2-grams.
61 |
62 |
63 |
64 |
65 | To enhance the efficiency of this process, each lookahead decoding step is divided into two parallel branches: the lookahead branch and the verification branch. The lookahead branch maintains a fixed-sized, 2D window to generate n-grams from the Jacobi iteration trajectory. Simultaneously, the verification branch selects and verifies promising n-gram candidates.
66 |
67 | ### Lookahead Branch and Verification Branch
68 |
69 | The lookahead branch aims to generate new N-grams. The branch operates with a two-dimensional window defined by two parameters:
70 | - Window size W: How far ahead we look in future token positions to conduct parallel decoding.
71 | - N-gram size N: How many steps we look back into the past Jacobi iteration trajectory to retrieve n-grams.
72 |
73 | In the verification branch, we identify n-grams whose first token matches the last input token. This is determined via simple string match. Once identified, these n-grams are appended to the current input and subjected to verification via an LLM forward pass through them.
74 |
75 | We implement these branches in one attention mask to further utilize GPU's parallel computing power.
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | Attention mask for lookahead decoding with 4-grams and window size 5. In this mask, two 4-gram candidates (bottom right) are verified concurrently with parallel decoding.
84 |
85 |
86 |
87 |
88 | ### Experimental Results
89 |
90 | Our study shows lookahead decoding substantially reduces latency, ranging from 1.5x to 2.3x on different datasets on a single GPU. See the figure below.
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | Speedup of lookahead decoding on different models and datasets.
99 |
100 |
101 |
102 |
103 | ## Contents
104 | - [Introduction](#introduction)
105 | - [Contents](#contents)
106 | - [Installation](#installation)
107 | - [Install With Pip](#install-with-pip)
108 | - [Install From The Source](#install-from-the-source)
109 | - [Inference](#inference-with-lookahead-decoding)
110 | - [Use In Your Own Code](#use-lookahead-decoding-in-your-own-code)
111 | - [Citation](#citation)
112 | - [Guidance](#guidance)
113 |
114 |
115 | ## Installation
116 | ### Install with pip
117 | ```bash
118 | pip install lade
119 | ```
120 | ### Install from the source
121 | ```bash
122 | git clone https://github.com/hao-ai-lab/LookaheadDecoding.git
123 | cd LookaheadDecoding
124 | pip install -r requirements.txt
125 | pip install -e .
126 | ```
127 |
128 | ### Inference With Lookahead decoding
129 | You can run the minimal example to see the speedup that Lookahead decoding brings.
130 | ```bash
131 | python minimal.py #no Lookahead decoding
132 | USE_LADE=1 LOAD_LADE=1 python minimal.py #use Lookahead decoding, 1.6x speedup
133 | ```
134 |
135 | You can also enjoy chatting with your own chatbots with Lookahead decoding.
136 | ```bash
137 | USE_LADE=1 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat #chat, with lookahead
138 | USE_LADE=0 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat #chat, without lookahead
139 |
140 |
141 | USE_LADE=1 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug #no chat, with lookahead
142 | USE_LADE=0 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug #no chat, without lookahead
143 | ```
144 |
145 | ### Use Lookahead decoding in your own code
146 | You can import and use Lookahead decoding in your own code in three LoCs. You also need to set ```USE_LADE=1``` in command line or set ```os.environ["USE_LADE"]="1"``` in Python script. Note that Lookahead decoding only support LLaMA yet.
147 |
148 | ```python
149 | import lade
150 | lade.augment_all()
151 | lade.config_lade(LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7, DEBUG=0)
152 | #LEVEL, WINDOW_SIZE and GUESS_SET_SIZE are three important configurations (N,W,G) in lookahead decoding, please refer to our blog!
153 | #You can obtain a better performance by tuning LEVEL/WINDOW_SIZE/GUESS_SET_SIZE on your own device.
154 | ```
155 |
156 | Then you can speedup the decoding process. Here is an example using greedy search:
157 | ```
158 | tokenizer = AutoTokenizer.from_pretrained(model_name)
159 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
160 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
161 | greedy_output = model.generate(**model_inputs, max_new_tokens=1024) #speedup obtained
162 | ```
163 |
164 | Here is an example using sampling:
165 | ```
166 | tokenizer = AutoTokenizer.from_pretrained(model_name)
167 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
168 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
169 | sample_output = model.generate(**model_inputs, max_new_tokens=1024, temperature=0.7) #speedup obtained
170 | ```
171 |
172 | ### FlashAttention Support
173 | Install the original FlashAttention
174 | ```bash
175 | pip install flash-attn==2.3.3 #original FlashAttention
176 | ```
177 | Two ways to install FlashAttention specialized for Lookahead Decoding
178 | 1) Download a pre-built package on https://github.com/Viol2000/flash-attention-lookahead/releases/tag/v2.3.3 and install (fast, recommended).
179 | For example, I have cuda==11.8, python==3.9 and torch==2.1, I should do the following:
180 | ```bash
181 | wget https://github.com/Viol2000/flash-attention-lookahead/releases/download/v2.3.3/flash_attn_lade-2.3.3+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
182 | pip install flash_attn_lade-2.3.3+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl
183 | ```
184 | 2) Install from the source (slow, not recommended)
185 | ```bash
186 | git clone https://github.com/Viol2000/flash-attention-lookahead.git
187 | cd flash-attention-lookahead && python setup.py install
188 | ```
189 |
190 | Here is an example script to run the models with FlashAttention:
191 | ```bash
192 | python minimal-flash.py #no Lookahead decoding, w/ FlashAttention
193 | USE_LADE=1 LOAD_LADE=1 python minimal-flash.py #use Lookahead decoding, w/ FlashAttention, 20% speedup than w/o FlashAttention
194 | ```
195 |
196 | In your own code, you need to set ```USE_FLASH=True``` when calling ```config_lade```, and set ```attn_implementation="flash_attention_2"``` when calling ```AutoModelForCausalLM.from_pretrained```.
197 | ```python
198 | import lade
199 | lade.augment_all()
200 | lade.config_lade(LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7, USE_FLASH=True, DEBUG=0)
201 | tokenizer = AutoTokenizer.from_pretrained(model_name)
202 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device, attn_implementation="flash_attention_2")
203 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
204 | greedy_output = model.generate(**model_inputs, max_new_tokens=1024) #speedup obtained
205 | ```
206 | We will integrate FlashAttention directly into this repo for simple installation and usage.
207 |
208 | ## Citation
209 | ```bibtex
210 | @article{fu2024break,
211 | title={Break the sequential dependency of llm inference using lookahead decoding},
212 | author={Fu, Yichao and Bailis, Peter and Stoica, Ion and Zhang, Hao},
213 | journal={arXiv preprint arXiv:2402.02057},
214 | year={2024}
215 | }
216 | ```
217 | ## Guidance
218 | The core implementation is in decoding.py. Lookahead decoding requires an adaptation for each specific model. An example is in models/llama.py.
219 |
220 |
--------------------------------------------------------------------------------
/applications/chatbot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import subprocess
4 | import lade
5 | from lade.utils import get_model
6 | import time, os
7 |
8 | if __name__ == "__main__":
9 | if int(os.environ.get("USE_LADE", 0)):
10 | lade.augment_all()
11 | lade.config_lade(LEVEL=5, WINDOW_SIZE=15, GUESS_SET_SIZE=15, DEBUG=1) #A100
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--local_rank", type=int, default=0)
15 | parser.add_argument("--model_path", type=str, help="model path", default="meta-llama/Llama-2-7b-chat-hf") #tiiuae/falcon-7b-instruct #"TheBloke/Falcon-180B-Chat-GPTQ"
16 | parser.add_argument("--model_type", type=str, default="llama")
17 | parser.add_argument("--quant", type=str, default="")
18 | parser.add_argument("--use_ds", action="store_true")
19 | parser.add_argument("--debug", action="store_true")
20 | parser.add_argument("--chat", action="store_true")
21 | parser.add_argument("--dtype", type=str, default="float16")
22 | parser.add_argument("--device", type=str, default="cuda:0")
23 | parser.add_argument("--cache_dir", type=str, default="")
24 | parser.add_argument(
25 | "--max_new_tokens",
26 | type=int,
27 | default=128,
28 | help="Maximum new tokens to generate per response",
29 | )
30 | args = parser.parse_args()
31 |
32 | if args.dtype == "float16":
33 | args.dtype = torch.float16
34 | elif args.dtype == "bfloat16":
35 | args.dtype = torch.bfloat16
36 |
37 | #if args.use_ds:
38 | model, tokenizer = get_model(args.model_path, args.quant, args.dtype, args.device, args.cache_dir, args.use_ds, False)
39 |
40 | user_input = ""
41 | num_rounds = 0
42 | if args.model_type == "llama":
43 | roles = ("[INST]", "[/INST]") #support llama2 only
44 | else:
45 | assert False
46 |
47 | user_input = ""
48 | if args.model_type == "llama":
49 | system_prompt = "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n\n"
50 |
51 | first_time = True
52 | while True:
53 | num_rounds += 1
54 | if args.chat:
55 | model_input = input("User: ")
56 | else:
57 | model_input = '''Which methods did Socrates employ to challenge the prevailing thoughts of his time?'''
58 | print("User: " + model_input)
59 | if system_prompt is not None and first_time:
60 | if args.model_type == "llama":
61 | new_inputs = system_prompt + f"{model_input}\n {roles[1]} "
62 | new_inputs = "[INST]" + f"{model_input}\n {roles[1]} "
63 | first_time = False
64 | else:
65 | new_inputs = f"{roles[0]}: {model_input}\n {roles[1]}: "
66 | user_input += new_inputs
67 |
68 | generate_kwargs = dict(max_new_tokens=1024, do_sample=False, stop_token=None, top_p=1.0, temperature=1.0) #greedy
69 |
70 | print("Assistant: " , flush=True, end="")
71 | input_ids = tokenizer(user_input, return_tensors="pt",
72 | max_length=1024, truncation=True).input_ids.to(args.device)
73 |
74 | if not args.chat:
75 | lade.config_lade(DEBUG=0)
76 | tmp_kwargs = dict(max_new_tokens=1, do_sample=False, stop_token=None, top_p=1.0, temperature=1.0)
77 | tmp_greedy_output = model.generate(input_ids=input_ids, **tmp_kwargs).tolist() #warmup
78 | lade.config_lade(DEBUG=1)
79 |
80 | os.environ["CHAT"] = "0"
81 | t0 = time.time()
82 | greedy_output = model.generate(input_ids=input_ids, **generate_kwargs).tolist()
83 |
84 | t1 = time.time()
85 | os.environ["CHAT"] = "0"
86 | output = tokenizer.decode(greedy_output[0], skip_special_tokens=False)
87 | print(output)
88 | user_input = f"{output}\n\n"
89 |
90 | if args.debug:
91 | generated_tokens = len(greedy_output[0]) - input_ids.numel()
92 | print()
93 | print("======================================SUMMARY=========================================")
94 | print("Input tokens: ", input_ids.numel() ,"Generated tokens: ", generated_tokens,"Time: ", round(t1 - t0, 2), "s Throughput: ", round(generated_tokens / (t1 - t0), 2), "tokens/s")
95 | print("======================================================================================")
96 | #print("\n\n\n\n")
97 | if not args.chat:
98 | break
99 |
--------------------------------------------------------------------------------
/applications/eval_classeval.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py
7 | import argparse
8 | import json
9 | import os
10 | import random
11 | import time
12 |
13 | import shortuuid
14 | import torch
15 | from tqdm import tqdm
16 | from typing import Dict, List, Optional
17 | from fastchat.llm_judge.common import load_questions, temperature_config
18 | from fastchat.model import get_conversation_template
19 | from fastchat.utils import str_to_torch_dtype
20 | import time
21 | import lade
22 | from datasets import load_dataset
23 |
24 | def run_eval(
25 | model_path,
26 | model_id,
27 | question_file,
28 | question_begin,
29 | question_end,
30 | answer_file,
31 | max_new_token,
32 | num_choices,
33 | num_gpus_per_model,
34 | num_gpus_total,
35 | max_gpu_memory,
36 | dtype,
37 | debug,
38 | cache_dir,
39 | cpu_offloading,
40 | use_pp,
41 | use_tp,
42 | use_tp_ds,
43 | use_flash,
44 | do_sample
45 | ):
46 | #questions = load_questions(question_file, question_begin, question_end)
47 | ClassEval = load_dataset("FudanSELab/ClassEval")
48 | questions = ClassEval["test"]
49 | # random shuffle the questions to balance the loading
50 | ###not shuffle
51 | #random.shuffle(questions)
52 |
53 | # Split the question file into `num_gpus` files
54 | assert num_gpus_total % num_gpus_per_model == 0
55 |
56 | get_answers_func = get_model_answers
57 |
58 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
59 | ans_handles = []
60 | for i in range(0, len(questions), chunk_size):
61 | ans_handles.append(
62 | get_answers_func(
63 | model_path,
64 | model_id,
65 | questions[i : i + chunk_size],
66 | question_end,
67 | answer_file,
68 | max_new_token,
69 | num_choices,
70 | num_gpus_per_model,
71 | max_gpu_memory,
72 | dtype=dtype,
73 | debug=debug,
74 | cache_dir=cache_dir,
75 | cpu_offloading=cpu_offloading,
76 | use_tp=use_tp,
77 | use_pp=use_pp,
78 | use_tp_ds=use_tp_ds,
79 | use_flash=use_flash,
80 | do_sample=do_sample
81 | )
82 | )
83 |
84 |
85 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM
86 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration
87 |
88 | def load_model(
89 | model_path: str,
90 | device: str = "cuda",
91 | device_map: str= "",
92 | num_gpus: int = 1,
93 | max_gpu_memory: Optional[str] = None,
94 | dtype: Optional[torch.dtype] = None,
95 | load_8bit: bool = False,
96 | cpu_offloading: bool = False,
97 | revision: str = "main",
98 | debug: bool = False,
99 | use_flash:bool = False
100 | ):
101 | """Load a model from Hugging Face."""
102 | # get model adapter
103 | adapter = Llama2Adapter()
104 | # Handle device mapping
105 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
106 | device, load_8bit, cpu_offloading
107 | )
108 | if device == "cpu":
109 | kwargs = {"torch_dtype": torch.float32}
110 | if CPU_ISA in ["avx512_bf16", "amx"]:
111 | try:
112 | import intel_extension_for_pytorch as ipex
113 |
114 | kwargs = {"torch_dtype": torch.bfloat16}
115 | except ImportError:
116 | warnings.warn(
117 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
118 | )
119 | elif device.startswith("cuda"):
120 | kwargs = {"torch_dtype": torch.float16}
121 | if num_gpus != 1:
122 | kwargs["device_map"] = "auto"
123 | if max_gpu_memory is None:
124 | kwargs[
125 | "device_map"
126 | ] = "sequential" # This is important for not the same VRAM sizes
127 | available_gpu_memory = get_gpu_memory(num_gpus)
128 | kwargs["max_memory"] = {
129 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
130 | for i in range(num_gpus)
131 | }
132 | else:
133 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
134 |
135 | if cpu_offloading:
136 | # raises an error on incompatible platforms
137 | from transformers import BitsAndBytesConfig
138 |
139 | if "max_memory" in kwargs:
140 | kwargs["max_memory"]["cpu"] = (
141 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
142 | )
143 | kwargs["quantization_config"] = BitsAndBytesConfig(
144 | load_in_8bit_fp32_cpu_offload=cpu_offloading
145 | )
146 | kwargs["load_in_8bit"] = load_8bit
147 | elif load_8bit:
148 | if num_gpus != 1:
149 | warnings.warn(
150 | "8-bit quantization is not supported for multi-gpu inference."
151 | )
152 | else:
153 | model, tokenizer = adapter.load_compress_model(
154 | model_path=model_path,
155 | device=device,
156 | torch_dtype=kwargs["torch_dtype"],
157 | revision=revision,
158 | )
159 | if debug:
160 | print(model)
161 | return model, tokenizer
162 | kwargs["revision"] = revision
163 |
164 | if dtype is not None: # Overwrite dtype if it is provided in the arguments.
165 | kwargs["torch_dtype"] = dtype
166 | if use_flash:
167 | kwargs["use_flash_attention_2"] = use_flash
168 | if len(device_map) > 0:
169 | kwargs["device_map"] = device_map
170 | # Load model
171 | model, tokenizer = adapter.load_model(model_path, kwargs)
172 |
173 | if len(device_map) > 0:
174 | return model, tokenizer
175 |
176 | if (
177 | device == "cpu"
178 | and kwargs["torch_dtype"] is torch.bfloat16
179 | and CPU_ISA is not None
180 | ):
181 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
182 |
183 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in (
184 | "mps",
185 | "xpu",
186 | "npu",
187 | ):
188 | model.to(device)
189 |
190 | if device == "xpu":
191 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True)
192 |
193 | if debug:
194 | print(model)
195 |
196 | return model, tokenizer
197 |
198 | #@torch.inference_mode()
199 | def get_model_answers(
200 | model_path,
201 | model_id,
202 | questions,
203 | question_end,
204 | answer_file,
205 | max_new_token,
206 | num_choices,
207 | num_gpus_per_model,
208 | max_gpu_memory,
209 | dtype,
210 | debug,
211 | cache_dir,
212 | cpu_offloading,
213 | use_pp,
214 | use_tp_ds,
215 | use_tp,
216 | use_flash,
217 | do_sample
218 | ):
219 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
220 |
221 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices)
222 | #tokenizer = AutoTokenizer.from_pretrained(model_path)
223 | #cfg = AutoConfig.from_pretrained(model_path)
224 | #cfg._flash_attn_2_enabled= use_flash
225 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
226 | if use_pp:
227 | model, tokenizer = load_model(
228 | model_path,
229 | use_flash=use_flash,
230 | device=f"cuda",
231 | device_map="balanced",
232 | num_gpus=num_gpus_per_model,
233 | max_gpu_memory=max_gpu_memory,
234 | dtype=dtype,
235 | load_8bit=False,
236 | cpu_offloading=cpu_offloading,
237 | debug=debug,
238 | )
239 |
240 | elif use_tp_ds:
241 | import deepspeed
242 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0')))
243 | model, tokenizer = load_model(
244 | model_path,
245 | use_flash=use_flash,
246 | device_map="cpu",
247 | num_gpus=num_gpus_per_model,
248 | max_gpu_memory=max_gpu_memory,
249 | dtype=dtype,
250 | load_8bit=False,
251 | cpu_offloading=cpu_offloading,
252 | debug=debug,
253 | )
254 | model = deepspeed.init_inference(
255 | model,
256 | mp_size=int(os.getenv("WORLD_SIZE", "1")),
257 | dtype=torch.half
258 | )
259 | else:
260 | model, tokenizer = load_model(
261 | model_path,
262 | use_flash=use_flash,
263 | device=f"cuda:{lade.get_device()}",
264 | num_gpus=num_gpus_per_model,
265 | max_gpu_memory=max_gpu_memory,
266 | dtype=dtype,
267 | load_8bit=False,
268 | cpu_offloading=cpu_offloading,
269 | debug=debug,
270 | )
271 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device())
272 | model.tokenizer = tokenizer
273 |
274 | overall_time = 0
275 | overall_tp = 0
276 | overall_gen = 0
277 | count_gen = 0
278 | stats = {}
279 | for question_idx, description in enumerate(tqdm(questions["class_description"][:question_end])):
280 | if not do_sample:
281 | temperature = 0.0 #force greedy
282 |
283 | stats[question_idx] = {} #
284 | choices = []
285 | for i in range(num_choices):
286 | torch.manual_seed(i)
287 | conv = get_conversation_template(model_id)
288 | turns = []
289 | prompts = []
290 |
291 | for j in range(1):
292 | qs = ""
293 |
294 | import_stat = '\n'.join(questions["import_statement"][question_idx])
295 | qs += import_stat
296 |
297 | class_init = questions["class_constructor"][question_idx]
298 | class_init_list = class_init.split('\n')
299 | class_init_list[0] += " \n" + description
300 | class_init = '\n'.join(class_init_list)
301 |
302 | qs = qs + "\n" + class_init
303 | prompt = qs
304 |
305 | input_ids = tokenizer(prompt, return_tensors="pt",
306 | max_length=1024, truncation=True).input_ids.to("cuda")
307 |
308 | if temperature < 1e-4:
309 | do_sample = False
310 | else:
311 | do_sample = True
312 |
313 |
314 | # some models may error out when generating long outputs
315 | if True:
316 | start_time = time.time()
317 | output_ids = model.generate(
318 | input_ids,
319 | do_sample=do_sample,
320 | temperature=temperature,
321 | max_new_tokens=max_new_token,
322 | )
323 | end_time = time.time()
324 | gap_time = end_time - start_time
325 | tokens = output_ids.numel() - input_ids.numel()
326 | overall_time += gap_time
327 | overall_gen += tokens
328 | overall_tp += tokens / gap_time
329 | count_gen += 1
330 |
331 | stats[question_idx][j] = [gap_time, tokens]
332 | if lade.get_device() == 0 and ds_local_rank == 0:
333 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time])
334 |
335 | output = tokenizer.decode(
336 | output_ids[0].tolist(),
337 | skip_special_tokens=False,
338 | )
339 |
340 | turns.append(output)
341 | prompts.append(prompt)
342 |
343 | choices.append({"index": i, "turns": turns, "prompts" : prompts})
344 |
345 | if lade.get_device() == 0 and ds_local_rank == 0:
346 | # Dump answers
347 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
348 | with open(os.path.expanduser(answer_file), "a") as fout:
349 | ans_json = {
350 | "question_id": question_idx,
351 | "answer_id": shortuuid.uuid(),
352 | "model_id": model_id,
353 | "choices": choices,
354 | "tstamp": time.time(),
355 | }
356 | fout.write(json.dumps(ans_json) + "\n")
357 | #if question_idx == 1:
358 | # break
359 |
360 | if lade.get_device() == 0 and ds_local_rank == 0:
361 | torch.save(stats[question_idx], answer_file + ".pt")
362 | print("LOG SAVE TO ", answer_file + ".pt")
363 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}")
364 | lade.log_history()
365 | lade.save_log(answer_file + "-lade-log.pt")
366 |
367 |
368 | def reorg_answer_file(answer_file):
369 | """Sort by question id and de-duplication"""
370 | answers = {}
371 | with open(answer_file, "r") as fin:
372 | for l in fin:
373 | qid = json.loads(l)["question_id"]
374 | answers[qid] = l
375 |
376 | qids = sorted(list(answers.keys()))
377 | with open(answer_file, "w") as fout:
378 | for qid in qids:
379 | fout.write(answers[qid])
380 |
381 |
382 | if __name__ == "__main__":
383 | parser = argparse.ArgumentParser()
384 | parser.add_argument(
385 | "--model-path",
386 | type=str,
387 | required=True,
388 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
389 | )
390 | parser.add_argument(
391 | "--model-id", type=str, required=True, help="A custom name for the model."
392 | )
393 | parser.add_argument(
394 | "--cache-dir",
395 | type=str,
396 | default="",
397 | )
398 | parser.add_argument(
399 | "--debug",
400 | action="store_true",
401 | )
402 | parser.add_argument(
403 | "--bench-name",
404 | type=str,
405 | default="classeval",
406 | help="The name of the benchmark question set.",
407 | )
408 | parser.add_argument(
409 | "--question-begin",
410 | type=int,
411 | help="A debug option. The begin index of questions.",
412 | )
413 | parser.add_argument(
414 | "--question-end", type=int, help="A debug option. The end index of questions."
415 | )
416 | parser.add_argument(
417 | "--cpu_offloading", action="store_true"
418 | )
419 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
420 | parser.add_argument(
421 | "--max-new-token",
422 | type=int,
423 | default=2048,
424 | help="The maximum number of new generated tokens.",
425 | )
426 | parser.add_argument(
427 | "--num-choices",
428 | type=int,
429 | default=1,
430 | help="How many completion choices to generate.",
431 | )
432 | parser.add_argument(
433 | "--num-gpus-per-model",
434 | type=int,
435 | default=1,
436 | help="The number of GPUs per model.",
437 | )
438 | parser.add_argument(
439 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
440 | )
441 | parser.add_argument(
442 | "--max-gpu-memory",
443 | type=str,
444 | help="Maxmum GPU memory used for model weights per GPU.",
445 | )
446 | parser.add_argument(
447 | "--dtype",
448 | type=str,
449 | choices=["float32", "float64", "float16", "bfloat16"],
450 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
451 | default=None,
452 | )
453 | parser.add_argument(
454 | "--local_rank",
455 | type=int,
456 | default=0,
457 | )
458 | parser.add_argument(
459 | "--local-rank",
460 | type=int,
461 | default=0,
462 | )
463 | parser.add_argument(
464 | "--level",
465 | type=int,
466 | default=3,
467 | )
468 | parser.add_argument(
469 | "--window",
470 | type=int,
471 | default=10,
472 | )
473 | parser.add_argument(
474 | "--guess",
475 | type=int,
476 | default=10,
477 | )
478 | parser.add_argument(
479 | "--use-tp",
480 | type=int,
481 | default=0,
482 | )
483 | parser.add_argument(
484 | "--use-pp",
485 | type=int,
486 | default=0,
487 | )
488 | parser.add_argument(
489 | "--use-tp-ds",
490 | type=int,
491 | default=0,
492 | )
493 | parser.add_argument(
494 | "--use-flash",
495 | type=int,
496 | default=0,
497 | )
498 | parser.add_argument(
499 | "--do-sample",
500 | type=int,
501 | default=0,
502 | )
503 |
504 | args = parser.parse_args()
505 | if int(os.environ.get("USE_LADE", 0)):
506 |
507 | lade.augment_all()
508 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(",")))
509 | print("lade activated config: ", lade.decoding.CONFIG_MAP)
510 |
511 | question_file = f"mtbench.jsonl"
512 | if args.answer_file:
513 | answer_file = args.answer_file
514 | else:
515 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
516 |
517 | print(f"Output to {answer_file}")
518 |
519 | run_eval(
520 | model_path=args.model_path,
521 | model_id=args.model_id,
522 | question_file=question_file,
523 | question_begin=args.question_begin,
524 | question_end=args.question_end,
525 | answer_file=answer_file,
526 | max_new_token=args.max_new_token,
527 | num_choices=args.num_choices,
528 | num_gpus_per_model=args.num_gpus_per_model,
529 | num_gpus_total=args.num_gpus_total,
530 | max_gpu_memory=args.max_gpu_memory,
531 | dtype=str_to_torch_dtype(args.dtype),
532 | debug=args.debug,
533 | cache_dir=args.cache_dir,
534 | cpu_offloading=args.cpu_offloading,
535 | use_pp=args.use_pp,
536 | use_tp_ds=args.use_tp_ds,
537 | use_tp=args.use_tp,
538 | use_flash=args.use_flash,
539 | do_sample=args.do_sample
540 | )
541 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
542 | if lade.get_device() == 0 and ds_local_rank == 0:
543 | reorg_answer_file(answer_file)
544 |
--------------------------------------------------------------------------------
/applications/eval_cnndm.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py
7 | import argparse
8 | import json
9 | import os
10 | import random
11 | import time
12 |
13 | import shortuuid
14 | import torch
15 | from tqdm import tqdm
16 | from typing import Dict, List, Optional
17 | from fastchat.llm_judge.common import load_questions, temperature_config
18 | from fastchat.model import get_conversation_template
19 | from fastchat.utils import str_to_torch_dtype
20 | import time
21 | import lade
22 | from datasets import load_dataset
23 |
24 | def run_eval(
25 | model_path,
26 | model_id,
27 | question_file,
28 | question_begin,
29 | question_end,
30 | answer_file,
31 | max_new_token,
32 | num_choices,
33 | num_gpus_per_model,
34 | num_gpus_total,
35 | max_gpu_memory,
36 | dtype,
37 | debug,
38 | cache_dir,
39 | cpu_offloading,
40 | use_pp,
41 | use_tp,
42 | use_tp_ds,
43 | use_flash,
44 | do_sample
45 | ):
46 | questions = load_dataset("cnn_dailymail", "3.0.0", split="validation", streaming=False)["article"][question_begin:question_end]
47 | # random shuffle the questions to balance the loading
48 | ###not shuffle
49 | #random.shuffle(questions)
50 |
51 | # Split the question file into `num_gpus` files
52 | assert num_gpus_total % num_gpus_per_model == 0
53 |
54 | get_answers_func = get_model_answers
55 |
56 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
57 | ans_handles = []
58 | for i in range(0, len(questions), chunk_size):
59 | ans_handles.append(
60 | get_answers_func(
61 | model_path,
62 | model_id,
63 | questions[i : i + chunk_size],
64 | answer_file,
65 | max_new_token,
66 | num_choices,
67 | num_gpus_per_model,
68 | max_gpu_memory,
69 | dtype=dtype,
70 | debug=debug,
71 | cache_dir=cache_dir,
72 | cpu_offloading=cpu_offloading,
73 | use_tp=use_tp,
74 | use_pp=use_pp,
75 | use_tp_ds=use_tp_ds,
76 | use_flash=use_flash,
77 | do_sample=do_sample
78 | )
79 | )
80 |
81 |
82 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM
83 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration
84 |
85 | def load_model(
86 | model_path: str,
87 | device: str = "cuda",
88 | device_map: str= "",
89 | num_gpus: int = 1,
90 | max_gpu_memory: Optional[str] = None,
91 | dtype: Optional[torch.dtype] = None,
92 | load_8bit: bool = False,
93 | cpu_offloading: bool = False,
94 | revision: str = "main",
95 | debug: bool = False,
96 | use_flash:bool = False
97 | ):
98 | """Load a model from Hugging Face."""
99 | # get model adapter
100 | adapter = Llama2Adapter()
101 | # Handle device mapping
102 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
103 | device, load_8bit, cpu_offloading
104 | )
105 | if device == "cpu":
106 | kwargs = {"torch_dtype": torch.float32}
107 | if CPU_ISA in ["avx512_bf16", "amx"]:
108 | try:
109 | import intel_extension_for_pytorch as ipex
110 |
111 | kwargs = {"torch_dtype": torch.bfloat16}
112 | except ImportError:
113 | warnings.warn(
114 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
115 | )
116 | elif device.startswith("cuda"):
117 | kwargs = {"torch_dtype": torch.float16}
118 | if num_gpus != 1:
119 | kwargs["device_map"] = "auto"
120 | if max_gpu_memory is None:
121 | kwargs[
122 | "device_map"
123 | ] = "sequential" # This is important for not the same VRAM sizes
124 | available_gpu_memory = get_gpu_memory(num_gpus)
125 | kwargs["max_memory"] = {
126 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
127 | for i in range(num_gpus)
128 | }
129 | else:
130 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
131 |
132 | if cpu_offloading:
133 | # raises an error on incompatible platforms
134 | from transformers import BitsAndBytesConfig
135 |
136 | if "max_memory" in kwargs:
137 | kwargs["max_memory"]["cpu"] = (
138 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
139 | )
140 | kwargs["quantization_config"] = BitsAndBytesConfig(
141 | load_in_8bit_fp32_cpu_offload=cpu_offloading
142 | )
143 | kwargs["load_in_8bit"] = load_8bit
144 | elif load_8bit:
145 | if num_gpus != 1:
146 | warnings.warn(
147 | "8-bit quantization is not supported for multi-gpu inference."
148 | )
149 | else:
150 | model, tokenizer = adapter.load_compress_model(
151 | model_path=model_path,
152 | device=device,
153 | torch_dtype=kwargs["torch_dtype"],
154 | revision=revision,
155 | )
156 | if debug:
157 | print(model)
158 | return model, tokenizer
159 | kwargs["revision"] = revision
160 |
161 | if dtype is not None: # Overwrite dtype if it is provided in the arguments.
162 | kwargs["torch_dtype"] = dtype
163 | if use_flash:
164 | kwargs["use_flash_attention_2"] = use_flash
165 | if len(device_map) > 0:
166 | kwargs["device_map"] = device_map
167 | # Load model
168 | model, tokenizer = adapter.load_model(model_path, kwargs)
169 |
170 | if len(device_map) > 0:
171 | return model, tokenizer
172 |
173 | if (
174 | device == "cpu"
175 | and kwargs["torch_dtype"] is torch.bfloat16
176 | and CPU_ISA is not None
177 | ):
178 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
179 |
180 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in (
181 | "mps",
182 | "xpu",
183 | "npu",
184 | ):
185 | model.to(device)
186 |
187 | if device == "xpu":
188 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True)
189 |
190 | if debug:
191 | print(model)
192 |
193 | return model, tokenizer
194 |
195 | #@torch.inference_mode()
196 | def get_model_answers(
197 | model_path,
198 | model_id,
199 | questions,
200 | answer_file,
201 | max_new_token,
202 | num_choices,
203 | num_gpus_per_model,
204 | max_gpu_memory,
205 | dtype,
206 | debug,
207 | cache_dir,
208 | cpu_offloading,
209 | use_pp,
210 | use_tp_ds,
211 | use_tp,
212 | use_flash,
213 | do_sample
214 | ):
215 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
216 |
217 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices)
218 |
219 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
220 | if use_pp:
221 | model, tokenizer = load_model(
222 | model_path,
223 | use_flash=use_flash,
224 | device=f"cuda",
225 | device_map="balanced",
226 | num_gpus=num_gpus_per_model,
227 | max_gpu_memory=max_gpu_memory,
228 | dtype=dtype,
229 | load_8bit=False,
230 | cpu_offloading=cpu_offloading,
231 | debug=debug,
232 | )
233 |
234 | elif use_tp_ds:
235 | import deepspeed
236 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0')))
237 | model, tokenizer = load_model(
238 | model_path,
239 | use_flash=use_flash,
240 | device_map="cpu",
241 | num_gpus=num_gpus_per_model,
242 | max_gpu_memory=max_gpu_memory,
243 | dtype=dtype,
244 | load_8bit=False,
245 | cpu_offloading=cpu_offloading,
246 | debug=debug,
247 | )
248 | model = deepspeed.init_inference(
249 | model,
250 | mp_size=int(os.getenv("WORLD_SIZE", "1")),
251 | dtype=torch.half
252 | )
253 | else:
254 | model, tokenizer = load_model(
255 | model_path,
256 | use_flash=use_flash,
257 | device=f"cuda:{lade.get_device()}",
258 | num_gpus=num_gpus_per_model,
259 | max_gpu_memory=max_gpu_memory,
260 | dtype=dtype,
261 | load_8bit=False,
262 | cpu_offloading=cpu_offloading,
263 | debug=debug,
264 | )
265 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device())
266 | model.tokenizer = tokenizer
267 |
268 | overall_time = 0
269 | overall_tp = 0
270 | overall_gen = 0
271 | count_gen = 0
272 | stats = {}
273 | for question_idx, question in enumerate(tqdm(questions)):
274 |
275 | stats[question_idx] = {} #
276 | choices = []
277 | for i in range(num_choices):
278 | torch.manual_seed(i)
279 | conv = get_conversation_template(model_id)
280 | turns = []
281 | prompts = []
282 |
283 | for j in range(1):
284 |
285 | prompt = f'''[INST] <>
286 | You are an intelligent chatbot. Answer the questions only using the following context:
287 |
288 | {question}
289 |
290 | Here are some rules you always follow:
291 |
292 | - Generate human readable output, avoid creating output with gibberish text.
293 | - Generate only the requested output, don't include any other language before or after the requested output.
294 | - Never say thank you, that you are happy to help, that you are an AI agent, etc. Just answer directly.
295 | - Generate professional language typically used in business documents in North America.
296 | - Never generate offensive or foul language.
297 |
298 | <>
299 |
300 | Briefly summarize the given context. [/INST]
301 | Summary: '''
302 |
303 | prompts.append(prompt)
304 |
305 | input_ids = tokenizer([prompt]).input_ids
306 |
307 | #print("len: ", len(input_ids[0]))
308 | if len(input_ids[0]) > 2048: #skip input len > 2048 tokens
309 | continue
310 |
311 | # some models may error out when generating long outputs
312 | if True:
313 | if do_sample:
314 | start_time = time.time()
315 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=True, top_k=0, temperature=1.0, top_p=1.0)
316 | end_time = time.time()
317 | else:
318 | start_time = time.time()
319 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=False, top_k=0)
320 | end_time = time.time()
321 |
322 | gap_time = end_time - start_time
323 | tokens = output_ids.numel() - len(input_ids[0])
324 | overall_time += gap_time
325 | overall_gen += tokens
326 | overall_tp += tokens / gap_time
327 | count_gen += 1
328 |
329 | stats[question_idx][j] = [gap_time, tokens]
330 | if lade.get_device() == 0 and ds_local_rank == 0:
331 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time])
332 |
333 | if model.config.is_encoder_decoder:
334 | output_ids = output_ids[0]
335 | else:
336 | output_ids = output_ids[0][len(input_ids[0]) :]
337 |
338 | # be consistent with the template's stop_token_ids
339 | if conv.stop_token_ids:
340 | stop_token_ids_index = [
341 | i
342 | for i, id in enumerate(output_ids)
343 | if id in conv.stop_token_ids
344 | ]
345 | if len(stop_token_ids_index) > 0:
346 | output_ids = output_ids[: stop_token_ids_index[0]]
347 |
348 | output = tokenizer.decode(
349 | output_ids,
350 | spaces_between_special_tokens=False,
351 | )
352 | if conv.stop_str and output.find(conv.stop_str) > 0:
353 | output = output[: output.find(conv.stop_str)]
354 | for special_token in tokenizer.special_tokens_map.values():
355 | if isinstance(special_token, list):
356 | for special_tok in special_token:
357 | output = output.replace(special_tok, "")
358 | else:
359 | output = output.replace(special_token, "")
360 | output = output.strip()
361 |
362 | if conv.name == "xgen" and output.startswith("Assistant:"):
363 | output = output.replace("Assistant:", "", 1).strip()
364 |
365 | #print("output: ", output)
366 | '''
367 | except RuntimeError as e:
368 | print("ERROR question ID: ", question["question_id"])
369 | output = "ERROR"
370 | '''
371 | turns.append(output)
372 |
373 |
374 | choices.append({"index": i, "turns": turns, "prompts" : prompts})
375 |
376 | if lade.get_device() == 0 and ds_local_rank == 0:
377 | # Dump answers
378 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
379 | with open(os.path.expanduser(answer_file), "a") as fout:
380 | ans_json = {
381 | "question_id": question_idx,
382 | "answer_id": shortuuid.uuid(),
383 | "model_id": model_id,
384 | "choices": choices,
385 | "tstamp": time.time(),
386 | }
387 | fout.write(json.dumps(ans_json) + "\n")
388 | #if question_idx == 1:
389 | # break
390 |
391 | if lade.get_device() == 0 and ds_local_rank == 0:
392 | torch.save(stats[question_idx], answer_file + ".pt")
393 | print("LOG SAVE TO ", answer_file + ".pt")
394 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}")
395 | lade.log_history()
396 | lade.save_log(answer_file + "-lade-log.pt")
397 |
398 |
399 | def reorg_answer_file(answer_file):
400 | """Sort by question id and de-duplication"""
401 | answers = {}
402 | with open(answer_file, "r") as fin:
403 | for l in fin:
404 | qid = json.loads(l)["question_id"]
405 | answers[qid] = l
406 |
407 | qids = sorted(list(answers.keys()))
408 | with open(answer_file, "w") as fout:
409 | for qid in qids:
410 | fout.write(answers[qid])
411 |
412 |
413 | if __name__ == "__main__":
414 | parser = argparse.ArgumentParser()
415 | parser.add_argument(
416 | "--model-path",
417 | type=str,
418 | required=True,
419 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
420 | )
421 | parser.add_argument(
422 | "--model-id", type=str, required=True, help="A custom name for the model."
423 | )
424 | parser.add_argument(
425 | "--cache-dir",
426 | type=str,
427 | default="",
428 | )
429 | parser.add_argument(
430 | "--debug",
431 | action="store_true",
432 | )
433 | parser.add_argument(
434 | "--bench-name",
435 | type=str,
436 | default="cnndm",
437 | help="The name of the benchmark question set.",
438 | )
439 | parser.add_argument(
440 | "--question-begin",
441 | type=int,
442 | help="A debug option. The begin index of questions.",
443 | )
444 | parser.add_argument(
445 | "--question-end", type=int, help="A debug option. The end index of questions."
446 | )
447 | parser.add_argument(
448 | "--cpu_offloading", action="store_true"
449 | )
450 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
451 | parser.add_argument(
452 | "--max-new-token",
453 | type=int,
454 | default=1024,
455 | help="The maximum number of new generated tokens.",
456 | )
457 | parser.add_argument(
458 | "--num-choices",
459 | type=int,
460 | default=1,
461 | help="How many completion choices to generate.",
462 | )
463 | parser.add_argument(
464 | "--num-gpus-per-model",
465 | type=int,
466 | default=1,
467 | help="The number of GPUs per model.",
468 | )
469 | parser.add_argument(
470 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
471 | )
472 | parser.add_argument(
473 | "--max-gpu-memory",
474 | type=str,
475 | help="Maxmum GPU memory used for model weights per GPU.",
476 | )
477 | parser.add_argument(
478 | "--dtype",
479 | type=str,
480 | choices=["float32", "float64", "float16", "bfloat16"],
481 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
482 | default=None,
483 | )
484 | parser.add_argument(
485 | "--local_rank",
486 | type=int,
487 | default=0,
488 | )
489 | parser.add_argument(
490 | "--local-rank",
491 | type=int,
492 | default=0,
493 | )
494 | parser.add_argument(
495 | "--level",
496 | type=int,
497 | default=3,
498 | )
499 | parser.add_argument(
500 | "--window",
501 | type=int,
502 | default=10,
503 | )
504 | parser.add_argument(
505 | "--guess",
506 | type=int,
507 | default=10,
508 | )
509 | parser.add_argument(
510 | "--use-tp",
511 | type=int,
512 | default=0,
513 | )
514 | parser.add_argument(
515 | "--use-pp",
516 | type=int,
517 | default=0,
518 | )
519 | parser.add_argument(
520 | "--use-tp-ds",
521 | type=int,
522 | default=0,
523 | )
524 | parser.add_argument(
525 | "--use-flash",
526 | type=int,
527 | default=0,
528 | )
529 | parser.add_argument(
530 | "--do-sample",
531 | type=int,
532 | default=0,
533 | )
534 |
535 | args = parser.parse_args()
536 | if int(os.environ.get("USE_LADE", 0)):
537 |
538 | lade.augment_all()
539 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(",")))
540 | print("lade activated config: ", lade.decoding.CONFIG_MAP)
541 |
542 | question_file = f"mtbench.jsonl"
543 | if args.answer_file:
544 | answer_file = args.answer_file
545 | else:
546 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
547 |
548 | print(f"Output to {answer_file}")
549 |
550 | run_eval(
551 | model_path=args.model_path,
552 | model_id=args.model_id,
553 | question_file=question_file,
554 | question_begin=args.question_begin,
555 | question_end=args.question_end,
556 | answer_file=answer_file,
557 | max_new_token=args.max_new_token,
558 | num_choices=args.num_choices,
559 | num_gpus_per_model=args.num_gpus_per_model,
560 | num_gpus_total=args.num_gpus_total,
561 | max_gpu_memory=args.max_gpu_memory,
562 | dtype=str_to_torch_dtype(args.dtype),
563 | debug=args.debug,
564 | cache_dir=args.cache_dir,
565 | cpu_offloading=args.cpu_offloading,
566 | use_pp=args.use_pp,
567 | use_tp_ds=args.use_tp_ds,
568 | use_tp=args.use_tp,
569 | use_flash=args.use_flash,
570 | do_sample=args.do_sample
571 | )
572 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
573 | if lade.get_device() == 0 and ds_local_rank == 0:
574 | reorg_answer_file(answer_file)
575 |
--------------------------------------------------------------------------------
/applications/eval_humaneval.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py
7 | import argparse
8 | import json
9 | import os
10 | import random
11 | import time
12 |
13 | import shortuuid
14 | import torch
15 | from tqdm import tqdm
16 | from typing import Dict, List, Optional
17 | from fastchat.llm_judge.common import load_questions, temperature_config
18 | from fastchat.model import get_conversation_template
19 | from fastchat.utils import str_to_torch_dtype
20 | import time
21 | import lade
22 | from human_eval.data import write_jsonl, read_problems
23 |
24 | def run_eval(
25 | model_path,
26 | model_id,
27 | question_file,
28 | question_begin,
29 | question_end,
30 | answer_file,
31 | max_new_token,
32 | num_choices,
33 | num_gpus_per_model,
34 | num_gpus_total,
35 | max_gpu_memory,
36 | dtype,
37 | debug,
38 | cache_dir,
39 | cpu_offloading,
40 | use_pp,
41 | use_tp,
42 | use_tp_ds,
43 | use_flash,
44 | do_sample
45 | ):
46 | #questions = load_questions(question_file, question_begin, question_end)
47 | questions = read_problems()
48 | questions = list(questions.values())[question_begin:question_end]
49 |
50 | # random shuffle the questions to balance the loading
51 | ###not shuffle
52 | #random.shuffle(questions)
53 |
54 | # Split the question file into `num_gpus` files
55 | assert num_gpus_total % num_gpus_per_model == 0
56 |
57 | get_answers_func = get_model_answers
58 |
59 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
60 | ans_handles = []
61 | for i in range(0, len(questions), chunk_size):
62 | ans_handles.append(
63 | get_answers_func(
64 | model_path,
65 | model_id,
66 | questions[i : i + chunk_size],
67 | answer_file,
68 | max_new_token,
69 | num_choices,
70 | num_gpus_per_model,
71 | max_gpu_memory,
72 | dtype=dtype,
73 | debug=debug,
74 | cache_dir=cache_dir,
75 | cpu_offloading=cpu_offloading,
76 | use_tp=use_tp,
77 | use_pp=use_pp,
78 | use_tp_ds=use_tp_ds,
79 | use_flash=use_flash,
80 | do_sample=do_sample
81 | )
82 | )
83 |
84 |
85 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM
86 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration
87 |
88 | def load_model(
89 | model_path: str,
90 | device: str = "cuda",
91 | device_map: str= "",
92 | num_gpus: int = 1,
93 | max_gpu_memory: Optional[str] = None,
94 | dtype: Optional[torch.dtype] = None,
95 | load_8bit: bool = False,
96 | cpu_offloading: bool = False,
97 | revision: str = "main",
98 | debug: bool = False,
99 | use_flash:bool = False
100 | ):
101 | """Load a model from Hugging Face."""
102 | # get model adapter
103 | adapter = Llama2Adapter()
104 | # Handle device mapping
105 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
106 | device, load_8bit, cpu_offloading
107 | )
108 | if device == "cpu":
109 | kwargs = {"torch_dtype": torch.float32}
110 | if CPU_ISA in ["avx512_bf16", "amx"]:
111 | try:
112 | import intel_extension_for_pytorch as ipex
113 |
114 | kwargs = {"torch_dtype": torch.bfloat16}
115 | except ImportError:
116 | warnings.warn(
117 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
118 | )
119 | elif device.startswith("cuda"):
120 | kwargs = {"torch_dtype": torch.float16}
121 | if num_gpus != 1:
122 | kwargs["device_map"] = "auto"
123 | if max_gpu_memory is None:
124 | kwargs[
125 | "device_map"
126 | ] = "sequential" # This is important for not the same VRAM sizes
127 | available_gpu_memory = get_gpu_memory(num_gpus)
128 | kwargs["max_memory"] = {
129 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
130 | for i in range(num_gpus)
131 | }
132 | else:
133 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
134 |
135 | if cpu_offloading:
136 | # raises an error on incompatible platforms
137 | from transformers import BitsAndBytesConfig
138 |
139 | if "max_memory" in kwargs:
140 | kwargs["max_memory"]["cpu"] = (
141 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
142 | )
143 | kwargs["quantization_config"] = BitsAndBytesConfig(
144 | load_in_8bit_fp32_cpu_offload=cpu_offloading
145 | )
146 | kwargs["load_in_8bit"] = load_8bit
147 | elif load_8bit:
148 | if num_gpus != 1:
149 | warnings.warn(
150 | "8-bit quantization is not supported for multi-gpu inference."
151 | )
152 | else:
153 | model, tokenizer = adapter.load_compress_model(
154 | model_path=model_path,
155 | device=device,
156 | torch_dtype=kwargs["torch_dtype"],
157 | revision=revision,
158 | )
159 | if debug:
160 | print(model)
161 | return model, tokenizer
162 | kwargs["revision"] = revision
163 |
164 | if dtype is not None: # Overwrite dtype if it is provided in the arguments.
165 | kwargs["torch_dtype"] = dtype
166 | if use_flash:
167 | kwargs["use_flash_attention_2"] = use_flash
168 | if len(device_map) > 0:
169 | kwargs["device_map"] = device_map
170 | # Load model
171 | model, tokenizer = adapter.load_model(model_path, kwargs)
172 |
173 | if len(device_map) > 0:
174 | return model, tokenizer
175 |
176 | if (
177 | device == "cpu"
178 | and kwargs["torch_dtype"] is torch.bfloat16
179 | and CPU_ISA is not None
180 | ):
181 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
182 |
183 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in (
184 | "mps",
185 | "xpu",
186 | "npu",
187 | ):
188 | model.to(device)
189 |
190 | if device == "xpu":
191 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True)
192 |
193 | if debug:
194 | print(model)
195 |
196 | return model, tokenizer
197 |
198 | #@torch.inference_mode()
199 | def get_model_answers(
200 | model_path,
201 | model_id,
202 | questions,
203 | answer_file,
204 | max_new_token,
205 | num_choices,
206 | num_gpus_per_model,
207 | max_gpu_memory,
208 | dtype,
209 | debug,
210 | cache_dir,
211 | cpu_offloading,
212 | use_pp,
213 | use_tp_ds,
214 | use_tp,
215 | use_flash,
216 | do_sample
217 | ):
218 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
219 |
220 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices)
221 | #tokenizer = AutoTokenizer.from_pretrained(model_path)
222 | #cfg = AutoConfig.from_pretrained(model_path)
223 | #cfg._flash_attn_2_enabled= use_flash
224 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
225 | if use_pp:
226 | model, tokenizer = load_model(
227 | model_path,
228 | use_flash=use_flash,
229 | device=f"cuda",
230 | device_map="balanced",
231 | num_gpus=num_gpus_per_model,
232 | max_gpu_memory=max_gpu_memory,
233 | dtype=dtype,
234 | load_8bit=False,
235 | cpu_offloading=cpu_offloading,
236 | debug=debug,
237 | )
238 |
239 | elif use_tp_ds:
240 | import deepspeed
241 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0')))
242 | model, tokenizer = load_model(
243 | model_path,
244 | use_flash=use_flash,
245 | device_map="cpu",
246 | num_gpus=num_gpus_per_model,
247 | max_gpu_memory=max_gpu_memory,
248 | dtype=dtype,
249 | load_8bit=False,
250 | cpu_offloading=cpu_offloading,
251 | debug=debug,
252 | )
253 | model = deepspeed.init_inference(
254 | model,
255 | mp_size=int(os.getenv("WORLD_SIZE", "1")),
256 | dtype=torch.half
257 | )
258 | else:
259 | model, tokenizer = load_model(
260 | model_path,
261 | use_flash=use_flash,
262 | device=f"cuda:{lade.get_device()}",
263 | num_gpus=num_gpus_per_model,
264 | max_gpu_memory=max_gpu_memory,
265 | dtype=dtype,
266 | load_8bit=False,
267 | cpu_offloading=cpu_offloading,
268 | debug=debug,
269 | )
270 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device())
271 | model.tokenizer = tokenizer
272 |
273 | overall_time = 0
274 | overall_tp = 0
275 | overall_gen = 0
276 | count_gen = 0
277 | stats = {}
278 | for question_idx, question in enumerate(tqdm(questions)):
279 |
280 | if not do_sample:
281 | temperature = 0.0 #force greedy
282 |
283 | stats[question_idx] = {} #
284 | choices = []
285 | for i in range(num_choices):
286 | torch.manual_seed(i)
287 | conv = get_conversation_template(model_id)
288 | turns = []
289 | prompts = []
290 |
291 | for j in range(1):
292 | qs = question["prompt"]
293 | prompt = qs
294 |
295 | input_ids = tokenizer(prompt, return_tensors="pt",
296 | max_length=1024, truncation=True).input_ids.to("cuda")
297 |
298 | if temperature < 1e-4:
299 | do_sample = False
300 | else:
301 | do_sample = True
302 |
303 |
304 | # some models may error out when generating long outputs
305 | if True:
306 | start_time = time.time()
307 | output_ids = model.generate(
308 | input_ids,
309 | do_sample=do_sample,
310 | temperature=temperature,
311 | max_new_tokens=max_new_token,
312 | )
313 | end_time = time.time()
314 | gap_time = end_time - start_time
315 | tokens = output_ids.numel() - input_ids.numel()
316 | overall_time += gap_time
317 | overall_gen += tokens
318 | overall_tp += tokens / gap_time
319 | count_gen += 1
320 |
321 | stats[question_idx][j] = [gap_time, tokens]
322 | if lade.get_device() == 0 and ds_local_rank == 0:
323 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time])
324 |
325 | output = tokenizer.decode(
326 | output_ids[0].tolist(),
327 | skip_special_tokens=False,
328 | )
329 |
330 | turns.append(output)
331 | prompts.append(prompt)
332 |
333 | choices.append({"index": i, "turns": turns, "prompts" : prompts})
334 |
335 | if lade.get_device() == 0 and ds_local_rank == 0:
336 | # Dump answers
337 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
338 | with open(os.path.expanduser(answer_file), "a") as fout:
339 | ans_json = {
340 | "question_id": question_idx,
341 | "answer_id": shortuuid.uuid(),
342 | "model_id": model_id,
343 | "choices": choices,
344 | "tstamp": time.time(),
345 | }
346 | fout.write(json.dumps(ans_json) + "\n")
347 | #if question_idx == 1:
348 | # break
349 |
350 | if lade.get_device() == 0 and ds_local_rank == 0:
351 | torch.save(stats[question_idx], answer_file + ".pt")
352 | print("LOG SAVE TO ", answer_file + ".pt")
353 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}")
354 | lade.log_history()
355 | lade.save_log(answer_file + "-lade-log.pt")
356 |
357 |
358 | def reorg_answer_file(answer_file):
359 | """Sort by question id and de-duplication"""
360 | answers = {}
361 | with open(answer_file, "r") as fin:
362 | for l in fin:
363 | qid = json.loads(l)["question_id"]
364 | answers[qid] = l
365 |
366 | qids = sorted(list(answers.keys()))
367 | with open(answer_file, "w") as fout:
368 | for qid in qids:
369 | fout.write(answers[qid])
370 |
371 |
372 | if __name__ == "__main__":
373 | parser = argparse.ArgumentParser()
374 | parser.add_argument(
375 | "--model-path",
376 | type=str,
377 | required=True,
378 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
379 | )
380 | parser.add_argument(
381 | "--model-id", type=str, required=True, help="A custom name for the model."
382 | )
383 | parser.add_argument(
384 | "--cache-dir",
385 | type=str,
386 | default="",
387 | )
388 | parser.add_argument(
389 | "--debug",
390 | action="store_true",
391 | )
392 | parser.add_argument(
393 | "--bench-name",
394 | type=str,
395 | default="humaneval",
396 | help="The name of the benchmark question set.",
397 | )
398 | parser.add_argument(
399 | "--question-begin",
400 | type=int,
401 | help="A debug option. The begin index of questions.",
402 | )
403 | parser.add_argument(
404 | "--question-end", type=int, help="A debug option. The end index of questions."
405 | )
406 | parser.add_argument(
407 | "--cpu_offloading", action="store_true"
408 | )
409 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
410 | parser.add_argument(
411 | "--max-new-token",
412 | type=int,
413 | default=512,
414 | help="The maximum number of new generated tokens.",
415 | )
416 | parser.add_argument(
417 | "--num-choices",
418 | type=int,
419 | default=1,
420 | help="How many completion choices to generate.",
421 | )
422 | parser.add_argument(
423 | "--num-gpus-per-model",
424 | type=int,
425 | default=1,
426 | help="The number of GPUs per model.",
427 | )
428 | parser.add_argument(
429 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
430 | )
431 | parser.add_argument(
432 | "--max-gpu-memory",
433 | type=str,
434 | help="Maxmum GPU memory used for model weights per GPU.",
435 | )
436 | parser.add_argument(
437 | "--dtype",
438 | type=str,
439 | choices=["float32", "float64", "float16", "bfloat16"],
440 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
441 | default=None,
442 | )
443 | parser.add_argument(
444 | "--local_rank",
445 | type=int,
446 | default=0,
447 | )
448 | parser.add_argument(
449 | "--local-rank",
450 | type=int,
451 | default=0,
452 | )
453 | parser.add_argument(
454 | "--level",
455 | type=int,
456 | default=3,
457 | )
458 | parser.add_argument(
459 | "--window",
460 | type=int,
461 | default=10,
462 | )
463 | parser.add_argument(
464 | "--guess",
465 | type=int,
466 | default=10,
467 | )
468 | parser.add_argument(
469 | "--use-tp",
470 | type=int,
471 | default=0,
472 | )
473 | parser.add_argument(
474 | "--use-pp",
475 | type=int,
476 | default=0,
477 | )
478 | parser.add_argument(
479 | "--use-tp-ds",
480 | type=int,
481 | default=0,
482 | )
483 | parser.add_argument(
484 | "--use-flash",
485 | type=int,
486 | default=0,
487 | )
488 | parser.add_argument(
489 | "--do-sample",
490 | type=int,
491 | default=0,
492 | )
493 |
494 | args = parser.parse_args()
495 | if int(os.environ.get("USE_LADE", 0)):
496 |
497 | lade.augment_all()
498 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(",")))
499 | print("lade activated config: ", lade.decoding.CONFIG_MAP)
500 |
501 | question_file = f"mtbench.jsonl"
502 | if args.answer_file:
503 | answer_file = args.answer_file
504 | else:
505 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
506 |
507 | print(f"Output to {answer_file}")
508 |
509 | run_eval(
510 | model_path=args.model_path,
511 | model_id=args.model_id,
512 | question_file=question_file,
513 | question_begin=args.question_begin,
514 | question_end=args.question_end,
515 | answer_file=answer_file,
516 | max_new_token=args.max_new_token,
517 | num_choices=args.num_choices,
518 | num_gpus_per_model=args.num_gpus_per_model,
519 | num_gpus_total=args.num_gpus_total,
520 | max_gpu_memory=args.max_gpu_memory,
521 | dtype=str_to_torch_dtype(args.dtype),
522 | debug=args.debug,
523 | cache_dir=args.cache_dir,
524 | cpu_offloading=args.cpu_offloading,
525 | use_pp=args.use_pp,
526 | use_tp_ds=args.use_tp_ds,
527 | use_tp=args.use_tp,
528 | use_flash=args.use_flash,
529 | do_sample=args.do_sample
530 | )
531 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
532 | if lade.get_device() == 0 and ds_local_rank == 0:
533 | reorg_answer_file(answer_file)
534 |
--------------------------------------------------------------------------------
/applications/eval_mtbench.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py
7 | import argparse
8 | import json
9 | import os
10 | import random
11 | import time
12 |
13 | import shortuuid
14 | import torch
15 | from tqdm import tqdm
16 | from typing import Dict, List, Optional
17 | from fastchat.llm_judge.common import load_questions, temperature_config
18 | from fastchat.model import get_conversation_template
19 | from fastchat.utils import str_to_torch_dtype
20 | import time
21 | import lade
22 |
23 | def run_eval(
24 | model_path,
25 | model_id,
26 | question_file,
27 | question_begin,
28 | question_end,
29 | answer_file,
30 | max_new_token,
31 | num_choices,
32 | num_gpus_per_model,
33 | num_gpus_total,
34 | max_gpu_memory,
35 | dtype,
36 | debug,
37 | cache_dir,
38 | cpu_offloading,
39 | use_pp,
40 | use_tp,
41 | use_tp_ds,
42 | use_flash,
43 | do_sample
44 | ):
45 | questions = load_questions(question_file, question_begin, question_end)
46 | # random shuffle the questions to balance the loading
47 | ###not shuffle
48 | #random.shuffle(questions)
49 |
50 | # Split the question file into `num_gpus` files
51 | assert num_gpus_total % num_gpus_per_model == 0
52 |
53 | get_answers_func = get_model_answers
54 |
55 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
56 | ans_handles = []
57 | for i in range(0, len(questions), chunk_size):
58 | ans_handles.append(
59 | get_answers_func(
60 | model_path,
61 | model_id,
62 | questions[i : i + chunk_size],
63 | answer_file,
64 | max_new_token,
65 | num_choices,
66 | num_gpus_per_model,
67 | max_gpu_memory,
68 | dtype=dtype,
69 | debug=debug,
70 | cache_dir=cache_dir,
71 | cpu_offloading=cpu_offloading,
72 | use_tp=use_tp,
73 | use_pp=use_pp,
74 | use_tp_ds=use_tp_ds,
75 | use_flash=use_flash,
76 | do_sample=do_sample
77 | )
78 | )
79 |
80 |
81 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM
82 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration
83 |
84 | def load_model(
85 | model_path: str,
86 | device: str = "cuda",
87 | device_map: str= "",
88 | num_gpus: int = 1,
89 | max_gpu_memory: Optional[str] = None,
90 | dtype: Optional[torch.dtype] = None,
91 | load_8bit: bool = False,
92 | cpu_offloading: bool = False,
93 | revision: str = "main",
94 | debug: bool = False,
95 | use_flash:bool = False
96 | ):
97 | """Load a model from Hugging Face."""
98 | # get model adapter
99 | adapter = Llama2Adapter()
100 | # Handle device mapping
101 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
102 | device, load_8bit, cpu_offloading
103 | )
104 | if device == "cpu":
105 | kwargs = {"torch_dtype": torch.float32}
106 | if CPU_ISA in ["avx512_bf16", "amx"]:
107 | try:
108 | import intel_extension_for_pytorch as ipex
109 |
110 | kwargs = {"torch_dtype": torch.bfloat16}
111 | except ImportError:
112 | warnings.warn(
113 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
114 | )
115 | elif device.startswith("cuda"):
116 | kwargs = {"torch_dtype": torch.float16}
117 | if num_gpus != 1:
118 | kwargs["device_map"] = "auto"
119 | if max_gpu_memory is None:
120 | kwargs[
121 | "device_map"
122 | ] = "sequential" # This is important for not the same VRAM sizes
123 | available_gpu_memory = get_gpu_memory(num_gpus)
124 | kwargs["max_memory"] = {
125 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
126 | for i in range(num_gpus)
127 | }
128 | else:
129 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
130 |
131 | if cpu_offloading:
132 | # raises an error on incompatible platforms
133 | from transformers import BitsAndBytesConfig
134 |
135 | if "max_memory" in kwargs:
136 | kwargs["max_memory"]["cpu"] = (
137 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
138 | )
139 | kwargs["quantization_config"] = BitsAndBytesConfig(
140 | load_in_8bit_fp32_cpu_offload=cpu_offloading
141 | )
142 | kwargs["load_in_8bit"] = load_8bit
143 | elif load_8bit:
144 | if num_gpus != 1:
145 | warnings.warn(
146 | "8-bit quantization is not supported for multi-gpu inference."
147 | )
148 | else:
149 | model, tokenizer = adapter.load_compress_model(
150 | model_path=model_path,
151 | device=device,
152 | torch_dtype=kwargs["torch_dtype"],
153 | revision=revision,
154 | )
155 | if debug:
156 | print(model)
157 | return model, tokenizer
158 | kwargs["revision"] = revision
159 |
160 | if dtype is not None: # Overwrite dtype if it is provided in the arguments.
161 | kwargs["torch_dtype"] = dtype
162 | if use_flash:
163 | kwargs["use_flash_attention_2"] = use_flash
164 | if len(device_map) > 0:
165 | kwargs["device_map"] = device_map
166 | # Load model
167 | model, tokenizer = adapter.load_model(model_path, kwargs)
168 |
169 | if len(device_map) > 0:
170 | return model, tokenizer
171 |
172 | if (
173 | device == "cpu"
174 | and kwargs["torch_dtype"] is torch.bfloat16
175 | and CPU_ISA is not None
176 | ):
177 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
178 |
179 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in (
180 | "mps",
181 | "xpu",
182 | "npu",
183 | ):
184 | model.to(device)
185 |
186 | if device == "xpu":
187 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True)
188 |
189 | if debug:
190 | print(model)
191 |
192 | return model, tokenizer
193 |
194 | #@torch.inference_mode()
195 | def get_model_answers(
196 | model_path,
197 | model_id,
198 | questions,
199 | answer_file,
200 | max_new_token,
201 | num_choices,
202 | num_gpus_per_model,
203 | max_gpu_memory,
204 | dtype,
205 | debug,
206 | cache_dir,
207 | cpu_offloading,
208 | use_pp,
209 | use_tp_ds,
210 | use_tp,
211 | use_flash,
212 | do_sample
213 | ):
214 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
215 |
216 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices)
217 |
218 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
219 | if use_pp:
220 | model, tokenizer = load_model(
221 | model_path,
222 | use_flash=use_flash,
223 | device=f"cuda",
224 | device_map="balanced",
225 | num_gpus=num_gpus_per_model,
226 | max_gpu_memory=max_gpu_memory,
227 | dtype=dtype,
228 | load_8bit=False,
229 | cpu_offloading=cpu_offloading,
230 | debug=debug,
231 | )
232 |
233 | elif use_tp_ds:
234 | import deepspeed
235 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0')))
236 | model, tokenizer = load_model(
237 | model_path,
238 | use_flash=use_flash,
239 | device_map="cpu",
240 | num_gpus=num_gpus_per_model,
241 | max_gpu_memory=max_gpu_memory,
242 | dtype=dtype,
243 | load_8bit=False,
244 | cpu_offloading=cpu_offloading,
245 | debug=debug,
246 | )
247 | model = deepspeed.init_inference(
248 | model,
249 | mp_size=int(os.getenv("WORLD_SIZE", "1")),
250 | dtype=torch.half
251 | )
252 | else:
253 | model, tokenizer = load_model(
254 | model_path,
255 | use_flash=use_flash,
256 | device=f"cuda:{lade.get_device()}",
257 | num_gpus=num_gpus_per_model,
258 | max_gpu_memory=max_gpu_memory,
259 | dtype=dtype,
260 | load_8bit=False,
261 | cpu_offloading=cpu_offloading,
262 | debug=debug,
263 | )
264 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device())
265 | model.tokenizer = tokenizer
266 |
267 | overall_time = 0
268 | overall_tp = 0
269 | overall_gen = 0
270 | count_gen = 0
271 | stats = {}
272 | for question_idx, question in enumerate(tqdm(questions)):
273 | if question["category"] in temperature_config:
274 | temperature = temperature_config[question["category"]]
275 | else:
276 | temperature = 0.7
277 |
278 | if not do_sample:
279 | temperature = 0.0 #force greedy
280 |
281 | stats[question_idx] = {} #
282 | choices = []
283 | for i in range(num_choices):
284 | torch.manual_seed(i)
285 | conv = get_conversation_template(model_id)
286 | turns = []
287 | prompts = []
288 |
289 | for j in range(len(question["turns"])):
290 | qs = question["turns"][j]
291 | conv.append_message(conv.roles[0], qs)
292 | conv.append_message(conv.roles[1], None)
293 | prompt = conv.get_prompt()
294 | prompts.append(prompt)
295 | input_ids = tokenizer([prompt]).input_ids
296 |
297 | if temperature < 1e-4:
298 | do_sample = False
299 | else:
300 | do_sample = True
301 |
302 |
303 | # some models may error out when generating long outputs
304 | if True:
305 | start_time = time.time()
306 | output_ids = model.generate(
307 | torch.as_tensor(input_ids).cuda(),
308 | do_sample=do_sample,
309 | temperature=temperature,
310 | max_new_tokens=max_new_token,
311 | top_k=0.0, top_p=1.0,
312 | )
313 | end_time = time.time()
314 | gap_time = end_time - start_time
315 | tokens = output_ids.numel() - len(input_ids[0])
316 | overall_time += gap_time
317 | overall_gen += tokens
318 | overall_tp += tokens / gap_time
319 | count_gen += 1
320 |
321 | stats[question_idx][j] = [gap_time, tokens]
322 | if lade.get_device() == 0 and ds_local_rank == 0:
323 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time])
324 |
325 | if model.config.is_encoder_decoder:
326 | output_ids = output_ids[0]
327 | else:
328 | output_ids = output_ids[0][len(input_ids[0]) :]
329 |
330 | # be consistent with the template's stop_token_ids
331 | if conv.stop_token_ids:
332 | stop_token_ids_index = [
333 | i
334 | for i, id in enumerate(output_ids)
335 | if id in conv.stop_token_ids
336 | ]
337 | if len(stop_token_ids_index) > 0:
338 | output_ids = output_ids[: stop_token_ids_index[0]]
339 |
340 | output = tokenizer.decode(
341 | output_ids,
342 | spaces_between_special_tokens=False,
343 | )
344 | if conv.stop_str and output.find(conv.stop_str) > 0:
345 | output = output[: output.find(conv.stop_str)]
346 | for special_token in tokenizer.special_tokens_map.values():
347 | if isinstance(special_token, list):
348 | for special_tok in special_token:
349 | output = output.replace(special_tok, "")
350 | else:
351 | output = output.replace(special_token, "")
352 | output = output.strip()
353 |
354 | if conv.name == "xgen" and output.startswith("Assistant:"):
355 | output = output.replace("Assistant:", "", 1).strip()
356 | '''
357 | except RuntimeError as e:
358 | print("ERROR question ID: ", question["question_id"])
359 | output = "ERROR"
360 | '''
361 | turns.append(output)
362 | conv.messages[-1][-1] = output
363 |
364 | choices.append({"index": i, "turns": turns, "prompts" : prompts})
365 |
366 | if lade.get_device() == 0 and ds_local_rank == 0:
367 | # Dump answers
368 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
369 | with open(os.path.expanduser(answer_file), "a") as fout:
370 | ans_json = {
371 | "question_id": question["question_id"],
372 | "answer_id": shortuuid.uuid(),
373 | "model_id": model_id,
374 | "choices": choices,
375 | "tstamp": time.time(),
376 | }
377 | fout.write(json.dumps(ans_json) + "\n")
378 | #if question_idx == 1:
379 | # break
380 |
381 | if lade.get_device() == 0 and ds_local_rank == 0:
382 | torch.save(stats[question_idx], answer_file + ".pt")
383 | print("LOG SAVE TO ", answer_file + ".pt")
384 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}")
385 | lade.log_history()
386 | lade.save_log(answer_file + "-lade-log.pt")
387 |
388 |
389 | def reorg_answer_file(answer_file):
390 | """Sort by question id and de-duplication"""
391 | answers = {}
392 | with open(answer_file, "r") as fin:
393 | for l in fin:
394 | qid = json.loads(l)["question_id"]
395 | answers[qid] = l
396 |
397 | qids = sorted(list(answers.keys()))
398 | with open(answer_file, "w") as fout:
399 | for qid in qids:
400 | fout.write(answers[qid])
401 |
402 |
403 | if __name__ == "__main__":
404 | parser = argparse.ArgumentParser()
405 | parser.add_argument(
406 | "--model-path",
407 | type=str,
408 | required=True,
409 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
410 | )
411 | parser.add_argument(
412 | "--model-id", type=str, required=True, help="A custom name for the model."
413 | )
414 | parser.add_argument(
415 | "--cache-dir",
416 | type=str,
417 | default="",
418 | )
419 | parser.add_argument(
420 | "--debug",
421 | action="store_true",
422 | )
423 | parser.add_argument(
424 | "--bench-name",
425 | type=str,
426 | default="mt_bench",
427 | help="The name of the benchmark question set.",
428 | )
429 | parser.add_argument(
430 | "--question-begin",
431 | type=int,
432 | help="A debug option. The begin index of questions.",
433 | )
434 | parser.add_argument(
435 | "--question-end", type=int, help="A debug option. The end index of questions."
436 | )
437 | parser.add_argument(
438 | "--cpu_offloading", action="store_true"
439 | )
440 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
441 | parser.add_argument(
442 | "--max-new-token",
443 | type=int,
444 | default=1024,
445 | help="The maximum number of new generated tokens.",
446 | )
447 | parser.add_argument(
448 | "--num-choices",
449 | type=int,
450 | default=1,
451 | help="How many completion choices to generate.",
452 | )
453 | parser.add_argument(
454 | "--num-gpus-per-model",
455 | type=int,
456 | default=1,
457 | help="The number of GPUs per model.",
458 | )
459 | parser.add_argument(
460 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
461 | )
462 | parser.add_argument(
463 | "--max-gpu-memory",
464 | type=str,
465 | help="Maxmum GPU memory used for model weights per GPU.",
466 | )
467 | parser.add_argument(
468 | "--dtype",
469 | type=str,
470 | choices=["float32", "float64", "float16", "bfloat16"],
471 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
472 | default=None,
473 | )
474 | parser.add_argument(
475 | "--local_rank",
476 | type=int,
477 | default=0,
478 | )
479 | parser.add_argument(
480 | "--local-rank",
481 | type=int,
482 | default=0,
483 | )
484 | parser.add_argument(
485 | "--level",
486 | type=int,
487 | default=3,
488 | )
489 | parser.add_argument(
490 | "--window",
491 | type=int,
492 | default=10,
493 | )
494 | parser.add_argument(
495 | "--guess",
496 | type=int,
497 | default=10,
498 | )
499 | parser.add_argument(
500 | "--use-tp",
501 | type=int,
502 | default=0,
503 | )
504 | parser.add_argument(
505 | "--use-pp",
506 | type=int,
507 | default=0,
508 | )
509 | parser.add_argument(
510 | "--use-tp-ds",
511 | type=int,
512 | default=0,
513 | )
514 | parser.add_argument(
515 | "--use-flash",
516 | type=int,
517 | default=0,
518 | )
519 | parser.add_argument(
520 | "--do-sample",
521 | type=int,
522 | default=0,
523 | )
524 |
525 | args = parser.parse_args()
526 | if int(os.environ.get("USE_LADE", 0)):
527 |
528 | lade.augment_all()
529 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(",")))
530 | print("lade activated config: ", lade.decoding.CONFIG_MAP)
531 |
532 | question_file = f"mtbench.jsonl"
533 | if args.answer_file:
534 | answer_file = args.answer_file
535 | else:
536 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
537 |
538 | print(f"Output to {answer_file}")
539 |
540 | run_eval(
541 | model_path=args.model_path,
542 | model_id=args.model_id,
543 | question_file=question_file,
544 | question_begin=args.question_begin,
545 | question_end=args.question_end,
546 | answer_file=answer_file,
547 | max_new_token=args.max_new_token,
548 | num_choices=args.num_choices,
549 | num_gpus_per_model=args.num_gpus_per_model,
550 | num_gpus_total=args.num_gpus_total,
551 | max_gpu_memory=args.max_gpu_memory,
552 | dtype=str_to_torch_dtype(args.dtype),
553 | debug=args.debug,
554 | cache_dir=args.cache_dir,
555 | cpu_offloading=args.cpu_offloading,
556 | use_pp=args.use_pp,
557 | use_tp_ds=args.use_tp_ds,
558 | use_tp=args.use_tp,
559 | use_flash=args.use_flash,
560 | do_sample=args.do_sample
561 | )
562 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
563 | if lade.get_device() == 0 and ds_local_rank == 0:
564 | reorg_answer_file(answer_file)
565 |
--------------------------------------------------------------------------------
/applications/eval_xsum.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py
7 | import argparse
8 | import json
9 | import os
10 | import random
11 | import time
12 |
13 | import shortuuid
14 | import torch
15 | from tqdm import tqdm
16 | from typing import Dict, List, Optional
17 | from fastchat.llm_judge.common import load_questions, temperature_config
18 | from fastchat.model import get_conversation_template
19 | from fastchat.utils import str_to_torch_dtype
20 | import time
21 | import lade
22 | from datasets import load_dataset
23 |
24 | def run_eval(
25 | model_path,
26 | model_id,
27 | question_file,
28 | question_begin,
29 | question_end,
30 | answer_file,
31 | max_new_token,
32 | num_choices,
33 | num_gpus_per_model,
34 | num_gpus_total,
35 | max_gpu_memory,
36 | dtype,
37 | debug,
38 | cache_dir,
39 | cpu_offloading,
40 | use_pp,
41 | use_tp,
42 | use_tp_ds,
43 | use_flash,
44 | do_sample
45 | ):
46 | questions = load_dataset("EdinburghNLP/xsum", split="validation", streaming=False)["document"][question_begin:question_end]
47 | # random shuffle the questions to balance the loading
48 | ###not shuffle
49 | #random.shuffle(questions)
50 |
51 | # Split the question file into `num_gpus` files
52 | assert num_gpus_total % num_gpus_per_model == 0
53 |
54 | get_answers_func = get_model_answers
55 |
56 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
57 | ans_handles = []
58 | for i in range(0, len(questions), chunk_size):
59 | ans_handles.append(
60 | get_answers_func(
61 | model_path,
62 | model_id,
63 | questions[i : i + chunk_size],
64 | answer_file,
65 | max_new_token,
66 | num_choices,
67 | num_gpus_per_model,
68 | max_gpu_memory,
69 | dtype=dtype,
70 | debug=debug,
71 | cache_dir=cache_dir,
72 | cpu_offloading=cpu_offloading,
73 | use_tp=use_tp,
74 | use_pp=use_pp,
75 | use_tp_ds=use_tp_ds,
76 | use_flash=use_flash,
77 | do_sample=do_sample
78 | )
79 | )
80 |
81 |
82 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM
83 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration
84 |
85 | def load_model(
86 | model_path: str,
87 | device: str = "cuda",
88 | device_map: str= "",
89 | num_gpus: int = 1,
90 | max_gpu_memory: Optional[str] = None,
91 | dtype: Optional[torch.dtype] = None,
92 | load_8bit: bool = False,
93 | cpu_offloading: bool = False,
94 | revision: str = "main",
95 | debug: bool = False,
96 | use_flash:bool = False
97 | ):
98 | """Load a model from Hugging Face."""
99 | # get model adapter
100 | adapter = Llama2Adapter()
101 | # Handle device mapping
102 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
103 | device, load_8bit, cpu_offloading
104 | )
105 | if device == "cpu":
106 | kwargs = {"torch_dtype": torch.float32}
107 | if CPU_ISA in ["avx512_bf16", "amx"]:
108 | try:
109 | import intel_extension_for_pytorch as ipex
110 |
111 | kwargs = {"torch_dtype": torch.bfloat16}
112 | except ImportError:
113 | warnings.warn(
114 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
115 | )
116 | elif device.startswith("cuda"):
117 | kwargs = {"torch_dtype": torch.float16}
118 | if num_gpus != 1:
119 | kwargs["device_map"] = "auto"
120 | if max_gpu_memory is None:
121 | kwargs[
122 | "device_map"
123 | ] = "sequential" # This is important for not the same VRAM sizes
124 | available_gpu_memory = get_gpu_memory(num_gpus)
125 | kwargs["max_memory"] = {
126 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
127 | for i in range(num_gpus)
128 | }
129 | else:
130 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
131 |
132 | if cpu_offloading:
133 | # raises an error on incompatible platforms
134 | from transformers import BitsAndBytesConfig
135 |
136 | if "max_memory" in kwargs:
137 | kwargs["max_memory"]["cpu"] = (
138 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib"
139 | )
140 | kwargs["quantization_config"] = BitsAndBytesConfig(
141 | load_in_8bit_fp32_cpu_offload=cpu_offloading
142 | )
143 | kwargs["load_in_8bit"] = load_8bit
144 | elif load_8bit:
145 | if num_gpus != 1:
146 | warnings.warn(
147 | "8-bit quantization is not supported for multi-gpu inference."
148 | )
149 | else:
150 | model, tokenizer = adapter.load_compress_model(
151 | model_path=model_path,
152 | device=device,
153 | torch_dtype=kwargs["torch_dtype"],
154 | revision=revision,
155 | )
156 | if debug:
157 | print(model)
158 | return model, tokenizer
159 | kwargs["revision"] = revision
160 |
161 | if dtype is not None: # Overwrite dtype if it is provided in the arguments.
162 | kwargs["torch_dtype"] = dtype
163 | if use_flash:
164 | kwargs["use_flash_attention_2"] = use_flash
165 | if len(device_map) > 0:
166 | kwargs["device_map"] = device_map
167 | # Load model
168 | model, tokenizer = adapter.load_model(model_path, kwargs)
169 |
170 | if len(device_map) > 0:
171 | return model, tokenizer
172 |
173 | if (
174 | device == "cpu"
175 | and kwargs["torch_dtype"] is torch.bfloat16
176 | and CPU_ISA is not None
177 | ):
178 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"])
179 |
180 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in (
181 | "mps",
182 | "xpu",
183 | "npu",
184 | ):
185 | model.to(device)
186 |
187 | if device == "xpu":
188 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True)
189 |
190 | if debug:
191 | print(model)
192 |
193 | return model, tokenizer
194 |
195 | #@torch.inference_mode()
196 | def get_model_answers(
197 | model_path,
198 | model_id,
199 | questions,
200 | answer_file,
201 | max_new_token,
202 | num_choices,
203 | num_gpus_per_model,
204 | max_gpu_memory,
205 | dtype,
206 | debug,
207 | cache_dir,
208 | cpu_offloading,
209 | use_pp,
210 | use_tp_ds,
211 | use_tp,
212 | use_flash,
213 | do_sample
214 | ):
215 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")
216 |
217 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices)
218 |
219 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
220 | if use_pp:
221 | model, tokenizer = load_model(
222 | model_path,
223 | use_flash=use_flash,
224 | device=f"cuda",
225 | device_map="balanced",
226 | num_gpus=num_gpus_per_model,
227 | max_gpu_memory=max_gpu_memory,
228 | dtype=dtype,
229 | load_8bit=False,
230 | cpu_offloading=cpu_offloading,
231 | debug=debug,
232 | )
233 |
234 | elif use_tp_ds:
235 | import deepspeed
236 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0')))
237 | model, tokenizer = load_model(
238 | model_path,
239 | use_flash=use_flash,
240 | device_map="cpu",
241 | num_gpus=num_gpus_per_model,
242 | max_gpu_memory=max_gpu_memory,
243 | dtype=dtype,
244 | load_8bit=False,
245 | cpu_offloading=cpu_offloading,
246 | debug=debug,
247 | )
248 | model = deepspeed.init_inference(
249 | model,
250 | mp_size=int(os.getenv("WORLD_SIZE", "1")),
251 | dtype=torch.half
252 | )
253 | else:
254 | model, tokenizer = load_model(
255 | model_path,
256 | use_flash=use_flash,
257 | device=f"cuda:{lade.get_device()}",
258 | num_gpus=num_gpus_per_model,
259 | max_gpu_memory=max_gpu_memory,
260 | dtype=dtype,
261 | load_8bit=False,
262 | cpu_offloading=cpu_offloading,
263 | debug=debug,
264 | )
265 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device())
266 | model.tokenizer = tokenizer
267 |
268 | overall_time = 0
269 | overall_tp = 0
270 | overall_gen = 0
271 | count_gen = 0
272 | stats = {}
273 | for question_idx, question in enumerate(tqdm(questions)):
274 |
275 | stats[question_idx] = {} #
276 | choices = []
277 | for i in range(num_choices):
278 | torch.manual_seed(i)
279 | conv = get_conversation_template(model_id)
280 | turns = []
281 | prompts = []
282 |
283 | for j in range(1):
284 |
285 | prompt = f'''[INST] <>
286 | You are an intelligent chatbot. Answer the questions only using the following context:
287 |
288 | {question}
289 |
290 | Here are some rules you always follow:
291 |
292 | - Generate human readable output, avoid creating output with gibberish text.
293 | - Generate only the requested output, don't include any other language before or after the requested output.
294 | - Never say thank you, that you are happy to help, that you are an AI agent, etc. Just answer directly.
295 | - Generate professional language typically used in business documents in North America.
296 | - Never generate offensive or foul language.
297 |
298 | <>
299 |
300 | Briefly summarize the given context. [/INST]
301 | Summary: '''
302 |
303 | prompts.append(prompt)
304 |
305 | input_ids = tokenizer([prompt]).input_ids
306 |
307 | #print("len: ", len(input_ids[0]))
308 | if len(input_ids[0]) > 2048: #skip input len > 2048 tokens
309 | continue
310 |
311 | # some models may error out when generating long outputs
312 | if True:
313 | if do_sample:
314 | start_time = time.time()
315 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=True, top_k=0, temperature=1.0, top_p=1.0)
316 | end_time = time.time()
317 | else:
318 | start_time = time.time()
319 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=False, top_k=0)
320 | end_time = time.time()
321 |
322 | gap_time = end_time - start_time
323 | tokens = output_ids.numel() - len(input_ids[0])
324 | overall_time += gap_time
325 | overall_gen += tokens
326 | overall_tp += tokens / gap_time
327 | count_gen += 1
328 |
329 | stats[question_idx][j] = [gap_time, tokens]
330 | if lade.get_device() == 0 and ds_local_rank == 0:
331 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time])
332 |
333 | if model.config.is_encoder_decoder:
334 | output_ids = output_ids[0]
335 | else:
336 | output_ids = output_ids[0][len(input_ids[0]) :]
337 |
338 | # be consistent with the template's stop_token_ids
339 | if conv.stop_token_ids:
340 | stop_token_ids_index = [
341 | i
342 | for i, id in enumerate(output_ids)
343 | if id in conv.stop_token_ids
344 | ]
345 | if len(stop_token_ids_index) > 0:
346 | output_ids = output_ids[: stop_token_ids_index[0]]
347 |
348 | output = tokenizer.decode(
349 | output_ids,
350 | spaces_between_special_tokens=False,
351 | )
352 | if conv.stop_str and output.find(conv.stop_str) > 0:
353 | output = output[: output.find(conv.stop_str)]
354 | for special_token in tokenizer.special_tokens_map.values():
355 | if isinstance(special_token, list):
356 | for special_tok in special_token:
357 | output = output.replace(special_tok, "")
358 | else:
359 | output = output.replace(special_token, "")
360 | output = output.strip()
361 |
362 | if conv.name == "xgen" and output.startswith("Assistant:"):
363 | output = output.replace("Assistant:", "", 1).strip()
364 |
365 | #print("output: ", output)
366 | '''
367 | except RuntimeError as e:
368 | print("ERROR question ID: ", question["question_id"])
369 | output = "ERROR"
370 | '''
371 | turns.append(output)
372 |
373 |
374 | choices.append({"index": i, "turns": turns, "prompts" : prompts})
375 |
376 | if lade.get_device() == 0 and ds_local_rank == 0:
377 | # Dump answers
378 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
379 | with open(os.path.expanduser(answer_file), "a") as fout:
380 | ans_json = {
381 | "question_id": question_idx,
382 | "answer_id": shortuuid.uuid(),
383 | "model_id": model_id,
384 | "choices": choices,
385 | "tstamp": time.time(),
386 | }
387 | fout.write(json.dumps(ans_json) + "\n")
388 | #if question_idx == 1:
389 | # break
390 |
391 | if lade.get_device() == 0 and ds_local_rank == 0:
392 | torch.save(stats[question_idx], answer_file + ".pt")
393 | print("LOG SAVE TO ", answer_file + ".pt")
394 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}")
395 | lade.log_history()
396 | lade.save_log(answer_file + "-lade-log.pt")
397 |
398 |
399 | def reorg_answer_file(answer_file):
400 | """Sort by question id and de-duplication"""
401 | answers = {}
402 | with open(answer_file, "r") as fin:
403 | for l in fin:
404 | qid = json.loads(l)["question_id"]
405 | answers[qid] = l
406 |
407 | qids = sorted(list(answers.keys()))
408 | with open(answer_file, "w") as fout:
409 | for qid in qids:
410 | fout.write(answers[qid])
411 |
412 |
413 | if __name__ == "__main__":
414 | parser = argparse.ArgumentParser()
415 | parser.add_argument(
416 | "--model-path",
417 | type=str,
418 | required=True,
419 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
420 | )
421 | parser.add_argument(
422 | "--model-id", type=str, required=True, help="A custom name for the model."
423 | )
424 | parser.add_argument(
425 | "--cache-dir",
426 | type=str,
427 | default="",
428 | )
429 | parser.add_argument(
430 | "--debug",
431 | action="store_true",
432 | )
433 | parser.add_argument(
434 | "--bench-name",
435 | type=str,
436 | default="xsum",
437 | help="The name of the benchmark question set.",
438 | )
439 | parser.add_argument(
440 | "--question-begin",
441 | type=int,
442 | help="A debug option. The begin index of questions.",
443 | )
444 | parser.add_argument(
445 | "--question-end", type=int, help="A debug option. The end index of questions."
446 | )
447 | parser.add_argument(
448 | "--cpu_offloading", action="store_true"
449 | )
450 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
451 | parser.add_argument(
452 | "--max-new-token",
453 | type=int,
454 | default=1024,
455 | help="The maximum number of new generated tokens.",
456 | )
457 | parser.add_argument(
458 | "--num-choices",
459 | type=int,
460 | default=1,
461 | help="How many completion choices to generate.",
462 | )
463 | parser.add_argument(
464 | "--num-gpus-per-model",
465 | type=int,
466 | default=1,
467 | help="The number of GPUs per model.",
468 | )
469 | parser.add_argument(
470 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
471 | )
472 | parser.add_argument(
473 | "--max-gpu-memory",
474 | type=str,
475 | help="Maxmum GPU memory used for model weights per GPU.",
476 | )
477 | parser.add_argument(
478 | "--dtype",
479 | type=str,
480 | choices=["float32", "float64", "float16", "bfloat16"],
481 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
482 | default=None,
483 | )
484 | parser.add_argument(
485 | "--local_rank",
486 | type=int,
487 | default=0,
488 | )
489 | parser.add_argument(
490 | "--local-rank",
491 | type=int,
492 | default=0,
493 | )
494 | parser.add_argument(
495 | "--level",
496 | type=int,
497 | default=3,
498 | )
499 | parser.add_argument(
500 | "--window",
501 | type=int,
502 | default=10,
503 | )
504 | parser.add_argument(
505 | "--guess",
506 | type=int,
507 | default=10,
508 | )
509 | parser.add_argument(
510 | "--use-tp",
511 | type=int,
512 | default=0,
513 | )
514 | parser.add_argument(
515 | "--use-pp",
516 | type=int,
517 | default=0,
518 | )
519 | parser.add_argument(
520 | "--use-tp-ds",
521 | type=int,
522 | default=0,
523 | )
524 | parser.add_argument(
525 | "--use-flash",
526 | type=int,
527 | default=0,
528 | )
529 | parser.add_argument(
530 | "--do-sample",
531 | type=int,
532 | default=0,
533 | )
534 |
535 | args = parser.parse_args()
536 | if int(os.environ.get("USE_LADE", 0)):
537 |
538 | lade.augment_all()
539 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(",")))
540 | print("lade activated config: ", lade.decoding.CONFIG_MAP)
541 |
542 | question_file = f"mtbench.jsonl"
543 | if args.answer_file:
544 | answer_file = args.answer_file
545 | else:
546 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
547 |
548 | print(f"Output to {answer_file}")
549 |
550 | run_eval(
551 | model_path=args.model_path,
552 | model_id=args.model_id,
553 | question_file=question_file,
554 | question_begin=args.question_begin,
555 | question_end=args.question_end,
556 | answer_file=answer_file,
557 | max_new_token=args.max_new_token,
558 | num_choices=args.num_choices,
559 | num_gpus_per_model=args.num_gpus_per_model,
560 | num_gpus_total=args.num_gpus_total,
561 | max_gpu_memory=args.max_gpu_memory,
562 | dtype=str_to_torch_dtype(args.dtype),
563 | debug=args.debug,
564 | cache_dir=args.cache_dir,
565 | cpu_offloading=args.cpu_offloading,
566 | use_pp=args.use_pp,
567 | use_tp_ds=args.use_tp_ds,
568 | use_tp=args.use_tp,
569 | use_flash=args.use_flash,
570 | do_sample=args.do_sample
571 | )
572 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0'))
573 | if lade.get_device() == 0 and ds_local_rank == 0:
574 | reorg_answer_file(answer_file)
575 |
--------------------------------------------------------------------------------
/applications/run_chat.sh:
--------------------------------------------------------------------------------
1 | USE_LADE=1 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug
2 | USE_LADE=0 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug
3 |
4 | USE_LADE=1 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat
5 | USE_LADE=0 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat
6 |
--------------------------------------------------------------------------------
/applications/run_mtbench.sh:
--------------------------------------------------------------------------------
1 | #download data
2 | wget https://raw.githubusercontent.com/lm-sys/FastChat/v0.2.31/fastchat/llm_judge/data/mt_bench/question.jsonl -O mtbench.jsonl
3 |
4 | export CUDA=0
5 | export LADE=0
6 | export LEVEL=0
7 | export WIN=0
8 | export GUESS=0
9 | export FLASH=0
10 | export PP=0
11 | CUDA_VISIBLE_DEVICES=$CUDA USE_LADE=$LADE python eval_mtbench.py \
12 | --model-path meta-llama/Llama-2-7b-chat-hf --model-id \
13 | llama-2-7b-level-$LEVEL-win-$WIN-guess-$GUESS-f$FLASH-pp$CUDA \
14 | --level $LEVEL --window $WIN --guess $GUESS --use-flash $FLASH --use-pp $PP
15 |
16 | export CUDA=0
17 | export LADE=1
18 | export LEVEL=5
19 | export WIN=15
20 | export GUESS=15
21 | export FLASH=0
22 | export PP=0
23 | CUDA_VISIBLE_DEVICES=$CUDA USE_LADE=$LADE python eval_mtbench.py \
24 | --model-path meta-llama/Llama-2-7b-chat-hf --model-id \
25 | llama-2-7b-level-$LEVEL-win-$WIN-guess-$GUESS-f$FLASH-pp$CUDA \
26 | --level $LEVEL --window $WIN --guess $GUESS --use-flash $FLASH --use-pp $PP
27 |
28 | export GPUS=1
29 | export LEVEL=0
30 | export WIN=0
31 | export GUESS=0
32 | export FLASH=0
33 | deepspeed --num_gpus $GPUS eval_mtbench.py --model-path meta-llama/Llama-2-7b-chat-hf \
34 | --model-id llama-2-7b-level-$LEVEL-win-$WIN-guess-$GUESS-f$FLASH-ds$GPUS \
35 | --level $LEVEL --window $WIN --guess $GUESS --use-flash $FLASH --use-tp-ds 1
36 |
--------------------------------------------------------------------------------
/lade/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import augment_llama
2 | from .utils import augment_generate
3 | from .utils import augment_all
4 | from .utils import config_lade, save_log, log_history
5 | from .lade_distributed import *
6 |
--------------------------------------------------------------------------------
/lade/decoding.py:
--------------------------------------------------------------------------------
1 | from transformers import GenerationMixin
2 | import torch
3 | import copy
4 | import inspect
5 | import warnings
6 | from dataclasses import dataclass
7 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
8 | from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GreedySearchOutput, SampleOutput, TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper
9 | import torch.distributed as dist
10 | import os, time, random
11 | FUNC_MAP = {}
12 | CONFIG_MAP = {}
13 | COLOR_PRINT = int(os.environ.get("COLOR_PRINT", 0))
14 |
15 | def greedy_search_proxy(self, *args, **kwargs):
16 | USE_LADE = int(os.environ.get("USE_LADE", 0))
17 | CHAT = int(os.environ.get("CHAT", 0))
18 | if CHAT and USE_LADE:
19 | return jacobi_greedy_search_multilevel(self, chat=True, *args, **kwargs)
20 | elif CHAT:
21 | return greedy_search_chat(self, chat=True, *args, **kwargs)
22 |
23 | if USE_LADE:
24 | return jacobi_greedy_search_multilevel(self, chat=False, *args, **kwargs)
25 | else:
26 | return FUNC_MAP["greedy_search"](self, *args, **kwargs)
27 |
28 | def sample_proxy(self, *args, **kwargs):
29 | USE_LADE = int(os.environ.get("USE_LADE", 0))
30 |
31 | if USE_LADE:
32 | return jacobi_sample_multilevel(self, chat=int(os.environ.get("CHAT", 0)), *args, **kwargs)
33 | else:
34 | return FUNC_MAP["greedy_search"](self, *args, **kwargs)
35 |
36 |
37 | def update_token_map(token_map, lst_token, past_tokens, new_results, LEVEL, WINDOW_SIZE, GUESS_SET_SIZE):
38 | if GUESS_SET_SIZE != -1: #limited guess set size for each key, lru policy
39 | if lst_token not in token_map:
40 | token_map[lst_token] = []
41 | tup = tuple(past_tokens[ll][0] for ll in range(1, LEVEL - 1)) + (new_results[0],)
42 | if tup in token_map[lst_token]:
43 | token_map[lst_token].remove(tup)
44 | token_map[lst_token].append(tup)
45 | elif len(token_map[lst_token]) < GUESS_SET_SIZE:
46 | token_map[lst_token].append(tup)
47 | else:
48 | assert len(token_map[lst_token]) == GUESS_SET_SIZE
49 | token_map[lst_token] = token_map[lst_token][1:] + [tup]
50 |
51 | for i in range(1, WINDOW_SIZE):
52 | if past_tokens[0][i - 1] not in token_map:
53 | token_map[past_tokens[0][i - 1]] = []
54 | tup = tuple(past_tokens[ll][i] for ll in range(1, LEVEL - 1)) + (new_results[i],)
55 |
56 | if tup in token_map[past_tokens[0][i - 1]]:
57 | token_map[past_tokens[0][i - 1]].remove(tup)
58 | token_map[past_tokens[0][i - 1]].append(tup)
59 | elif len(token_map[past_tokens[0][i - 1]]) < GUESS_SET_SIZE:
60 | token_map[past_tokens[0][i - 1]].append(tup)
61 | else:
62 | assert len(token_map[past_tokens[0][i - 1]]) == GUESS_SET_SIZE
63 | token_map[past_tokens[0][i - 1]] = token_map[past_tokens[0][i - 1]][1:] + [tup]
64 |
65 | else: #unlimited guess set size for each key
66 | #first add
67 | if lst_token not in token_map:
68 | token_map[lst_token] = set()
69 | #build tuple
70 | tup = tuple(past_tokens[ll][0] for ll in range(1, LEVEL - 1)) + (new_results[0],)
71 | #add tuple
72 | token_map[lst_token].add(tup)
73 |
74 | for i in range(1, WINDOW_SIZE):
75 | if past_tokens[0][i - 1] not in token_map:
76 | token_map[past_tokens[0][i - 1]] = set()
77 | tup = tuple(past_tokens[ll][i] for ll in range(1, LEVEL - 1)) + (new_results[i],)
78 | token_map[past_tokens[0][i - 1]].add(tup)
79 |
80 | def append_new_generated_pool(tokens, token_map, LEVEL, GUESS_SET_SIZE):
81 | if len(tokens) != LEVEL:
82 | return
83 | lst_token = tokens[0]
84 | tup = tuple(tokens[1:])
85 |
86 | if GUESS_SET_SIZE != -1: #limited guess set size for each key, lru policy
87 | if lst_token not in token_map:
88 | token_map[lst_token] = []
89 | if tup in token_map[lst_token]:
90 | token_map[lst_token].remove(tup)
91 | token_map[lst_token].append(tup)
92 | elif len(token_map[lst_token]) < GUESS_SET_SIZE:
93 | token_map[lst_token].append(tup)
94 | else:
95 | assert len(token_map[lst_token]) == GUESS_SET_SIZE
96 | token_map[lst_token] = token_map[lst_token][1:] + [tup]
97 | else: #unlimited guess set size for each key
98 | #first add
99 | if lst_token not in token_map:
100 | token_map[lst_token] = set()
101 | token_map[lst_token].add(tup)
102 |
103 |
104 | def fill_pool_with_prompt(prompts, token_map, LEVEL, GUESS_SET_SIZE):
105 | for start_idx in range(len(prompts) - LEVEL + 1):
106 | lst_token = prompts[start_idx]
107 | tup = tuple(prompts[start_idx+1:start_idx+LEVEL])
108 |
109 | if len(tup) != LEVEL - 1:
110 | return
111 |
112 | if GUESS_SET_SIZE != -1: #limited guess set size for each key, lru policy
113 | if lst_token not in token_map:
114 | token_map[lst_token] = []
115 | if tup in token_map[lst_token]:
116 | token_map[lst_token].remove(tup)
117 | token_map[lst_token].append(tup)
118 | elif len(token_map[lst_token]) < GUESS_SET_SIZE:
119 | token_map[lst_token].append(tup)
120 | else:
121 | assert len(token_map[lst_token]) == GUESS_SET_SIZE
122 | token_map[lst_token] = token_map[lst_token][1:] + [tup]
123 | else: #unlimited guess set size for each key
124 | #first add
125 | if lst_token not in token_map:
126 | token_map[lst_token] = set()
127 | token_map[lst_token].add(tup)
128 |
129 |
130 |
131 | def filter_window(level_window, eos_token_id, reset_func):
132 |
133 | for idx in range(len(level_window)):
134 | if level_window[idx] == eos_token_id:
135 | level_window[idx] = reset_func()
136 |
137 | def jacobi_sample_multilevel(
138 | self,
139 | input_ids: torch.LongTensor,
140 | logits_processor: Optional[LogitsProcessorList] = None,
141 | stopping_criteria: Optional[StoppingCriteriaList] = None,
142 | logits_warper: Optional[LogitsProcessorList] = None,
143 | max_length: Optional[int] = None,
144 | pad_token_id: Optional[int] = None,
145 | eos_token_id: Optional[Union[int, List[int]]] = None,
146 | output_attentions: Optional[bool] = None,
147 | output_hidden_states: Optional[bool] = None,
148 | output_scores: Optional[bool] = None,
149 | return_dict_in_generate: Optional[bool] = None,
150 | synced_gpus: bool = False,
151 | streamer: Optional["BaseStreamer"] = None,
152 | chat: bool = False,
153 | **model_kwargs,
154 | ) -> Union[SampleOutput, torch.LongTensor]:
155 | r"""
156 | Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
157 | can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
158 |
159 |
160 |
161 | In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
162 | For an overview of generation strategies and code examples, check the [following
163 | guide](../generation_strategies).
164 |
165 |
166 |
167 | Parameters:
168 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
169 | The sequence used as a prompt for the generation.
170 | logits_processor (`LogitsProcessorList`, *optional*):
171 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
172 | used to modify the prediction scores of the language modeling head applied at each generation step.
173 | stopping_criteria (`StoppingCriteriaList`, *optional*):
174 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
175 | used to tell if the generation loop should stop.
176 | logits_warper (`LogitsProcessorList`, *optional*):
177 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
178 | to warp the prediction score distribution of the language modeling head applied before multinomial
179 | sampling at each generation step.
180 | max_length (`int`, *optional*, defaults to 20):
181 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
182 | tokens. The maximum length of the sequence to be generated.
183 | pad_token_id (`int`, *optional*):
184 | The id of the *padding* token.
185 | eos_token_id (`Union[int, List[int]]`, *optional*):
186 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
187 | output_attentions (`bool`, *optional*, defaults to `False`):
188 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
189 | returned tensors for more details.
190 | output_hidden_states (`bool`, *optional*, defaults to `False`):
191 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
192 | for more details.
193 | output_scores (`bool`, *optional*, defaults to `False`):
194 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
195 | return_dict_in_generate (`bool`, *optional*, defaults to `False`):
196 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
197 | synced_gpus (`bool`, *optional*, defaults to `False`):
198 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
199 | streamer (`BaseStreamer`, *optional*):
200 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed
201 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
202 | model_kwargs:
203 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
204 | an encoder-decoder model the kwargs should include `encoder_outputs`.
205 |
206 | Return:
207 | [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
208 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a
209 | [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
210 | `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
211 | `model.config.is_encoder_decoder=True`.
212 |
213 | Examples:
214 |
215 | ```python
216 | >>> from transformers import (
217 | ... AutoTokenizer,
218 | ... AutoModelForCausalLM,
219 | ... LogitsProcessorList,
220 | ... MinLengthLogitsProcessor,
221 | ... TopKLogitsWarper,
222 | ... TemperatureLogitsWarper,
223 | ... StoppingCriteriaList,
224 | ... MaxLengthCriteria,
225 | ... )
226 | >>> import torch
227 |
228 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
229 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
230 |
231 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
232 | >>> model.config.pad_token_id = model.config.eos_token_id
233 | >>> model.generation_config.pad_token_id = model.config.eos_token_id
234 |
235 | >>> input_prompt = "Today is a beautiful day, and"
236 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
237 |
238 | >>> # instantiate logits processors
239 | >>> logits_processor = LogitsProcessorList(
240 | ... [
241 | ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
242 | ... ]
243 | ... )
244 | >>> # instantiate logits processors
245 | >>> logits_warper = LogitsProcessorList(
246 | ... [
247 | ... TopKLogitsWarper(50),
248 | ... TemperatureLogitsWarper(0.7),
249 | ... ]
250 | ... )
251 |
252 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
253 |
254 | >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
255 | >>> outputs = model.sample(
256 | ... input_ids,
257 | ... logits_processor=logits_processor,
258 | ... logits_warper=logits_warper,
259 | ... stopping_criteria=stopping_criteria,
260 | ... )
261 |
262 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
263 | ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.']
264 | ```"""
265 | # init values
266 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
267 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
268 | if max_length is not None:
269 | warnings.warn(
270 | "`max_length` is deprecated in this function, use"
271 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
272 | UserWarning,
273 | )
274 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
275 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
276 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
277 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
278 | if isinstance(eos_token_id, int):
279 | eos_token_id = [eos_token_id]
280 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
281 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
282 | output_attentions = (
283 | output_attentions if output_attentions is not None else self.generation_config.output_attentions
284 | )
285 | output_hidden_states = (
286 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
287 | )
288 | return_dict_in_generate = (
289 | return_dict_in_generate
290 | if return_dict_in_generate is not None
291 | else self.generation_config.return_dict_in_generate
292 | )
293 |
294 | # init attention / hidden states / scores tuples
295 | scores = () if (return_dict_in_generate and output_scores) else None
296 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
297 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None
298 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
299 |
300 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
301 | if return_dict_in_generate and self.config.is_encoder_decoder:
302 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
303 | encoder_hidden_states = (
304 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
305 | )
306 |
307 | # keep track of which sequences are already finished
308 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
309 |
310 | this_peer_finished = False # used by synced_gpus only
311 |
312 | WINDOW_SIZE = CONFIG_MAP.get("WINDOW_SIZE", 60)
313 | GUESS_SET_SIZE = CONFIG_MAP.get("GUESS_SET_SIZE", 60)
314 | ALWAYS_FWD_ONE = CONFIG_MAP.get("ALWAYS_FWD_ONE", 1)
315 | LEVEL = CONFIG_MAP.get("LEVEL", 8)
316 | DEBUG = CONFIG_MAP.get("DEBUG", 0)
317 | DIST_WORKERS = CONFIG_MAP.get("DIST_WORKERS", 1)
318 | LOCAL_RANK = CONFIG_MAP.get("LOCAL_RANK", 0)
319 | USE_FLASH = CONFIG_MAP.get("USE_FLASH", 0) #not use flash by default
320 | POOL_FROM_PROMPT = CONFIG_MAP.get("POOL_FROM_PROMPT", 0)
321 | USE_AWQ = False #not support AWQ
322 | #IN FLASH ATTENTION WE REORDERED LOOKAHEAD WINDOW
323 |
324 | GUESS_SIZE = LEVEL - 1
325 | NOT_SEQ = 0
326 | CONTINUE_ALL = 0
327 | TEMP_FOR_GUESS = 0.0
328 |
329 | assert TEMP_FOR_GUESS == 0
330 | #assert LEVEL <= 8
331 | def random_set():
332 | return random.randint(0,self.vocab_size - 1)
333 |
334 | all_old_tokens = input_ids[0].tolist()
335 | init_len = len(all_old_tokens)
336 | #print("original: ", init_len, input_ids.numel())
337 |
338 | def copy_from():
339 | return random.choice(all_old_tokens)
340 |
341 | order_copy_from_idx = [0]
342 |
343 | def order_copy_from():
344 | if order_copy_from_idx[0] >= len(all_old_tokens):
345 | order_copy_from_idx[0] = 0
346 | ret = all_old_tokens[order_copy_from_idx[0]]
347 | order_copy_from_idx[0] = 1 + order_copy_from_idx[0]
348 | return ret
349 |
350 | def copy_from_last():
351 | return all_old_tokens[-1]
352 |
353 | set_token = copy_from
354 |
355 | past_tokens = [[set_token() for _ in range(WINDOW_SIZE + LEVEL - 3)]] + [None for _ in range(LEVEL - 2)]
356 |
357 | if DIST_WORKERS > 1:
358 | dist.broadcast_object_list(past_tokens, src=0) #keep past_tokens always the same on different GPUs
359 |
360 | ###############end Init methods
361 | fill_level = 0
362 | guess_tokens = None
363 | token_map = {}
364 | steps = 0
365 | guess_skip_dist = 0
366 |
367 | if POOL_FROM_PROMPT:
368 | fill_pool_with_prompt(all_old_tokens, token_map, LEVEL, GUESS_SET_SIZE)
369 |
370 | if chat:
371 | init = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
372 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
373 | prev = len(init)
374 |
375 | for warper in logits_warper:
376 | #assert type(warper) == TemperatureLogitsWarper or type(warper) == TopPLogitsWarper or type(warper) == TopKLogitsWarper, f"please set top_k=0 {warper}"
377 | assert type(warper) == TemperatureLogitsWarper or type(warper) == TopKLogitsWarper or type(warper) == TopPLogitsWarper, f"please set top_k=0.0 and top_p=1.0 {warper}"
378 |
379 | # auto-regressive generation
380 | while True:
381 | if synced_gpus:
382 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
383 | # The following logic allows an early break if all peers finished generating their sequence
384 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
385 | # send 0.0 if we finished, 1.0 otherwise
386 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
387 | # did all peers finish? the reduced sum will be 0.0 then
388 | if this_peer_finished_flag.item() == 0.0:
389 | break
390 |
391 | # prepare model inputs
392 | #this only support llama, check compatibility with other models
393 | past_key_values = model_kwargs.pop("past_key_values", None)
394 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
395 | if past_key_values is None:
396 | model_inputs["input_ids"] = input_ids
397 | else:
398 | model_inputs["input_ids"] = model_inputs["input_ids"][:, -1 - guess_skip_dist:]
399 | model_inputs["position_ids"] = model_inputs["position_ids"][:, -1 - guess_skip_dist:]
400 | model_inputs["past_key_values"] = past_key_values
401 |
402 | if past_tokens[LEVEL - 2] is not None and lst_token in token_map and GUESS_SET_SIZE > 0:
403 | guess_tokens_ = token_map[lst_token]
404 | guess_tokens = []
405 | for tok in list(guess_tokens_):
406 | guess_tokens += list(tok)
407 | else:
408 | guess_tokens = None
409 |
410 | #not support logits_processor yet
411 | assert return_dict_in_generate == False
412 | assert len(logits_processor) == 0
413 |
414 | # forward pass to get next token
415 | outputs = self.jforward_multilevel(
416 | **model_inputs,
417 | past_tokens=past_tokens,
418 | guess_tokens=guess_tokens,
419 | return_dict=True,
420 | not_seq=NOT_SEQ,
421 | continue_all=CONTINUE_ALL,
422 | output_attentions=output_attentions,
423 | output_hidden_states=output_hidden_states,
424 | level=LEVEL,
425 | WINDOWS_SIZE=WINDOW_SIZE,
426 | guess_size=GUESS_SIZE,
427 | fill_level=fill_level,
428 | dist_workers=DIST_WORKERS,
429 | la_mask_offset=0,
430 | local_rank=LOCAL_RANK,
431 | use_flash=USE_FLASH
432 | )
433 |
434 | steps += 1
435 |
436 | if synced_gpus and this_peer_finished:
437 | continue # don't waste resources running the code we don't need
438 |
439 | next_token_logits = outputs.out_logits #outputs.logits[:, -1, :]
440 |
441 | #not support logits_processor and only support temperature w/o top-p top-k, I will support these two later
442 | # pre-process distribution
443 | next_token_scores = logits_warper(input_ids, next_token_logits)
444 |
445 | #delete return_dict_in_generate here, we set it to False
446 | # Store scores, attentions and hidden_states when required
447 |
448 | # finished sentences should have their next token be a padding token
449 | #if eos_token_id is not None:
450 | # if pad_token_id is None:
451 | # raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
452 | # next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
453 | #for bs > 1, so I comment these out
454 |
455 | #handling output
456 | max_hit = 0
457 |
458 | if past_tokens[1] is None:
459 | #first fill, not use verification branch
460 | assert fill_level == 0
461 | probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
462 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
463 | hits = [next_tokens.item()]
464 |
465 | past_tokens[0] = past_tokens[0][1:]
466 | past_tokens[1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist() #fill window with argmax
467 |
468 | fill_level += 1
469 | elif past_tokens[LEVEL - 2] is None:
470 | #fill other levels, not use verification branch
471 | probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
472 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
473 | hits = [next_tokens.item()]
474 |
475 | for level in range(fill_level + 1):
476 | past_tokens[level] = past_tokens[level][1:]
477 |
478 | past_tokens[fill_level + 1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()[1:] #fill window with argmax
479 |
480 | fill_level += 1
481 | else:
482 |
483 |
484 | if guess_tokens is not None:
485 | probs_next = torch.nn.functional.softmax(next_token_scores, dim=-1)[0]
486 | hits = []
487 | #= original model output
488 | guess_logits = logits_warper(input_ids, outputs.guess_logits[0])
489 | guess_probs = torch.nn.functional.softmax(guess_logits, dim=-1) #
490 | #guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
491 | guess_indices = list(range(outputs.guess_logits.size(1) // GUESS_SIZE))
492 | #algorithm modified from specinfer
493 | for idx_in_ngram in range(GUESS_SIZE):
494 |
495 | g_idx = 0
496 | is_accept = False
497 | #print("gues: ", guess_indices)
498 |
499 | while g_idx < len(guess_indices):
500 | guess_idx = guess_indices[g_idx]
501 | guess_offset = guess_idx * GUESS_SIZE
502 |
503 |
504 | #draft_guess is draft model (by lookahead) generation
505 | draft_guess = guess_tokens[guess_offset + idx_in_ngram]
506 | prob_accept = min(1, probs_next[draft_guess].item()) #min(1, prob_llm/prob_draft) #use argmax, prob_draft is 1
507 | sample_prob = random.random()
508 |
509 | if sample_prob < prob_accept:
510 | #accept
511 | hits.append(draft_guess)
512 | is_accept = True
513 | max_hit_idx = guess_idx
514 | new_guess_indices = []
515 | for guess_idx_n in guess_indices:
516 | guess_offset_n = guess_idx_n * GUESS_SIZE
517 | new_draft_guess = guess_tokens[guess_offset_n + idx_in_ngram]
518 | if new_draft_guess == draft_guess:
519 | new_guess_indices.append(guess_idx_n)
520 | guess_indices = new_guess_indices
521 | break
522 | else:
523 | #not accept
524 | #max norm (argmax)
525 | probs_next[draft_guess] = 0
526 | probs_next = probs_next / probs_next.sum()
527 | g_idx += 1
528 |
529 | if is_accept:
530 | probs_next = guess_probs[guess_offset + idx_in_ngram]
531 | continue
532 | else:
533 | new_token_gen = torch.multinomial(probs_next, num_samples=1).item()
534 | #print("non accept: ", probs_next.size(), new_token_gen)
535 | hits.append(new_token_gen)
536 | break
537 |
538 | #hits.append(new_token_gen)
539 |
540 | max_hit = len(hits) - 1
541 |
542 | else:
543 | probs_next = torch.nn.functional.softmax(next_token_scores, dim=-1)
544 | next_tokens = torch.multinomial(probs_next, num_samples=1).squeeze(1)
545 | hits = [next_tokens.item()]
546 |
547 |
548 | #new window level, use argmax to generate
549 | new_results = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
550 |
551 | assert len(past_tokens[LEVEL - 2]) == WINDOW_SIZE and len(new_results) == WINDOW_SIZE
552 |
553 | update_token_map(token_map, lst_token, past_tokens, new_results, LEVEL, WINDOW_SIZE, GUESS_SET_SIZE)
554 |
555 | #update windows when max_hit > 1
556 | if ALWAYS_FWD_ONE:
557 | past_tokens[0] = past_tokens[1][1:]
558 | for level in range(1, LEVEL - 2):
559 | past_tokens[level] = past_tokens[level + 1][:]
560 |
561 | past_tokens[LEVEL - 2] = new_results
562 | else:
563 | past_tokens[0] = past_tokens[1][1 + max_hit:]
564 | for level in range(1, LEVEL - 2):
565 | past_tokens[level] = past_tokens[level + 1][max_hit:]
566 |
567 | past_tokens[LEVEL - 2] = new_results[max_hit:]
568 |
569 |
570 | if max_hit > 0:
571 | if not ALWAYS_FWD_ONE:
572 | for level in range(LEVEL - 1):
573 | past_tokens[level] = past_tokens[level] + [set_token() for _ in range(max_hit)]
574 |
575 | attention_mask = model_kwargs["attention_mask"]
576 | model_kwargs["attention_mask"] = torch.cat((attention_mask, torch.ones(1, max_hit, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1)
577 |
578 | if eos_token_id is not None:
579 | #filter (we find too many in window lead to numerical error)
580 | filter_window(past_tokens[LEVEL - 2], eos_token_id[0], set_token)
581 |
582 | #update kv cache of correctly speculated tokens
583 | past_key_values = []
584 | for idx, kv in enumerate(outputs.past_key_values):
585 | for hh in range(max_hit):
586 | assert outputs.step_len == kv[0].size(2)
587 | kv[0][:,:,outputs.kvcache_len + hh,:] = kv[0][:,:,outputs.step_len-len(guess_tokens)+max_hit_idx * GUESS_SIZE + hh,:]
588 | kv[1][:,:,outputs.kvcache_len + hh,:] = kv[1][:,:,outputs.step_len-len(guess_tokens)+max_hit_idx * GUESS_SIZE + hh,:]
589 | past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) )
590 | outputs.past_key_values = past_key_values
591 |
592 | lst_token = hits[max_hit]
593 |
594 | for hit_ids in range(max_hit + 1):
595 | if eos_token_id is not None and hits[hit_ids] == eos_token_id[0]:
596 | all_old_tokens.append(hits[hit_ids])
597 | next_tokens = eos_token_id_tensor
598 | max_hit = hit_ids
599 | break
600 | else:
601 | all_old_tokens.append(hits[hit_ids])
602 | if POOL_FROM_PROMPT:
603 | append_new_generated_pool(all_old_tokens[-LEVEL:], token_map, LEVEL, GUESS_SET_SIZE)
604 |
605 | if chat:
606 |
607 | all_str = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
608 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
609 | if COLOR_PRINT:
610 | from termcolor import colored
611 | if max_hit > 1:
612 | not_hit = self.tokenizer.decode(all_old_tokens[:-max_hit + 1], skip_special_tokens=True, \
613 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
614 | pt = colored(not_hit[prev:],"blue") + colored(all_str[len(not_hit):], "blue")
615 | else:
616 | pt = all_str[prev:]
617 | print(pt, flush=True, end="")
618 | else:
619 | print(all_str[prev:], flush=True, end="")
620 | prev = len(all_str)
621 |
622 | # update generated ids, model inputs, and length for next step
623 | input_ids = torch.cat([input_ids, torch.tensor(hits[:max_hit + 1], device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)], dim=-1)
624 |
625 | #input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
626 |
627 |
628 | ###not change codes below
629 | if streamer is not None:
630 | streamer.put(next_tokens.cpu())
631 | model_kwargs = self._update_model_kwargs_for_generation(
632 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
633 | )
634 |
635 | # if eos_token was found in one sentence, set sentence to finished
636 | if eos_token_id_tensor is not None:
637 | unfinished_sequences = unfinished_sequences.mul(
638 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
639 | )
640 |
641 | # stop when each sentence is finished
642 | if unfinished_sequences.max() == 0:
643 | this_peer_finished = True
644 |
645 | # stop if we exceed the maximum length
646 | if stopping_criteria(input_ids, scores):
647 | this_peer_finished = True
648 |
649 | if this_peer_finished and not synced_gpus:
650 | break
651 |
652 | #if predict more tokens than max_length, remove them
653 | for criteria in stopping_criteria:
654 | if hasattr(criteria, "max_length"):
655 | all_old_tokens = all_old_tokens[:criteria.max_length]
656 | input_ids = input_ids[:,:criteria.max_length]
657 |
658 | if max_length is not None:
659 | all_old_tokens = all_old_tokens[:init_len + max_length]
660 | input_ids = input_ids[:][:init_len + max_length]
661 | #end handling
662 | if DEBUG and LOCAL_RANK == 0:
663 | print("\n==========================ACCELERATION===SUMMARY======================================")
664 | print("Generated tokens: ", len(all_old_tokens) - init_len, "Total steps: ", steps, " Compression ratio: ", round((len(all_old_tokens) - init_len) / steps, 2))
665 | print("======================================================================================", end="")
666 | CONFIG_MAP["log"].append([len(all_old_tokens) - init_len, steps, round((len(all_old_tokens) - init_len) / steps, 2)])
667 |
668 | if streamer is not None:
669 | streamer.end()
670 |
671 | if return_dict_in_generate:
672 | if self.config.is_encoder_decoder:
673 | return SampleEncoderDecoderOutput(
674 | sequences=input_ids,
675 | scores=scores,
676 | encoder_attentions=encoder_attentions,
677 | encoder_hidden_states=encoder_hidden_states,
678 | decoder_attentions=decoder_attentions,
679 | cross_attentions=cross_attentions,
680 | decoder_hidden_states=decoder_hidden_states,
681 | past_key_values=model_kwargs.get("past_key_values"),
682 | )
683 | else:
684 | return SampleDecoderOnlyOutput(
685 | sequences=input_ids,
686 | scores=scores,
687 | attentions=decoder_attentions,
688 | hidden_states=decoder_hidden_states,
689 | past_key_values=model_kwargs.get("past_key_values"),
690 | )
691 | else:
692 | return input_ids
693 |
694 |
695 |
696 |
697 | def jacobi_greedy_search_multilevel(
698 | self,
699 | input_ids: torch.LongTensor,
700 | logits_processor: Optional[LogitsProcessorList] = None,
701 | stopping_criteria: Optional[StoppingCriteriaList] = None,
702 | max_length: Optional[int] = None,
703 | pad_token_id: Optional[int] = None,
704 | eos_token_id: Optional[Union[int, List[int]]] = None,
705 | output_attentions: Optional[bool] = None,
706 | output_hidden_states: Optional[bool] = None,
707 | output_scores: Optional[bool] = None,
708 | return_dict_in_generate: Optional[bool] = None,
709 | synced_gpus: bool = False,
710 | streamer: Optional["BaseStreamer"] = None,
711 |
712 | chat: bool = False,
713 | stop_token: Optional[str]= None,
714 | **model_kwargs,
715 | ) -> Union[GreedySearchOutput, torch.LongTensor]:
716 | r"""
717 | Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
718 | used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
719 |
720 |
721 |
722 | In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
723 | instead. For an overview of generation strategies and code examples, check the [following
724 | guide](../generation_strategies).
725 |
726 |
727 |
728 |
729 | Parameters:
730 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
731 | The sequence used as a prompt for the generation.
732 | logits_processor (`LogitsProcessorList`, *optional*):
733 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
734 | used to modify the prediction scores of the language modeling head applied at each generation step.
735 | stopping_criteria (`StoppingCriteriaList`, *optional*):
736 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
737 | used to tell if the generation loop should stop.
738 |
739 | max_length (`int`, *optional*, defaults to 20):
740 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
741 | tokens. The maximum length of the sequence to be generated.
742 | pad_token_id (`int`, *optional*):
743 | The id of the *padding* token.
744 | eos_token_id (`Union[int, List[int]]`, *optional*):
745 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
746 | output_attentions (`bool`, *optional*, defaults to `False`):
747 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
748 | returned tensors for more details.
749 | output_hidden_states (`bool`, *optional*, defaults to `False`):
750 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
751 | for more details.
752 | output_scores (`bool`, *optional*, defaults to `False`):
753 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
754 | return_dict_in_generate (`bool`, *optional*, defaults to `False`):
755 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
756 | synced_gpus (`bool`, *optional*, defaults to `False`):
757 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
758 | streamer (`BaseStreamer`, *optional*):
759 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed
760 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
761 | model_kwargs:
762 | Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
763 | If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
764 |
765 | Return:
766 | [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
767 | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
768 | [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
769 | `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
770 | `model.config.is_encoder_decoder=True`.
771 |
772 | Examples:
773 |
774 | ```python
775 | >>> from transformers import (
776 | ... AutoTokenizer,
777 | ... AutoModelForCausalLM,
778 | ... LogitsProcessorList,
779 | ... MinLengthLogitsProcessor,
780 | ... StoppingCriteriaList,
781 | ... MaxLengthCriteria,
782 | ... )
783 |
784 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
785 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
786 |
787 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
788 | >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
789 |
790 | >>> input_prompt = "It might be possible to"
791 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
792 |
793 | >>> # instantiate logits processors
794 | >>> logits_processor = LogitsProcessorList(
795 | ... [
796 | ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
797 | ... ]
798 | ... )
799 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
800 |
801 | >>> outputs = model.greedy_search(
802 | ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
803 | ... )
804 |
805 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
806 | ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
807 | ```"""
808 | # init values
809 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
810 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
811 | if max_length is not None:
812 | warnings.warn(
813 | "`max_length` is deprecated in this function, use"
814 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
815 | UserWarning,
816 | )
817 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
818 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
819 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
820 | if isinstance(eos_token_id, int):
821 | eos_token_id = [eos_token_id]
822 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
823 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
824 | output_attentions = (
825 | output_attentions if output_attentions is not None else self.generation_config.output_attentions
826 | )
827 | output_hidden_states = (
828 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
829 | )
830 | return_dict_in_generate = (
831 | return_dict_in_generate
832 | if return_dict_in_generate is not None
833 | else self.generation_config.return_dict_in_generate
834 | )
835 |
836 | # init attention / hidden states / scores tuples
837 | scores = () if (return_dict_in_generate and output_scores) else None
838 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
839 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None
840 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
841 |
842 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
843 | if return_dict_in_generate and self.config.is_encoder_decoder:
844 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
845 | encoder_hidden_states = (
846 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
847 | )
848 |
849 | # keep track of which sequences are already finished
850 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
851 |
852 | this_peer_finished = False # used by synced_gpus only
853 | ############### configurations
854 | WINDOW_SIZE = CONFIG_MAP.get("WINDOW_SIZE", 60)
855 | GUESS_SET_SIZE = CONFIG_MAP.get("GUESS_SET_SIZE", 60)
856 | ALWAYS_FWD_ONE = CONFIG_MAP.get("ALWAYS_FWD_ONE", 1)
857 | LEVEL = CONFIG_MAP.get("LEVEL", 8)
858 | DEBUG = CONFIG_MAP.get("DEBUG", 0)
859 | DIST_WORKERS = CONFIG_MAP.get("DIST_WORKERS", 1)
860 | LOCAL_RANK = CONFIG_MAP.get("LOCAL_RANK", 0)
861 | USE_FLASH = CONFIG_MAP.get("USE_FLASH", 0) #not use flash by default
862 | POOL_FROM_PROMPT = CONFIG_MAP.get("POOL_FROM_PROMPT", 0)
863 | USE_AWQ = False #not support AWQ
864 | #IN FLASH ATTENTION WE REORDERED LOOKAHEAD WINDOW
865 |
866 | GUESS_SIZE = LEVEL - 1
867 | NOT_SEQ = 0
868 | CONTINUE_ALL = 0
869 | TEMP_FOR_GUESS = 0.0
870 | USE_AWQ = False
871 | import random
872 | assert TEMP_FOR_GUESS == 0
873 | assert ALWAYS_FWD_ONE == 1
874 | assert USE_AWQ == False
875 |
876 | ############### Init methods
877 | #random.seed(10) #unset this random seed later
878 |
879 | all_old_tokens = input_ids[0].tolist()
880 | init_len = len(all_old_tokens)
881 | order_copy_from_idx = [0]
882 |
883 |
884 | def random_set():
885 | return random.randint(0,self.vocab_size - 1)
886 |
887 | def copy_from():
888 | return random.choice(all_old_tokens)
889 |
890 | def order_copy_from():
891 | if order_copy_from_idx[0] >= len(all_old_tokens):
892 | order_copy_from_idx[0] = 0
893 | ret = all_old_tokens[order_copy_from_idx[0]]
894 | order_copy_from_idx[0] = 1 + order_copy_from_idx[0]
895 | return ret
896 |
897 | def copy_from_last():
898 | return all_old_tokens[-1]
899 |
900 | set_token = copy_from
901 |
902 | past_tokens = [[set_token() for _ in range(WINDOW_SIZE + LEVEL - 3)]] + [None for _ in range(LEVEL - 2)]
903 | #past_tokens is the lookahead window. Current we initialize it with random copy from prompts
904 |
905 | if DIST_WORKERS > 1:
906 | dist.broadcast_object_list(past_tokens, src=0) #keep past_tokens always the same on different GPUs
907 |
908 | ###############end Init methods
909 | fill_level = 0
910 | guess_tokens = None
911 | token_map = {}
912 | steps = 0
913 | guess_skip_dist = 0
914 |
915 | if POOL_FROM_PROMPT:
916 | fill_pool_with_prompt(all_old_tokens, token_map, LEVEL, GUESS_SET_SIZE)
917 |
918 | if chat:
919 | init = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
920 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
921 | prev = len(init)
922 |
923 | while True:
924 | if synced_gpus:
925 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
926 | # The following logic allows an early break if all peers finished generating their sequence
927 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
928 | # send 0.0 if we finished, 1.0 otherwise
929 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
930 | # did all peers finish? the reduced sum will be 0.0 then
931 | if this_peer_finished_flag.item() == 0.0:
932 | break
933 |
934 | # prepare model inputs
935 | #this only support llama, check compatibility with other models
936 | past_key_values = model_kwargs.pop("past_key_values", None)
937 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
938 | if past_key_values is None:
939 | model_inputs["input_ids"] = input_ids
940 | else:
941 | model_inputs["input_ids"] = model_inputs["input_ids"][:, -1 - guess_skip_dist:]
942 | model_inputs["position_ids"] = model_inputs["position_ids"][:, -1 - guess_skip_dist:]
943 | model_inputs["past_key_values"] = past_key_values
944 |
945 | ori_guess = None
946 | #set up guess_tokens for verification branch
947 | # past_tokens[LEVEL - 2] is not None means we are still in warmup stage filling multi-level window
948 | if past_tokens[LEVEL - 2] is not None and lst_token in token_map and GUESS_SET_SIZE > 0:
949 | ###############NOT ENTER CURRENTLY
950 | guess_tokens_ = token_map[lst_token]
951 | guess_tokens = []
952 | for tok in list(guess_tokens_):
953 | guess_tokens += list(tok)
954 | ori_guess = guess_tokens
955 | #shards guess_tokens on different GPUs
956 | if DIST_WORKERS > 1 and guess_tokens is not None:
957 | assert len(guess_tokens) % GUESS_SIZE == 0
958 | cnt_guess = (len(guess_tokens) // GUESS_SIZE + DIST_WORKERS - 1) // DIST_WORKERS
959 | guess_base = cnt_guess * LOCAL_RANK
960 | guess_end = cnt_guess * (LOCAL_RANK + 1)
961 | guess_tokens = guess_tokens[GUESS_SIZE * guess_base: GUESS_SIZE * guess_end]
962 | if len(guess_tokens) == 0:
963 | guess_tokens = None
964 | else:
965 | guess_tokens = None
966 |
967 | assert return_dict_in_generate == False
968 | assert len(logits_processor) == 0
969 | # forward pass to get next token
970 | #if LOCAL_RANK == 0:
971 | # print("position: ", model_inputs["input_ids"], model_inputs["position_ids"], ori_guess, guess_tokens)
972 | #forward
973 | if DIST_WORKERS > 1:
974 | window_len = len(past_tokens[0]) + 1
975 | split_window_len = (window_len + DIST_WORKERS - 1) // DIST_WORKERS
976 | window_start = min(split_window_len * LOCAL_RANK, window_len)
977 | window_end = min(split_window_len * (LOCAL_RANK + 1), window_len)
978 |
979 | if LOCAL_RANK == DIST_WORKERS - 1:
980 | assert len(past_tokens[0]) == window_end - 1
981 | past_tokens_inp = [past_tokens[0][: window_end - 1]]
982 | for l in range(1, len(past_tokens)):
983 | tokens = past_tokens[l]
984 | past_tokens_inp.append(tokens[window_start: window_end] if tokens is not None else None)
985 | else:
986 | past_tokens_inp = past_tokens
987 |
988 | outputs = self.jforward_multilevel(
989 | **model_inputs,
990 | past_tokens=past_tokens_inp,
991 | guess_tokens=guess_tokens,
992 | return_dict=True,
993 | not_seq=NOT_SEQ,
994 | continue_all=CONTINUE_ALL,
995 | output_attentions=output_attentions,
996 | output_hidden_states=output_hidden_states,
997 | level=LEVEL,
998 | WINDOWS_SIZE=WINDOW_SIZE,
999 | guess_size=GUESS_SIZE,
1000 | fill_level=fill_level,
1001 | dist_workers=DIST_WORKERS,
1002 | la_mask_offset=0,
1003 | local_rank=LOCAL_RANK,
1004 | use_flash=USE_FLASH
1005 | )
1006 |
1007 | steps += 1
1008 |
1009 | if synced_gpus and this_peer_finished:
1010 | continue # don't waste resources running the code we don't need
1011 |
1012 | if past_tokens[LEVEL - 2] is None: #prefill
1013 | next_token_logits = outputs.out_logits
1014 | else:
1015 | next_token_logits = outputs.out_logits #outputs.logits[:, -1, :]
1016 |
1017 | # pre-process distribution
1018 | #next_tokens_scores = logits_processor(input_ids, next_token_logits)
1019 | next_tokens_scores = next_token_logits
1020 | # argmax
1021 | next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1022 |
1023 | if DIST_WORKERS > 1:
1024 | torch.distributed.broadcast(next_tokens, src=0)
1025 |
1026 | # finished sentences should have their next token be a padding token
1027 | if eos_token_id is not None:
1028 | if pad_token_id is None:
1029 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1030 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1031 |
1032 | first_guess = next_tokens.item()
1033 | max_hit = 0
1034 | hits = [first_guess] + [0] * (GUESS_SIZE - 1)
1035 |
1036 | new_results = []
1037 |
1038 | if past_tokens[1] is None: #filling multi-level window, the very first step is different
1039 | assert fill_level == 0
1040 | past_tokens[0] = past_tokens[0][1:]
1041 | past_tokens[1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
1042 |
1043 | if DIST_WORKERS > 1:
1044 | nn_past_tokens = [copy.deepcopy(past_tokens[1])]
1045 | torch.distributed.broadcast_object_list(nn_past_tokens, src=DIST_WORKERS - 1)
1046 | past_tokens[1] = nn_past_tokens[0]
1047 |
1048 | fill_level += 1
1049 | elif past_tokens[LEVEL - 2] is None: #filling multi-level window
1050 | for level in range(fill_level + 1):
1051 | past_tokens[level] = past_tokens[level][1:]
1052 | current_past_tokens = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
1053 |
1054 |
1055 | if DIST_WORKERS > 1:
1056 | nn_past_tokens = [None] * DIST_WORKERS
1057 | torch.distributed.all_gather_object(nn_past_tokens, current_past_tokens)
1058 | current_past_tokens = sum(nn_past_tokens, [])
1059 |
1060 |
1061 | #time.sleep(10000)
1062 | past_tokens[fill_level + 1] = current_past_tokens[1:]
1063 | #print("new past: ", (LOCAL_RANK, past_tokens))
1064 |
1065 |
1066 | fill_level += 1
1067 | else:
1068 | #time.sleep(10000)
1069 | #multi-level window is filled
1070 | #match guess tokens
1071 | if guess_tokens is not None:
1072 | guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
1073 | for eg in range(len(guess_results) // GUESS_SIZE):
1074 | egx = eg * GUESS_SIZE
1075 | correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
1076 | myguess = guess_tokens[egx:egx + GUESS_SIZE]
1077 | gg = 0
1078 | for gg in range(len(myguess)):
1079 | if myguess[gg] != correct[gg]:
1080 | break
1081 | if gg > max_hit:
1082 | max_hit = gg
1083 | max_hit_idx = eg
1084 | hits[:max_hit + 1] = correct[:max_hit + 1]
1085 | #max_hit is the length of longest accepted sequence in verification branch
1086 |
1087 | #sync max_hit if we have multi-GPUs
1088 | if DIST_WORKERS > 1:
1089 | max_hit_all_ranks = [0] * DIST_WORKERS
1090 | torch.distributed.all_gather_object(max_hit_all_ranks, max_hit)
1091 | max_hit = max(max_hit_all_ranks)
1092 | max_hit_rank = max_hit_all_ranks.index(max_hit)
1093 |
1094 | if max_hit > 0:
1095 | hit_info = [hits]
1096 | torch.distributed.broadcast_object_list(hit_info, src=max_hit_rank)
1097 | hits = hit_info[0]
1098 | #print("rank: ", [hits, torch.distributed.get_rank(), max_hit, LOCAL_RANK, max_hit_rank])
1099 | #if LOCAL_RANK == 0:
1100 | # print("rank: ",hits, max_hit)
1101 | #sync new_results
1102 | new_results = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()
1103 |
1104 | if DIST_WORKERS > 1:
1105 | nn_past_tokens = [None] * DIST_WORKERS
1106 | torch.distributed.all_gather_object(nn_past_tokens, new_results)
1107 | new_results = sum(nn_past_tokens, [])
1108 | #else:
1109 | # current_past_tokens = new_results
1110 | #print("brand new past: ", (LOCAL_RANK, past_tokens, new_results))
1111 |
1112 | #time.sleep(1000)
1113 |
1114 | assert len(past_tokens[LEVEL - 2]) == WINDOW_SIZE and len(new_results) == WINDOW_SIZE
1115 |
1116 | update_token_map(token_map, lst_token, past_tokens, new_results, LEVEL, WINDOW_SIZE, GUESS_SET_SIZE)
1117 |
1118 |
1119 | if ALWAYS_FWD_ONE:
1120 | past_tokens[0] = past_tokens[1][1:]
1121 | for level in range(1, LEVEL - 2):
1122 | past_tokens[level] = past_tokens[level + 1][:]
1123 |
1124 | past_tokens[LEVEL - 2] = new_results
1125 | else:
1126 | past_tokens[0] = past_tokens[1][1 + max_hit:]
1127 | for level in range(1, LEVEL - 2):
1128 | past_tokens[level] = past_tokens[level + 1][max_hit:]
1129 |
1130 | past_tokens[LEVEL - 2] = new_results[max_hit:]
1131 |
1132 |
1133 |
1134 | if max_hit > 0:
1135 | if not ALWAYS_FWD_ONE:
1136 | for level in range(LEVEL - 1):
1137 | past_tokens[level] = past_tokens[level] + [set_token() for _ in range(max_hit)]
1138 |
1139 | attention_mask = model_kwargs["attention_mask"]
1140 | model_kwargs["attention_mask"] = torch.cat((attention_mask, torch.ones(1, max_hit, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1)
1141 |
1142 | #not support awq
1143 | assert USE_AWQ == False
1144 |
1145 | past_key_values = []
1146 |
1147 | #plan to remove kv-cache copy and set tokens into next input when dist_workers > 1, as communication is costly
1148 | if DIST_WORKERS > 1 and max_hit > 0:
1149 |
1150 | guess_skip_dist = max_hit
1151 | for idx, kv in enumerate(outputs.past_key_values):
1152 | past_key_values.append( (kv[0][:,:,:outputs.kvcache_len,:], kv[1][:,:,:outputs.kvcache_len,:]) )
1153 | outputs.past_key_values = past_key_values
1154 | else:
1155 | guess_skip_dist = 0
1156 | offset_kv_cache = outputs.step_len-len(guess_tokens)+max_hit_idx * GUESS_SIZE if max_hit > 0 else 0
1157 | for idx, kv in enumerate(outputs.past_key_values):
1158 | #update kv-cache from verification branch
1159 | if max_hit > 0:
1160 | kv[0][:,:,outputs.kvcache_len:outputs.kvcache_len+max_hit,:] = kv[0][:,:,offset_kv_cache:offset_kv_cache+max_hit,:]
1161 | kv[1][:,:,outputs.kvcache_len:outputs.kvcache_len+max_hit,:] = kv[1][:,:,offset_kv_cache:offset_kv_cache+max_hit,:]
1162 | past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) )
1163 | outputs.past_key_values = past_key_values
1164 |
1165 | lst_token = hits[max_hit]
1166 |
1167 | #stopping condition
1168 | for hit_idx in range(max_hit + 1):
1169 | if eos_token_id is not None and hits[hit_idx] == eos_token_id[0]:
1170 | all_old_tokens.append(hits[hit_idx])
1171 | next_tokens = eos_token_id_tensor
1172 | max_hit = hit_idx
1173 | break
1174 | else:
1175 | all_old_tokens.append(hits[max_hit])
1176 | if POOL_FROM_PROMPT:
1177 | append_new_generated_pool(all_old_tokens[-LEVEL:], token_map, LEVEL, GUESS_SET_SIZE)
1178 |
1179 |
1180 | if chat and LOCAL_RANK == 0:
1181 | all_str = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
1182 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
1183 | if COLOR_PRINT:
1184 | from termcolor import colored
1185 | if max_hit > 1:
1186 | not_hit = self.tokenizer.decode(all_old_tokens[:-max_hit + 1], skip_special_tokens=True, \
1187 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
1188 | pt = colored(not_hit[prev:],"blue") + colored(all_str[len(not_hit):], "blue")
1189 | else:
1190 | pt = all_str[prev:]
1191 | print(pt, flush=True, end="")
1192 | else:
1193 | print(all_str[prev:], flush=True, end="")
1194 | prev = len(all_str)
1195 |
1196 | input_ids = torch.cat([input_ids, torch.tensor(hits[:max_hit + 1], device=next_tokens.device, dtype=next_tokens.dtype).unsqueeze(0)], dim=-1)
1197 |
1198 | if streamer is not None:
1199 | streamer.put(next_tokens.cpu())
1200 | model_kwargs = self._update_model_kwargs_for_generation(
1201 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1202 | )
1203 |
1204 | # if eos_token was found in one sentence, set sentence to finished
1205 | if eos_token_id_tensor is not None:
1206 | unfinished_sequences = unfinished_sequences.mul(
1207 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1208 | )
1209 |
1210 | # stop when each sentence is finished
1211 | if unfinished_sequences.max() == 0:
1212 | this_peer_finished = True
1213 |
1214 | # stop if we exceed the maximum length
1215 | if stopping_criteria(input_ids, scores):
1216 | this_peer_finished = True
1217 |
1218 | if this_peer_finished and not synced_gpus:
1219 | break
1220 |
1221 | for criteria in stopping_criteria:
1222 | if hasattr(criteria, "max_length"):
1223 | #print("steop: ", criteria.max_length, init_len, len(all_old_tokens), input_ids.size())
1224 | all_old_tokens = all_old_tokens[:criteria.max_length]
1225 | input_ids = input_ids[:,:criteria.max_length]
1226 | if max_length is not None:
1227 | #print("max : ", max_length, init_len)
1228 | all_old_tokens = all_old_tokens[:init_len + max_length]
1229 | input_ids = input_ids[:][:init_len + max_length]
1230 |
1231 | if DEBUG and LOCAL_RANK == 0:
1232 | print("\n==========================ACCELERATION===SUMMARY======================================")
1233 | print("Generated tokens: ", len(all_old_tokens) - init_len, "Total steps: ", steps, " Compression ratio: ", round((len(all_old_tokens) - init_len) / steps, 2))
1234 | print("======================================================================================", end="")
1235 | CONFIG_MAP["log"].append([len(all_old_tokens) - init_len, steps, round((len(all_old_tokens) - init_len) / steps, 2)])
1236 |
1237 | if streamer is not None:
1238 | streamer.end()
1239 |
1240 | if return_dict_in_generate:
1241 | if self.config.is_encoder_decoder:
1242 | return GreedySearchEncoderDecoderOutput(
1243 | sequences=input_ids,
1244 | scores=scores,
1245 | encoder_attentions=encoder_attentions,
1246 | encoder_hidden_states=encoder_hidden_states,
1247 | decoder_attentions=decoder_attentions,
1248 | cross_attentions=cross_attentions,
1249 | decoder_hidden_states=decoder_hidden_states,
1250 | )
1251 | else:
1252 | return GreedySearchDecoderOnlyOutput(
1253 | sequences=input_ids,
1254 | scores=scores,
1255 | attentions=decoder_attentions,
1256 | hidden_states=decoder_hidden_states,
1257 | )
1258 | else:
1259 | return input_ids
1260 |
1261 |
1262 |
1263 |
1264 |
1265 |
1266 | def greedy_search_chat(
1267 | self,
1268 | input_ids: torch.LongTensor,
1269 | logits_processor: Optional[LogitsProcessorList] = None,
1270 | stopping_criteria: Optional[StoppingCriteriaList] = None,
1271 | max_length: Optional[int] = None,
1272 | pad_token_id: Optional[int] = None,
1273 | eos_token_id: Optional[Union[int, List[int]]] = None,
1274 | output_attentions: Optional[bool] = None,
1275 | output_hidden_states: Optional[bool] = None,
1276 | output_scores: Optional[bool] = None,
1277 | return_dict_in_generate: Optional[bool] = None,
1278 | synced_gpus: bool = False,
1279 | streamer: Optional["BaseStreamer"] = None,
1280 | chat: int=True,
1281 | stop_token: Optional[str] = None,
1282 | **model_kwargs,
1283 | ) -> Union[GreedySearchOutput, torch.LongTensor]:
1284 | r"""
1285 | Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be
1286 | used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
1287 |
1288 |
1289 |
1290 | In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate()
1291 | instead. For an overview of generation strategies and code examples, check the [following
1292 | guide](../generation_strategies).
1293 |
1294 |
1295 |
1296 |
1297 | Parameters:
1298 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1299 | The sequence used as a prompt for the generation.
1300 | logits_processor (`LogitsProcessorList`, *optional*):
1301 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
1302 | used to modify the prediction scores of the language modeling head applied at each generation step.
1303 | stopping_criteria (`StoppingCriteriaList`, *optional*):
1304 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
1305 | used to tell if the generation loop should stop.
1306 |
1307 | max_length (`int`, *optional*, defaults to 20):
1308 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
1309 | tokens. The maximum length of the sequence to be generated.
1310 | pad_token_id (`int`, *optional*):
1311 | The id of the *padding* token.
1312 | eos_token_id (`Union[int, List[int]]`, *optional*):
1313 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
1314 | output_attentions (`bool`, *optional*, defaults to `False`):
1315 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
1316 | returned tensors for more details.
1317 | output_hidden_states (`bool`, *optional*, defaults to `False`):
1318 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
1319 | for more details.
1320 | output_scores (`bool`, *optional*, defaults to `False`):
1321 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
1322 | return_dict_in_generate (`bool`, *optional*, defaults to `False`):
1323 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1324 | synced_gpus (`bool`, *optional*, defaults to `False`):
1325 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
1326 | streamer (`BaseStreamer`, *optional*):
1327 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed
1328 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
1329 | model_kwargs:
1330 | Additional model specific keyword arguments will be forwarded to the `forward` function of the model.
1331 | If model is an encoder-decoder model the kwargs should include `encoder_outputs`.
1332 |
1333 | Return:
1334 | [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or
1335 | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
1336 | [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
1337 | `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if
1338 | `model.config.is_encoder_decoder=True`.
1339 |
1340 | Examples:
1341 |
1342 | ```python
1343 | >>> from transformers import (
1344 | ... AutoTokenizer,
1345 | ... AutoModelForCausalLM,
1346 | ... LogitsProcessorList,
1347 | ... MinLengthLogitsProcessor,
1348 | ... StoppingCriteriaList,
1349 | ... MaxLengthCriteria,
1350 | ... )
1351 |
1352 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
1353 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
1354 |
1355 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token
1356 | >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id
1357 |
1358 | >>> input_prompt = "It might be possible to"
1359 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
1360 |
1361 | >>> # instantiate logits processors
1362 | >>> logits_processor = LogitsProcessorList(
1363 | ... [
1364 | ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id),
1365 | ... ]
1366 | ... )
1367 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
1368 |
1369 | >>> outputs = model.greedy_search(
1370 | ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria
1371 | ... )
1372 |
1373 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
1374 | ["It might be possible to get a better understanding of the nature of the problem, but it's not"]
1375 | ```"""
1376 | # init values
1377 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
1378 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
1379 | if max_length is not None:
1380 | warnings.warn(
1381 | "`max_length` is deprecated in this function, use"
1382 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.",
1383 | UserWarning,
1384 | )
1385 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
1386 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
1387 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
1388 | if isinstance(eos_token_id, int):
1389 | eos_token_id = [eos_token_id]
1390 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
1391 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
1392 | output_attentions = (
1393 | output_attentions if output_attentions is not None else self.generation_config.output_attentions
1394 | )
1395 | output_hidden_states = (
1396 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
1397 | )
1398 | return_dict_in_generate = (
1399 | return_dict_in_generate
1400 | if return_dict_in_generate is not None
1401 | else self.generation_config.return_dict_in_generate
1402 | )
1403 |
1404 | # init attention / hidden states / scores tuples
1405 | scores = () if (return_dict_in_generate and output_scores) else None
1406 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
1407 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None
1408 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
1409 |
1410 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
1411 | if return_dict_in_generate and self.config.is_encoder_decoder:
1412 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
1413 | encoder_hidden_states = (
1414 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
1415 | )
1416 |
1417 | # keep track of which sequences are already finished
1418 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
1419 |
1420 | assert input_ids.size(0) == 1
1421 | all_old_tokens = input_ids[0].tolist()
1422 | init_len = len(all_old_tokens)
1423 | init = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
1424 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
1425 | prev = len(init)
1426 | steps = 0
1427 | this_peer_finished = False # used by synced_gpus only
1428 | while True:
1429 | steps += 1
1430 | if synced_gpus:
1431 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
1432 | # The following logic allows an early break if all peers finished generating their sequence
1433 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
1434 | # send 0.0 if we finished, 1.0 otherwise
1435 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
1436 | # did all peers finish? the reduced sum will be 0.0 then
1437 | if this_peer_finished_flag.item() == 0.0:
1438 | break
1439 |
1440 | # prepare model inputs
1441 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1442 |
1443 | # forward pass to get next token
1444 | outputs = self(
1445 | **model_inputs,
1446 | return_dict=True,
1447 | output_attentions=output_attentions,
1448 | output_hidden_states=output_hidden_states,
1449 | )
1450 |
1451 | if synced_gpus and this_peer_finished:
1452 | continue # don't waste resources running the code we don't need
1453 |
1454 | next_token_logits = outputs.logits[:, -1, :]
1455 |
1456 | # pre-process distribution
1457 | next_tokens_scores = logits_processor(input_ids, next_token_logits)
1458 |
1459 | # Store scores, attentions and hidden_states when required
1460 | if return_dict_in_generate:
1461 | if output_scores:
1462 | scores += (next_tokens_scores,)
1463 | if output_attentions:
1464 | decoder_attentions += (
1465 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
1466 | )
1467 | if self.config.is_encoder_decoder:
1468 | cross_attentions += (outputs.cross_attentions,)
1469 |
1470 | if output_hidden_states:
1471 | decoder_hidden_states += (
1472 | (outputs.decoder_hidden_states,)
1473 | if self.config.is_encoder_decoder
1474 | else (outputs.hidden_states,)
1475 | )
1476 |
1477 | # argmax
1478 | next_tokens = torch.argmax(next_tokens_scores, dim=-1)
1479 |
1480 | # finished sentences should have their next token be a padding token
1481 | if eos_token_id is not None:
1482 | if pad_token_id is None:
1483 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
1484 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
1485 |
1486 | all_old_tokens.append(next_tokens.item())
1487 | all_str = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \
1488 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,)
1489 | if chat:
1490 | print(all_str[prev:], flush=True, end="")
1491 | prev = len(all_str)
1492 |
1493 |
1494 | # update generated ids, model inputs, and length for next step
1495 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
1496 | if streamer is not None:
1497 | streamer.put(next_tokens.cpu())
1498 | model_kwargs = self._update_model_kwargs_for_generation(
1499 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
1500 | )
1501 |
1502 | # if eos_token was found in one sentence, set sentence to finished
1503 | if eos_token_id_tensor is not None:
1504 | unfinished_sequences = unfinished_sequences.mul(
1505 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
1506 | )
1507 |
1508 | # stop when each sentence is finished
1509 | if unfinished_sequences.max() == 0:
1510 | this_peer_finished = True
1511 |
1512 | # stop if we exceed the maximum length
1513 | if stopping_criteria(input_ids, scores):
1514 | this_peer_finished = True
1515 |
1516 | if this_peer_finished and not synced_gpus:
1517 | break
1518 | DEBUG = CONFIG_MAP.get("DEBUG", 0)
1519 | if DEBUG:
1520 | #print("===DEBUG INFO===", " generated tokens: ", len(all_old_tokens) - init_len, "total step: ", steps, len(token_map.keys()), sum(len(value) for value in token_map.values()), input_ids.numel(), reps)
1521 |
1522 | print("\n==========================ACCELERATION===SUMMARY======================================")
1523 | print("Generated tokens: ", len(all_old_tokens) - init_len, "Total steps: ", steps, " Compression ratio: N/A ")
1524 | print("======================================================================================", end="")
1525 |
1526 | if streamer is not None:
1527 | streamer.end()
1528 |
1529 | if return_dict_in_generate:
1530 | if self.config.is_encoder_decoder:
1531 | return GreedySearchEncoderDecoderOutput(
1532 | sequences=input_ids,
1533 | scores=scores,
1534 | encoder_attentions=encoder_attentions,
1535 | encoder_hidden_states=encoder_hidden_states,
1536 | decoder_attentions=decoder_attentions,
1537 | cross_attentions=cross_attentions,
1538 | decoder_hidden_states=decoder_hidden_states,
1539 | )
1540 | else:
1541 | return GreedySearchDecoderOnlyOutput(
1542 | sequences=input_ids,
1543 | scores=scores,
1544 | attentions=decoder_attentions,
1545 | hidden_states=decoder_hidden_states,
1546 | )
1547 | else:
1548 | return input_ids
1549 |
--------------------------------------------------------------------------------
/lade/lade_distributed.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from .decoding import CONFIG_MAP
4 |
5 | def get_device():
6 | if "LOCAL_RANK" not in CONFIG_MAP:
7 | return 0
8 | local_rank = CONFIG_MAP["LOCAL_RANK"]
9 | return local_rank
10 |
11 | def distributed():
12 | return "DIST_WORKERS" in CONFIG_MAP and CONFIG_MAP["DIST_WORKERS"] > 1
13 |
--------------------------------------------------------------------------------
/lade/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers import GenerationMixin
3 | from transformers.models.llama import modeling_llama
4 |
5 | from lade.decoding import greedy_search_proxy, sample_proxy, FUNC_MAP, CONFIG_MAP
6 | from lade.models import modeling_llama as lade_modeling_llama
7 | #from .from lade.models import modeling_llama
8 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
9 | import torch
10 | import torch.distributed as dist
11 | import inspect
12 |
13 | def config_lade(WINDOW_SIZE=None, LEVEL=None, DEBUG=None, GUESS_SET_SIZE=None, ALWAYS_FWD_ONE=None, SPLIT_FLAG=None, DIST_WORKERS=None, POOL_FROM_PROMPT=None, backend = 'nccl', USE_FLASH=None):
14 | if WINDOW_SIZE is not None:
15 | CONFIG_MAP["WINDOW_SIZE"] = WINDOW_SIZE
16 | if LEVEL is not None:
17 | CONFIG_MAP["LEVEL"] = LEVEL
18 | if GUESS_SET_SIZE is not None:
19 | CONFIG_MAP["GUESS_SET_SIZE"] = GUESS_SET_SIZE
20 | if ALWAYS_FWD_ONE is not None:
21 | CONFIG_MAP["ALWAYS_FWD_ONE"] = ALWAYS_FWD_ONE
22 | if DEBUG is not None:
23 | CONFIG_MAP["DEBUG"] = DEBUG
24 | if SPLIT_FLAG is not None:
25 | CONFIG_MAP["SPLIT_FLAG"] = SPLIT_FLAG
26 | if POOL_FROM_PROMPT is not None:
27 | CONFIG_MAP["POOL_FROM_PROMPT"] = POOL_FROM_PROMPT
28 | if DIST_WORKERS is not None and DIST_WORKERS > 1:
29 | CONFIG_MAP["DIST_WORKERS"] = DIST_WORKERS
30 | CONFIG_MAP["LOCAL_RANK"] = int(os.environ["LOCAL_RANK"])
31 | dist.init_process_group(backend, rank=CONFIG_MAP["LOCAL_RANK"])
32 | torch.cuda.set_device(CONFIG_MAP["LOCAL_RANK"])
33 | assert dist.get_world_size() == DIST_WORKERS, "DIST_WORKERS config should be equal to work size"
34 | if USE_FLASH is not None:
35 | CONFIG_MAP["USE_FLASH"] = USE_FLASH
36 |
37 | CONFIG_MAP["log"] = []
38 |
39 |
40 | def inject_module(lade_module, original_module):
41 | s = {}
42 | for name, cls in inspect.getmembers(original_module, inspect.isclass):
43 | s[name] = cls
44 | for name, cls in inspect.getmembers(lade_module, inspect.isclass):
45 | if str(cls.__module__).startswith("lade") and name in s:
46 | tc = s[name]
47 | for method_name in dir(cls):
48 | if callable(getattr(cls, method_name)):
49 | try:
50 | setattr(tc, method_name, getattr(cls, method_name))
51 | except:
52 | pass
53 |
54 |
55 | def augment_llama():
56 | inject_module(lade_modeling_llama, modeling_llama)
57 | #llama.modeling_llama.LlamaForCausalLM = lade_modeling_llama.LlamaForCausalLM
58 | #modeling_llama.LlamaForCausalLM.jforward_multilevel = lookahead_llama.jforward_multilevel
59 | #modeling_llama.LlamaModel.LlamaModeljforward = lookahead_llama.LlamaModeljforward
60 | #modeling_llama.LlamaModel.j_prepare_decoder_attention_mask = lookahead_llama.j_prepare_decoder_attention_mask
61 |
62 | def augment_generate():
63 | FUNC_MAP["greedy_search"] = GenerationMixin.greedy_search
64 | FUNC_MAP["sample"] = GenerationMixin.sample
65 | GenerationMixin.greedy_search = greedy_search_proxy
66 | GenerationMixin.sample = sample_proxy
67 | #FUNC_MAP["sample"] = GenerationMixin.sample
68 | #GenerationMixin.sample = sample_proxy
69 |
70 | def augment_all():
71 | augment_llama()
72 | augment_generate()
73 |
74 | def log_history(clear=False):
75 | gen = 0
76 | step = 0
77 | if "log" in CONFIG_MAP:
78 | for log in CONFIG_MAP["log"]:
79 | gen += log[0]
80 | step += log[1]
81 | if clear:
82 | CONFIG_MAP["log"] = []
83 | print("LADE LOG - OVERALL GEN: ", gen, " STEPS: ", step, " AVG COMPRESS RATIO: ", (gen / step) if step > 0 else 0)
84 |
85 | def save_log(log_dir):
86 | if "log" in CONFIG_MAP:
87 | torch.save(CONFIG_MAP["log"], log_dir)
88 |
89 | def get_hf_model(model_path, quant, dtype, device, cache_dir):
90 | tokenizer = AutoTokenizer.from_pretrained(model_path, fast_tokenizer=True)
91 | model_config = AutoConfig.from_pretrained(model_path)
92 | assert quant is None or len(quant) == 0
93 |
94 | model = AutoModelForCausalLM.from_pretrained(
95 | model_path, torch_dtype=dtype, device_map=device, cache_dir=cache_dir if len(cache_dir) > 0 else None)
96 | model = model.eval()
97 | model.tokenizer = tokenizer
98 |
99 | return model, tokenizer
100 |
101 | def get_model(model_path, quant, dtype, device, cache_dir, use_ds, native_offload = False):
102 | return get_hf_model(model_path, quant, dtype, device, cache_dir)
--------------------------------------------------------------------------------
/media/acc-demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/acc-demo.gif
--------------------------------------------------------------------------------
/media/jacobi-iteration.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/jacobi-iteration.gif
--------------------------------------------------------------------------------
/media/lookahead-decoding.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/lookahead-decoding.gif
--------------------------------------------------------------------------------
/media/lookahead-perf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/lookahead-perf.png
--------------------------------------------------------------------------------
/media/mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/mask.png
--------------------------------------------------------------------------------
/minimal-flash.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | import torch
3 | import time
4 | import os
5 | if int(os.environ.get("LOAD_LADE", 0)):
6 | import lade
7 | lade.augment_all()
8 | #For a 7B model, set LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7
9 | lade.config_lade(LEVEL=7, WINDOW_SIZE=20, GUESS_SET_SIZE=20, DEBUG=1, USE_FLASH=True, POOL_FROM_PROMPT=True)
10 |
11 | assert torch.cuda.is_available()
12 |
13 | torch_device = "cuda"
14 |
15 | model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16 |
17 | tokenizer = AutoTokenizer.from_pretrained(model_name)
18 |
19 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device, attn_implementation="flash_attention_2")
20 | model.tokenizer = tokenizer
21 | prompt = "How do you fine tune a large language model?"
22 | input_text = (
23 | f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.\n<|user|>\n{prompt}\n<|assistant|>"
24 | )
25 |
26 |
27 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
28 |
29 | #warm up
30 | greedy_output = model.generate(**model_inputs, max_new_tokens=1)
31 | #end warm up
32 |
33 | # generate 256 new tokens
34 | torch.cuda.synchronize()
35 | t0s = time.time()
36 | sample_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=True, temperature=0.7,
37 | top_k=50, top_p=0.9)
38 | torch.cuda.synchronize()
39 | t1s = time.time()
40 |
41 | torch.cuda.synchronize()
42 | t0g = time.time()
43 | greedy_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
44 | torch.cuda.synchronize()
45 | t1g = time.time()
46 |
47 | print("Output:\n" + 100 * '-')
48 | print("Greedy output: ", tokenizer.decode(greedy_output[0], skip_special_tokens=False))
49 | print("Sample output: ", tokenizer.decode(sample_output[0], skip_special_tokens=False))
50 |
51 | print("Greedy Generated Tokens:", (greedy_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (greedy_output.numel() - model_inputs['input_ids'].numel()) / (t1g - t0g), " tokens/s")
52 | print("Sample Generated Tokens:", (sample_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (sample_output.numel() - model_inputs['input_ids'].numel()) / (t1s - t0s), " tokens/s")
53 |
54 | #python minimal.py #44 tokens/s
55 | #LOAD_LADE=1 USE_LADE=1 python minimal.py #74 tokens/s, 1.6x throughput without changing output distribution!
56 |
57 |
--------------------------------------------------------------------------------
/minimal.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | import torch
3 | import time
4 | import os
5 | if int(os.environ.get("LOAD_LADE", 0)):
6 | import lade
7 | lade.augment_all()
8 | #For a 7B model, set LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7
9 | lade.config_lade(LEVEL=7, WINDOW_SIZE=20, GUESS_SET_SIZE=20, DEBUG=1, POOL_FROM_PROMPT=True)
10 |
11 | assert torch.cuda.is_available()
12 |
13 | torch_device = "cuda"
14 |
15 | model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
16 |
17 | tokenizer = AutoTokenizer.from_pretrained(model_name)
18 |
19 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device)
20 | model.tokenizer = tokenizer
21 | prompt = "How do you fine tune a large language model?"
22 | input_text = (
23 | f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.\n<|user|>\n{prompt}\n<|assistant|>"
24 | )
25 |
26 |
27 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device)
28 |
29 | #warm up
30 | greedy_output = model.generate(**model_inputs, max_new_tokens=1)
31 | #end warm up
32 |
33 | # generate 256 new tokens
34 | torch.cuda.synchronize()
35 | t0s = time.time()
36 | sample_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=True, temperature=0.7,
37 | top_k=50, top_p=0.9)
38 | torch.cuda.synchronize()
39 | t1s = time.time()
40 |
41 | torch.cuda.synchronize()
42 | t0g = time.time()
43 | greedy_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
44 | torch.cuda.synchronize()
45 | t1g = time.time()
46 |
47 | print("Output:\n" + 100 * '-')
48 | print("Greedy output: ", tokenizer.decode(greedy_output[0], skip_special_tokens=False))
49 | print("Sample output: ", tokenizer.decode(sample_output[0], skip_special_tokens=False))
50 |
51 | print("Greedy Generated Tokens:", (greedy_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (greedy_output.numel() - model_inputs['input_ids'].numel()) / (t1g - t0g), " tokens/s")
52 | print("Sample Generated Tokens:", (sample_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (sample_output.numel() - model_inputs['input_ids'].numel()) / (t1s - t0s), " tokens/s")
53 |
54 | #python minimal.py #44 tokens/s
55 | #LOAD_LADE=1 USE_LADE=1 python minimal.py #74 tokens/s, 1.6x throughput without changing output distribution!
56 |
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.36.2
2 | accelerate==0.23.0
3 | fschat==0.2.31
4 | openai
5 | anthropic
6 | einops==0.7.0
7 | torch<2.1.1 #torch 2.1.1 uses sdpa, which is not supported yet
8 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | if __name__ == '__main__':
4 | setuptools.setup(
5 | name='lade',
6 | version='0.0.2',
7 | description='Lookahead Decoding Implementation',
8 | author='Fu Yichao',
9 | author_email='yichaofu2000@outlook.com',
10 | license='Apache-2',
11 | url='https://github.com/hao-ai-lab/LookaheadDecoding.git',
12 | packages=['lade', 'lade.models']
13 | )
14 |
15 |
--------------------------------------------------------------------------------