├── LICENSE ├── README.md ├── applications ├── chatbot.py ├── eval_classeval.py ├── eval_cnndm.py ├── eval_humaneval.py ├── eval_mtbench.py ├── eval_xsum.py ├── run_chat.sh └── run_mtbench.sh ├── lade ├── __init__.py ├── decoding.py ├── lade_distributed.py ├── models │ └── modeling_llama.py └── utils.py ├── media ├── acc-demo.gif ├── jacobi-iteration.gif ├── lookahead-decoding.gif ├── lookahead-perf.png └── mask.png ├── minimal-flash.py ├── minimal.py ├── requirements.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

 Break the Sequential Dependency of LLM Inference Using Lookahead Decoding

2 | 3 |

4 | | Paper | Blog | Roadmap | 5 |

6 | 7 | --- 8 | *News* 🔥 9 | - [2024/2] Lookahead Decoding Paper now available on [arXiv](https://arxiv.org/abs/2402.02057). [Sampling](#use-lookahead-decoding-in-your-own-code) and [FlashAttention](#flashAttention-support) are supported. Advanced features for better token prediction are updated. 10 | 11 | --- 12 | ## Introduction 13 | We introduce lookahead decoding: 14 | - A parallel decoding algorithm to accelerate LLM inference. 15 | - Without the need for a draft model or a data store. 16 | - Linearly decreases #decoding steps relative to log(FLOPs) used per decoding step. 17 | 18 | Below is a demo of lookahead decoding accelerating LLaMa-2-Chat 7B generation: 19 | 20 |
21 | 22 | 23 | 24 |
25 |
26 | Demo of speedups by lookahead decoding on LLaMA-2-Chat 7B generation. Blue fonts are tokens generated in parallel in a decoding step. 27 |
28 |
29 |
30 | 31 | ### Background: Parallel LLM Decoding Using Jacobi Iteration 32 | 33 | Lookahead decoding is motivated by [Jacobi decoding](https://arxiv.org/pdf/2305.10427.pdf), which views autoregressive decoding as solving nonlinear systems and decodes all future tokens simultaneously using a fixed-point iteration method. Below is a Jacobi decoding example. 34 | 35 |
36 | 37 | 38 | 39 |
40 |
41 | Illustration of applying Jacobi iteration method for parallel LLM decoding. 42 |
43 |
44 |
45 | 46 | However, Jacobi decoding can barely see wall-clock speedup in real-world LLM applications. 47 | 48 | ### Lookahead Decoding: Make Jacobi Decoding Feasible 49 | 50 | Lookahead decoding takes advantage of Jacobi decoding's ability by collecting and caching n-grams generated from Jacobi iteration trajectories. 51 | 52 | The following gif shows the process of collecting 2 grams via Jacobi decoding and verifying them to accelerate decoding. 53 | 54 |
55 | 56 | 57 | 58 |
59 |
60 | Illustration of lookahead decoding with 2-grams. 61 |
62 |
63 |
64 | 65 | To enhance the efficiency of this process, each lookahead decoding step is divided into two parallel branches: the lookahead branch and the verification branch. The lookahead branch maintains a fixed-sized, 2D window to generate n-grams from the Jacobi iteration trajectory. Simultaneously, the verification branch selects and verifies promising n-gram candidates. 66 | 67 | ### Lookahead Branch and Verification Branch 68 | 69 | The lookahead branch aims to generate new N-grams. The branch operates with a two-dimensional window defined by two parameters: 70 | - Window size W: How far ahead we look in future token positions to conduct parallel decoding. 71 | - N-gram size N: How many steps we look back into the past Jacobi iteration trajectory to retrieve n-grams. 72 | 73 | In the verification branch, we identify n-grams whose first token matches the last input token. This is determined via simple string match. Once identified, these n-grams are appended to the current input and subjected to verification via an LLM forward pass through them. 74 | 75 | We implement these branches in one attention mask to further utilize GPU's parallel computing power. 76 | 77 |
78 | 79 | 80 | 81 |
82 |
83 | Attention mask for lookahead decoding with 4-grams and window size 5. In this mask, two 4-gram candidates (bottom right) are verified concurrently with parallel decoding. 84 |
85 |
86 |
87 | 88 | ### Experimental Results 89 | 90 | Our study shows lookahead decoding substantially reduces latency, ranging from 1.5x to 2.3x on different datasets on a single GPU. See the figure below. 91 | 92 |
93 | 94 | 95 | 96 |
97 |
98 | Speedup of lookahead decoding on different models and datasets. 99 |
100 |
101 |
102 | 103 | ## Contents 104 | - [Introduction](#introduction) 105 | - [Contents](#contents) 106 | - [Installation](#installation) 107 | - [Install With Pip](#install-with-pip) 108 | - [Install From The Source](#install-from-the-source) 109 | - [Inference](#inference-with-lookahead-decoding) 110 | - [Use In Your Own Code](#use-lookahead-decoding-in-your-own-code) 111 | - [Citation](#citation) 112 | - [Guidance](#guidance) 113 | 114 | 115 | ## Installation 116 | ### Install with pip 117 | ```bash 118 | pip install lade 119 | ``` 120 | ### Install from the source 121 | ```bash 122 | git clone https://github.com/hao-ai-lab/LookaheadDecoding.git 123 | cd LookaheadDecoding 124 | pip install -r requirements.txt 125 | pip install -e . 126 | ``` 127 | 128 | ### Inference With Lookahead decoding 129 | You can run the minimal example to see the speedup that Lookahead decoding brings. 130 | ```bash 131 | python minimal.py #no Lookahead decoding 132 | USE_LADE=1 LOAD_LADE=1 python minimal.py #use Lookahead decoding, 1.6x speedup 133 | ``` 134 | 135 | You can also enjoy chatting with your own chatbots with Lookahead decoding. 136 | ```bash 137 | USE_LADE=1 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat #chat, with lookahead 138 | USE_LADE=0 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat #chat, without lookahead 139 | 140 | 141 | USE_LADE=1 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug #no chat, with lookahead 142 | USE_LADE=0 python applications/chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug #no chat, without lookahead 143 | ``` 144 | 145 | ### Use Lookahead decoding in your own code 146 | You can import and use Lookahead decoding in your own code in three LoCs. You also need to set ```USE_LADE=1``` in command line or set ```os.environ["USE_LADE"]="1"``` in Python script. Note that Lookahead decoding only support LLaMA yet. 147 | 148 | ```python 149 | import lade 150 | lade.augment_all() 151 | lade.config_lade(LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7, DEBUG=0) 152 | #LEVEL, WINDOW_SIZE and GUESS_SET_SIZE are three important configurations (N,W,G) in lookahead decoding, please refer to our blog! 153 | #You can obtain a better performance by tuning LEVEL/WINDOW_SIZE/GUESS_SET_SIZE on your own device. 154 | ``` 155 | 156 | Then you can speedup the decoding process. Here is an example using greedy search: 157 | ``` 158 | tokenizer = AutoTokenizer.from_pretrained(model_name) 159 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device) 160 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device) 161 | greedy_output = model.generate(**model_inputs, max_new_tokens=1024) #speedup obtained 162 | ``` 163 | 164 | Here is an example using sampling: 165 | ``` 166 | tokenizer = AutoTokenizer.from_pretrained(model_name) 167 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device) 168 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device) 169 | sample_output = model.generate(**model_inputs, max_new_tokens=1024, temperature=0.7) #speedup obtained 170 | ``` 171 | 172 | ### FlashAttention Support 173 | Install the original FlashAttention 174 | ```bash 175 | pip install flash-attn==2.3.3 #original FlashAttention 176 | ``` 177 | Two ways to install FlashAttention specialized for Lookahead Decoding 178 | 1) Download a pre-built package on https://github.com/Viol2000/flash-attention-lookahead/releases/tag/v2.3.3 and install (fast, recommended). 179 | For example, I have cuda==11.8, python==3.9 and torch==2.1, I should do the following: 180 | ```bash 181 | wget https://github.com/Viol2000/flash-attention-lookahead/releases/download/v2.3.3/flash_attn_lade-2.3.3+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl 182 | pip install flash_attn_lade-2.3.3+cu118torch2.1cxx11abiFALSE-cp39-cp39-linux_x86_64.whl 183 | ``` 184 | 2) Install from the source (slow, not recommended) 185 | ```bash 186 | git clone https://github.com/Viol2000/flash-attention-lookahead.git 187 | cd flash-attention-lookahead && python setup.py install 188 | ``` 189 | 190 | Here is an example script to run the models with FlashAttention: 191 | ```bash 192 | python minimal-flash.py #no Lookahead decoding, w/ FlashAttention 193 | USE_LADE=1 LOAD_LADE=1 python minimal-flash.py #use Lookahead decoding, w/ FlashAttention, 20% speedup than w/o FlashAttention 194 | ``` 195 | 196 | In your own code, you need to set ```USE_FLASH=True``` when calling ```config_lade```, and set ```attn_implementation="flash_attention_2"``` when calling ```AutoModelForCausalLM.from_pretrained```. 197 | ```python 198 | import lade 199 | lade.augment_all() 200 | lade.config_lade(LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7, USE_FLASH=True, DEBUG=0) 201 | tokenizer = AutoTokenizer.from_pretrained(model_name) 202 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device, attn_implementation="flash_attention_2") 203 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device) 204 | greedy_output = model.generate(**model_inputs, max_new_tokens=1024) #speedup obtained 205 | ``` 206 | We will integrate FlashAttention directly into this repo for simple installation and usage. 207 | 208 | ## Citation 209 | ```bibtex 210 | @article{fu2024break, 211 | title={Break the sequential dependency of llm inference using lookahead decoding}, 212 | author={Fu, Yichao and Bailis, Peter and Stoica, Ion and Zhang, Hao}, 213 | journal={arXiv preprint arXiv:2402.02057}, 214 | year={2024} 215 | } 216 | ``` 217 | ## Guidance 218 | The core implementation is in decoding.py. Lookahead decoding requires an adaptation for each specific model. An example is in models/llama.py. 219 | 220 | -------------------------------------------------------------------------------- /applications/chatbot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import subprocess 4 | import lade 5 | from lade.utils import get_model 6 | import time, os 7 | 8 | if __name__ == "__main__": 9 | if int(os.environ.get("USE_LADE", 0)): 10 | lade.augment_all() 11 | lade.config_lade(LEVEL=5, WINDOW_SIZE=15, GUESS_SET_SIZE=15, DEBUG=1) #A100 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--local_rank", type=int, default=0) 15 | parser.add_argument("--model_path", type=str, help="model path", default="meta-llama/Llama-2-7b-chat-hf") #tiiuae/falcon-7b-instruct #"TheBloke/Falcon-180B-Chat-GPTQ" 16 | parser.add_argument("--model_type", type=str, default="llama") 17 | parser.add_argument("--quant", type=str, default="") 18 | parser.add_argument("--use_ds", action="store_true") 19 | parser.add_argument("--debug", action="store_true") 20 | parser.add_argument("--chat", action="store_true") 21 | parser.add_argument("--dtype", type=str, default="float16") 22 | parser.add_argument("--device", type=str, default="cuda:0") 23 | parser.add_argument("--cache_dir", type=str, default="") 24 | parser.add_argument( 25 | "--max_new_tokens", 26 | type=int, 27 | default=128, 28 | help="Maximum new tokens to generate per response", 29 | ) 30 | args = parser.parse_args() 31 | 32 | if args.dtype == "float16": 33 | args.dtype = torch.float16 34 | elif args.dtype == "bfloat16": 35 | args.dtype = torch.bfloat16 36 | 37 | #if args.use_ds: 38 | model, tokenizer = get_model(args.model_path, args.quant, args.dtype, args.device, args.cache_dir, args.use_ds, False) 39 | 40 | user_input = "" 41 | num_rounds = 0 42 | if args.model_type == "llama": 43 | roles = ("[INST]", "[/INST]") #support llama2 only 44 | else: 45 | assert False 46 | 47 | user_input = "" 48 | if args.model_type == "llama": 49 | system_prompt = "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n\n" 50 | 51 | first_time = True 52 | while True: 53 | num_rounds += 1 54 | if args.chat: 55 | model_input = input("User: ") 56 | else: 57 | model_input = '''Which methods did Socrates employ to challenge the prevailing thoughts of his time?''' 58 | print("User: " + model_input) 59 | if system_prompt is not None and first_time: 60 | if args.model_type == "llama": 61 | new_inputs = system_prompt + f"{model_input}\n {roles[1]} " 62 | new_inputs = "[INST]" + f"{model_input}\n {roles[1]} " 63 | first_time = False 64 | else: 65 | new_inputs = f"{roles[0]}: {model_input}\n {roles[1]}: " 66 | user_input += new_inputs 67 | 68 | generate_kwargs = dict(max_new_tokens=1024, do_sample=False, stop_token=None, top_p=1.0, temperature=1.0) #greedy 69 | 70 | print("Assistant: " , flush=True, end="") 71 | input_ids = tokenizer(user_input, return_tensors="pt", 72 | max_length=1024, truncation=True).input_ids.to(args.device) 73 | 74 | if not args.chat: 75 | lade.config_lade(DEBUG=0) 76 | tmp_kwargs = dict(max_new_tokens=1, do_sample=False, stop_token=None, top_p=1.0, temperature=1.0) 77 | tmp_greedy_output = model.generate(input_ids=input_ids, **tmp_kwargs).tolist() #warmup 78 | lade.config_lade(DEBUG=1) 79 | 80 | os.environ["CHAT"] = "0" 81 | t0 = time.time() 82 | greedy_output = model.generate(input_ids=input_ids, **generate_kwargs).tolist() 83 | 84 | t1 = time.time() 85 | os.environ["CHAT"] = "0" 86 | output = tokenizer.decode(greedy_output[0], skip_special_tokens=False) 87 | print(output) 88 | user_input = f"{output}\n\n" 89 | 90 | if args.debug: 91 | generated_tokens = len(greedy_output[0]) - input_ids.numel() 92 | print() 93 | print("======================================SUMMARY=========================================") 94 | print("Input tokens: ", input_ids.numel() ,"Generated tokens: ", generated_tokens,"Time: ", round(t1 - t0, 2), "s Throughput: ", round(generated_tokens / (t1 - t0), 2), "tokens/s") 95 | print("======================================================================================") 96 | #print("\n\n\n\n") 97 | if not args.chat: 98 | break 99 | -------------------------------------------------------------------------------- /applications/eval_classeval.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py 7 | import argparse 8 | import json 9 | import os 10 | import random 11 | import time 12 | 13 | import shortuuid 14 | import torch 15 | from tqdm import tqdm 16 | from typing import Dict, List, Optional 17 | from fastchat.llm_judge.common import load_questions, temperature_config 18 | from fastchat.model import get_conversation_template 19 | from fastchat.utils import str_to_torch_dtype 20 | import time 21 | import lade 22 | from datasets import load_dataset 23 | 24 | def run_eval( 25 | model_path, 26 | model_id, 27 | question_file, 28 | question_begin, 29 | question_end, 30 | answer_file, 31 | max_new_token, 32 | num_choices, 33 | num_gpus_per_model, 34 | num_gpus_total, 35 | max_gpu_memory, 36 | dtype, 37 | debug, 38 | cache_dir, 39 | cpu_offloading, 40 | use_pp, 41 | use_tp, 42 | use_tp_ds, 43 | use_flash, 44 | do_sample 45 | ): 46 | #questions = load_questions(question_file, question_begin, question_end) 47 | ClassEval = load_dataset("FudanSELab/ClassEval") 48 | questions = ClassEval["test"] 49 | # random shuffle the questions to balance the loading 50 | ###not shuffle 51 | #random.shuffle(questions) 52 | 53 | # Split the question file into `num_gpus` files 54 | assert num_gpus_total % num_gpus_per_model == 0 55 | 56 | get_answers_func = get_model_answers 57 | 58 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) 59 | ans_handles = [] 60 | for i in range(0, len(questions), chunk_size): 61 | ans_handles.append( 62 | get_answers_func( 63 | model_path, 64 | model_id, 65 | questions[i : i + chunk_size], 66 | question_end, 67 | answer_file, 68 | max_new_token, 69 | num_choices, 70 | num_gpus_per_model, 71 | max_gpu_memory, 72 | dtype=dtype, 73 | debug=debug, 74 | cache_dir=cache_dir, 75 | cpu_offloading=cpu_offloading, 76 | use_tp=use_tp, 77 | use_pp=use_pp, 78 | use_tp_ds=use_tp_ds, 79 | use_flash=use_flash, 80 | do_sample=do_sample 81 | ) 82 | ) 83 | 84 | 85 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM 86 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration 87 | 88 | def load_model( 89 | model_path: str, 90 | device: str = "cuda", 91 | device_map: str= "", 92 | num_gpus: int = 1, 93 | max_gpu_memory: Optional[str] = None, 94 | dtype: Optional[torch.dtype] = None, 95 | load_8bit: bool = False, 96 | cpu_offloading: bool = False, 97 | revision: str = "main", 98 | debug: bool = False, 99 | use_flash:bool = False 100 | ): 101 | """Load a model from Hugging Face.""" 102 | # get model adapter 103 | adapter = Llama2Adapter() 104 | # Handle device mapping 105 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( 106 | device, load_8bit, cpu_offloading 107 | ) 108 | if device == "cpu": 109 | kwargs = {"torch_dtype": torch.float32} 110 | if CPU_ISA in ["avx512_bf16", "amx"]: 111 | try: 112 | import intel_extension_for_pytorch as ipex 113 | 114 | kwargs = {"torch_dtype": torch.bfloat16} 115 | except ImportError: 116 | warnings.warn( 117 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" 118 | ) 119 | elif device.startswith("cuda"): 120 | kwargs = {"torch_dtype": torch.float16} 121 | if num_gpus != 1: 122 | kwargs["device_map"] = "auto" 123 | if max_gpu_memory is None: 124 | kwargs[ 125 | "device_map" 126 | ] = "sequential" # This is important for not the same VRAM sizes 127 | available_gpu_memory = get_gpu_memory(num_gpus) 128 | kwargs["max_memory"] = { 129 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" 130 | for i in range(num_gpus) 131 | } 132 | else: 133 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} 134 | 135 | if cpu_offloading: 136 | # raises an error on incompatible platforms 137 | from transformers import BitsAndBytesConfig 138 | 139 | if "max_memory" in kwargs: 140 | kwargs["max_memory"]["cpu"] = ( 141 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" 142 | ) 143 | kwargs["quantization_config"] = BitsAndBytesConfig( 144 | load_in_8bit_fp32_cpu_offload=cpu_offloading 145 | ) 146 | kwargs["load_in_8bit"] = load_8bit 147 | elif load_8bit: 148 | if num_gpus != 1: 149 | warnings.warn( 150 | "8-bit quantization is not supported for multi-gpu inference." 151 | ) 152 | else: 153 | model, tokenizer = adapter.load_compress_model( 154 | model_path=model_path, 155 | device=device, 156 | torch_dtype=kwargs["torch_dtype"], 157 | revision=revision, 158 | ) 159 | if debug: 160 | print(model) 161 | return model, tokenizer 162 | kwargs["revision"] = revision 163 | 164 | if dtype is not None: # Overwrite dtype if it is provided in the arguments. 165 | kwargs["torch_dtype"] = dtype 166 | if use_flash: 167 | kwargs["use_flash_attention_2"] = use_flash 168 | if len(device_map) > 0: 169 | kwargs["device_map"] = device_map 170 | # Load model 171 | model, tokenizer = adapter.load_model(model_path, kwargs) 172 | 173 | if len(device_map) > 0: 174 | return model, tokenizer 175 | 176 | if ( 177 | device == "cpu" 178 | and kwargs["torch_dtype"] is torch.bfloat16 179 | and CPU_ISA is not None 180 | ): 181 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) 182 | 183 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in ( 184 | "mps", 185 | "xpu", 186 | "npu", 187 | ): 188 | model.to(device) 189 | 190 | if device == "xpu": 191 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) 192 | 193 | if debug: 194 | print(model) 195 | 196 | return model, tokenizer 197 | 198 | #@torch.inference_mode() 199 | def get_model_answers( 200 | model_path, 201 | model_id, 202 | questions, 203 | question_end, 204 | answer_file, 205 | max_new_token, 206 | num_choices, 207 | num_gpus_per_model, 208 | max_gpu_memory, 209 | dtype, 210 | debug, 211 | cache_dir, 212 | cpu_offloading, 213 | use_pp, 214 | use_tp_ds, 215 | use_tp, 216 | use_flash, 217 | do_sample 218 | ): 219 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 220 | 221 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices) 222 | #tokenizer = AutoTokenizer.from_pretrained(model_path) 223 | #cfg = AutoConfig.from_pretrained(model_path) 224 | #cfg._flash_attn_2_enabled= use_flash 225 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 226 | if use_pp: 227 | model, tokenizer = load_model( 228 | model_path, 229 | use_flash=use_flash, 230 | device=f"cuda", 231 | device_map="balanced", 232 | num_gpus=num_gpus_per_model, 233 | max_gpu_memory=max_gpu_memory, 234 | dtype=dtype, 235 | load_8bit=False, 236 | cpu_offloading=cpu_offloading, 237 | debug=debug, 238 | ) 239 | 240 | elif use_tp_ds: 241 | import deepspeed 242 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0'))) 243 | model, tokenizer = load_model( 244 | model_path, 245 | use_flash=use_flash, 246 | device_map="cpu", 247 | num_gpus=num_gpus_per_model, 248 | max_gpu_memory=max_gpu_memory, 249 | dtype=dtype, 250 | load_8bit=False, 251 | cpu_offloading=cpu_offloading, 252 | debug=debug, 253 | ) 254 | model = deepspeed.init_inference( 255 | model, 256 | mp_size=int(os.getenv("WORLD_SIZE", "1")), 257 | dtype=torch.half 258 | ) 259 | else: 260 | model, tokenizer = load_model( 261 | model_path, 262 | use_flash=use_flash, 263 | device=f"cuda:{lade.get_device()}", 264 | num_gpus=num_gpus_per_model, 265 | max_gpu_memory=max_gpu_memory, 266 | dtype=dtype, 267 | load_8bit=False, 268 | cpu_offloading=cpu_offloading, 269 | debug=debug, 270 | ) 271 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device()) 272 | model.tokenizer = tokenizer 273 | 274 | overall_time = 0 275 | overall_tp = 0 276 | overall_gen = 0 277 | count_gen = 0 278 | stats = {} 279 | for question_idx, description in enumerate(tqdm(questions["class_description"][:question_end])): 280 | if not do_sample: 281 | temperature = 0.0 #force greedy 282 | 283 | stats[question_idx] = {} # 284 | choices = [] 285 | for i in range(num_choices): 286 | torch.manual_seed(i) 287 | conv = get_conversation_template(model_id) 288 | turns = [] 289 | prompts = [] 290 | 291 | for j in range(1): 292 | qs = "" 293 | 294 | import_stat = '\n'.join(questions["import_statement"][question_idx]) 295 | qs += import_stat 296 | 297 | class_init = questions["class_constructor"][question_idx] 298 | class_init_list = class_init.split('\n') 299 | class_init_list[0] += " \n" + description 300 | class_init = '\n'.join(class_init_list) 301 | 302 | qs = qs + "\n" + class_init 303 | prompt = qs 304 | 305 | input_ids = tokenizer(prompt, return_tensors="pt", 306 | max_length=1024, truncation=True).input_ids.to("cuda") 307 | 308 | if temperature < 1e-4: 309 | do_sample = False 310 | else: 311 | do_sample = True 312 | 313 | 314 | # some models may error out when generating long outputs 315 | if True: 316 | start_time = time.time() 317 | output_ids = model.generate( 318 | input_ids, 319 | do_sample=do_sample, 320 | temperature=temperature, 321 | max_new_tokens=max_new_token, 322 | ) 323 | end_time = time.time() 324 | gap_time = end_time - start_time 325 | tokens = output_ids.numel() - input_ids.numel() 326 | overall_time += gap_time 327 | overall_gen += tokens 328 | overall_tp += tokens / gap_time 329 | count_gen += 1 330 | 331 | stats[question_idx][j] = [gap_time, tokens] 332 | if lade.get_device() == 0 and ds_local_rank == 0: 333 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time]) 334 | 335 | output = tokenizer.decode( 336 | output_ids[0].tolist(), 337 | skip_special_tokens=False, 338 | ) 339 | 340 | turns.append(output) 341 | prompts.append(prompt) 342 | 343 | choices.append({"index": i, "turns": turns, "prompts" : prompts}) 344 | 345 | if lade.get_device() == 0 and ds_local_rank == 0: 346 | # Dump answers 347 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 348 | with open(os.path.expanduser(answer_file), "a") as fout: 349 | ans_json = { 350 | "question_id": question_idx, 351 | "answer_id": shortuuid.uuid(), 352 | "model_id": model_id, 353 | "choices": choices, 354 | "tstamp": time.time(), 355 | } 356 | fout.write(json.dumps(ans_json) + "\n") 357 | #if question_idx == 1: 358 | # break 359 | 360 | if lade.get_device() == 0 and ds_local_rank == 0: 361 | torch.save(stats[question_idx], answer_file + ".pt") 362 | print("LOG SAVE TO ", answer_file + ".pt") 363 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}") 364 | lade.log_history() 365 | lade.save_log(answer_file + "-lade-log.pt") 366 | 367 | 368 | def reorg_answer_file(answer_file): 369 | """Sort by question id and de-duplication""" 370 | answers = {} 371 | with open(answer_file, "r") as fin: 372 | for l in fin: 373 | qid = json.loads(l)["question_id"] 374 | answers[qid] = l 375 | 376 | qids = sorted(list(answers.keys())) 377 | with open(answer_file, "w") as fout: 378 | for qid in qids: 379 | fout.write(answers[qid]) 380 | 381 | 382 | if __name__ == "__main__": 383 | parser = argparse.ArgumentParser() 384 | parser.add_argument( 385 | "--model-path", 386 | type=str, 387 | required=True, 388 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 389 | ) 390 | parser.add_argument( 391 | "--model-id", type=str, required=True, help="A custom name for the model." 392 | ) 393 | parser.add_argument( 394 | "--cache-dir", 395 | type=str, 396 | default="", 397 | ) 398 | parser.add_argument( 399 | "--debug", 400 | action="store_true", 401 | ) 402 | parser.add_argument( 403 | "--bench-name", 404 | type=str, 405 | default="classeval", 406 | help="The name of the benchmark question set.", 407 | ) 408 | parser.add_argument( 409 | "--question-begin", 410 | type=int, 411 | help="A debug option. The begin index of questions.", 412 | ) 413 | parser.add_argument( 414 | "--question-end", type=int, help="A debug option. The end index of questions." 415 | ) 416 | parser.add_argument( 417 | "--cpu_offloading", action="store_true" 418 | ) 419 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 420 | parser.add_argument( 421 | "--max-new-token", 422 | type=int, 423 | default=2048, 424 | help="The maximum number of new generated tokens.", 425 | ) 426 | parser.add_argument( 427 | "--num-choices", 428 | type=int, 429 | default=1, 430 | help="How many completion choices to generate.", 431 | ) 432 | parser.add_argument( 433 | "--num-gpus-per-model", 434 | type=int, 435 | default=1, 436 | help="The number of GPUs per model.", 437 | ) 438 | parser.add_argument( 439 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 440 | ) 441 | parser.add_argument( 442 | "--max-gpu-memory", 443 | type=str, 444 | help="Maxmum GPU memory used for model weights per GPU.", 445 | ) 446 | parser.add_argument( 447 | "--dtype", 448 | type=str, 449 | choices=["float32", "float64", "float16", "bfloat16"], 450 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 451 | default=None, 452 | ) 453 | parser.add_argument( 454 | "--local_rank", 455 | type=int, 456 | default=0, 457 | ) 458 | parser.add_argument( 459 | "--local-rank", 460 | type=int, 461 | default=0, 462 | ) 463 | parser.add_argument( 464 | "--level", 465 | type=int, 466 | default=3, 467 | ) 468 | parser.add_argument( 469 | "--window", 470 | type=int, 471 | default=10, 472 | ) 473 | parser.add_argument( 474 | "--guess", 475 | type=int, 476 | default=10, 477 | ) 478 | parser.add_argument( 479 | "--use-tp", 480 | type=int, 481 | default=0, 482 | ) 483 | parser.add_argument( 484 | "--use-pp", 485 | type=int, 486 | default=0, 487 | ) 488 | parser.add_argument( 489 | "--use-tp-ds", 490 | type=int, 491 | default=0, 492 | ) 493 | parser.add_argument( 494 | "--use-flash", 495 | type=int, 496 | default=0, 497 | ) 498 | parser.add_argument( 499 | "--do-sample", 500 | type=int, 501 | default=0, 502 | ) 503 | 504 | args = parser.parse_args() 505 | if int(os.environ.get("USE_LADE", 0)): 506 | 507 | lade.augment_all() 508 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(","))) 509 | print("lade activated config: ", lade.decoding.CONFIG_MAP) 510 | 511 | question_file = f"mtbench.jsonl" 512 | if args.answer_file: 513 | answer_file = args.answer_file 514 | else: 515 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 516 | 517 | print(f"Output to {answer_file}") 518 | 519 | run_eval( 520 | model_path=args.model_path, 521 | model_id=args.model_id, 522 | question_file=question_file, 523 | question_begin=args.question_begin, 524 | question_end=args.question_end, 525 | answer_file=answer_file, 526 | max_new_token=args.max_new_token, 527 | num_choices=args.num_choices, 528 | num_gpus_per_model=args.num_gpus_per_model, 529 | num_gpus_total=args.num_gpus_total, 530 | max_gpu_memory=args.max_gpu_memory, 531 | dtype=str_to_torch_dtype(args.dtype), 532 | debug=args.debug, 533 | cache_dir=args.cache_dir, 534 | cpu_offloading=args.cpu_offloading, 535 | use_pp=args.use_pp, 536 | use_tp_ds=args.use_tp_ds, 537 | use_tp=args.use_tp, 538 | use_flash=args.use_flash, 539 | do_sample=args.do_sample 540 | ) 541 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 542 | if lade.get_device() == 0 and ds_local_rank == 0: 543 | reorg_answer_file(answer_file) 544 | -------------------------------------------------------------------------------- /applications/eval_cnndm.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py 7 | import argparse 8 | import json 9 | import os 10 | import random 11 | import time 12 | 13 | import shortuuid 14 | import torch 15 | from tqdm import tqdm 16 | from typing import Dict, List, Optional 17 | from fastchat.llm_judge.common import load_questions, temperature_config 18 | from fastchat.model import get_conversation_template 19 | from fastchat.utils import str_to_torch_dtype 20 | import time 21 | import lade 22 | from datasets import load_dataset 23 | 24 | def run_eval( 25 | model_path, 26 | model_id, 27 | question_file, 28 | question_begin, 29 | question_end, 30 | answer_file, 31 | max_new_token, 32 | num_choices, 33 | num_gpus_per_model, 34 | num_gpus_total, 35 | max_gpu_memory, 36 | dtype, 37 | debug, 38 | cache_dir, 39 | cpu_offloading, 40 | use_pp, 41 | use_tp, 42 | use_tp_ds, 43 | use_flash, 44 | do_sample 45 | ): 46 | questions = load_dataset("cnn_dailymail", "3.0.0", split="validation", streaming=False)["article"][question_begin:question_end] 47 | # random shuffle the questions to balance the loading 48 | ###not shuffle 49 | #random.shuffle(questions) 50 | 51 | # Split the question file into `num_gpus` files 52 | assert num_gpus_total % num_gpus_per_model == 0 53 | 54 | get_answers_func = get_model_answers 55 | 56 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) 57 | ans_handles = [] 58 | for i in range(0, len(questions), chunk_size): 59 | ans_handles.append( 60 | get_answers_func( 61 | model_path, 62 | model_id, 63 | questions[i : i + chunk_size], 64 | answer_file, 65 | max_new_token, 66 | num_choices, 67 | num_gpus_per_model, 68 | max_gpu_memory, 69 | dtype=dtype, 70 | debug=debug, 71 | cache_dir=cache_dir, 72 | cpu_offloading=cpu_offloading, 73 | use_tp=use_tp, 74 | use_pp=use_pp, 75 | use_tp_ds=use_tp_ds, 76 | use_flash=use_flash, 77 | do_sample=do_sample 78 | ) 79 | ) 80 | 81 | 82 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM 83 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration 84 | 85 | def load_model( 86 | model_path: str, 87 | device: str = "cuda", 88 | device_map: str= "", 89 | num_gpus: int = 1, 90 | max_gpu_memory: Optional[str] = None, 91 | dtype: Optional[torch.dtype] = None, 92 | load_8bit: bool = False, 93 | cpu_offloading: bool = False, 94 | revision: str = "main", 95 | debug: bool = False, 96 | use_flash:bool = False 97 | ): 98 | """Load a model from Hugging Face.""" 99 | # get model adapter 100 | adapter = Llama2Adapter() 101 | # Handle device mapping 102 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( 103 | device, load_8bit, cpu_offloading 104 | ) 105 | if device == "cpu": 106 | kwargs = {"torch_dtype": torch.float32} 107 | if CPU_ISA in ["avx512_bf16", "amx"]: 108 | try: 109 | import intel_extension_for_pytorch as ipex 110 | 111 | kwargs = {"torch_dtype": torch.bfloat16} 112 | except ImportError: 113 | warnings.warn( 114 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" 115 | ) 116 | elif device.startswith("cuda"): 117 | kwargs = {"torch_dtype": torch.float16} 118 | if num_gpus != 1: 119 | kwargs["device_map"] = "auto" 120 | if max_gpu_memory is None: 121 | kwargs[ 122 | "device_map" 123 | ] = "sequential" # This is important for not the same VRAM sizes 124 | available_gpu_memory = get_gpu_memory(num_gpus) 125 | kwargs["max_memory"] = { 126 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" 127 | for i in range(num_gpus) 128 | } 129 | else: 130 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} 131 | 132 | if cpu_offloading: 133 | # raises an error on incompatible platforms 134 | from transformers import BitsAndBytesConfig 135 | 136 | if "max_memory" in kwargs: 137 | kwargs["max_memory"]["cpu"] = ( 138 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" 139 | ) 140 | kwargs["quantization_config"] = BitsAndBytesConfig( 141 | load_in_8bit_fp32_cpu_offload=cpu_offloading 142 | ) 143 | kwargs["load_in_8bit"] = load_8bit 144 | elif load_8bit: 145 | if num_gpus != 1: 146 | warnings.warn( 147 | "8-bit quantization is not supported for multi-gpu inference." 148 | ) 149 | else: 150 | model, tokenizer = adapter.load_compress_model( 151 | model_path=model_path, 152 | device=device, 153 | torch_dtype=kwargs["torch_dtype"], 154 | revision=revision, 155 | ) 156 | if debug: 157 | print(model) 158 | return model, tokenizer 159 | kwargs["revision"] = revision 160 | 161 | if dtype is not None: # Overwrite dtype if it is provided in the arguments. 162 | kwargs["torch_dtype"] = dtype 163 | if use_flash: 164 | kwargs["use_flash_attention_2"] = use_flash 165 | if len(device_map) > 0: 166 | kwargs["device_map"] = device_map 167 | # Load model 168 | model, tokenizer = adapter.load_model(model_path, kwargs) 169 | 170 | if len(device_map) > 0: 171 | return model, tokenizer 172 | 173 | if ( 174 | device == "cpu" 175 | and kwargs["torch_dtype"] is torch.bfloat16 176 | and CPU_ISA is not None 177 | ): 178 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) 179 | 180 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in ( 181 | "mps", 182 | "xpu", 183 | "npu", 184 | ): 185 | model.to(device) 186 | 187 | if device == "xpu": 188 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) 189 | 190 | if debug: 191 | print(model) 192 | 193 | return model, tokenizer 194 | 195 | #@torch.inference_mode() 196 | def get_model_answers( 197 | model_path, 198 | model_id, 199 | questions, 200 | answer_file, 201 | max_new_token, 202 | num_choices, 203 | num_gpus_per_model, 204 | max_gpu_memory, 205 | dtype, 206 | debug, 207 | cache_dir, 208 | cpu_offloading, 209 | use_pp, 210 | use_tp_ds, 211 | use_tp, 212 | use_flash, 213 | do_sample 214 | ): 215 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 216 | 217 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices) 218 | 219 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 220 | if use_pp: 221 | model, tokenizer = load_model( 222 | model_path, 223 | use_flash=use_flash, 224 | device=f"cuda", 225 | device_map="balanced", 226 | num_gpus=num_gpus_per_model, 227 | max_gpu_memory=max_gpu_memory, 228 | dtype=dtype, 229 | load_8bit=False, 230 | cpu_offloading=cpu_offloading, 231 | debug=debug, 232 | ) 233 | 234 | elif use_tp_ds: 235 | import deepspeed 236 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0'))) 237 | model, tokenizer = load_model( 238 | model_path, 239 | use_flash=use_flash, 240 | device_map="cpu", 241 | num_gpus=num_gpus_per_model, 242 | max_gpu_memory=max_gpu_memory, 243 | dtype=dtype, 244 | load_8bit=False, 245 | cpu_offloading=cpu_offloading, 246 | debug=debug, 247 | ) 248 | model = deepspeed.init_inference( 249 | model, 250 | mp_size=int(os.getenv("WORLD_SIZE", "1")), 251 | dtype=torch.half 252 | ) 253 | else: 254 | model, tokenizer = load_model( 255 | model_path, 256 | use_flash=use_flash, 257 | device=f"cuda:{lade.get_device()}", 258 | num_gpus=num_gpus_per_model, 259 | max_gpu_memory=max_gpu_memory, 260 | dtype=dtype, 261 | load_8bit=False, 262 | cpu_offloading=cpu_offloading, 263 | debug=debug, 264 | ) 265 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device()) 266 | model.tokenizer = tokenizer 267 | 268 | overall_time = 0 269 | overall_tp = 0 270 | overall_gen = 0 271 | count_gen = 0 272 | stats = {} 273 | for question_idx, question in enumerate(tqdm(questions)): 274 | 275 | stats[question_idx] = {} # 276 | choices = [] 277 | for i in range(num_choices): 278 | torch.manual_seed(i) 279 | conv = get_conversation_template(model_id) 280 | turns = [] 281 | prompts = [] 282 | 283 | for j in range(1): 284 | 285 | prompt = f'''[INST] <> 286 | You are an intelligent chatbot. Answer the questions only using the following context: 287 | 288 | {question} 289 | 290 | Here are some rules you always follow: 291 | 292 | - Generate human readable output, avoid creating output with gibberish text. 293 | - Generate only the requested output, don't include any other language before or after the requested output. 294 | - Never say thank you, that you are happy to help, that you are an AI agent, etc. Just answer directly. 295 | - Generate professional language typically used in business documents in North America. 296 | - Never generate offensive or foul language. 297 | 298 | <> 299 | 300 | Briefly summarize the given context. [/INST] 301 | Summary: ''' 302 | 303 | prompts.append(prompt) 304 | 305 | input_ids = tokenizer([prompt]).input_ids 306 | 307 | #print("len: ", len(input_ids[0])) 308 | if len(input_ids[0]) > 2048: #skip input len > 2048 tokens 309 | continue 310 | 311 | # some models may error out when generating long outputs 312 | if True: 313 | if do_sample: 314 | start_time = time.time() 315 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=True, top_k=0, temperature=1.0, top_p=1.0) 316 | end_time = time.time() 317 | else: 318 | start_time = time.time() 319 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=False, top_k=0) 320 | end_time = time.time() 321 | 322 | gap_time = end_time - start_time 323 | tokens = output_ids.numel() - len(input_ids[0]) 324 | overall_time += gap_time 325 | overall_gen += tokens 326 | overall_tp += tokens / gap_time 327 | count_gen += 1 328 | 329 | stats[question_idx][j] = [gap_time, tokens] 330 | if lade.get_device() == 0 and ds_local_rank == 0: 331 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time]) 332 | 333 | if model.config.is_encoder_decoder: 334 | output_ids = output_ids[0] 335 | else: 336 | output_ids = output_ids[0][len(input_ids[0]) :] 337 | 338 | # be consistent with the template's stop_token_ids 339 | if conv.stop_token_ids: 340 | stop_token_ids_index = [ 341 | i 342 | for i, id in enumerate(output_ids) 343 | if id in conv.stop_token_ids 344 | ] 345 | if len(stop_token_ids_index) > 0: 346 | output_ids = output_ids[: stop_token_ids_index[0]] 347 | 348 | output = tokenizer.decode( 349 | output_ids, 350 | spaces_between_special_tokens=False, 351 | ) 352 | if conv.stop_str and output.find(conv.stop_str) > 0: 353 | output = output[: output.find(conv.stop_str)] 354 | for special_token in tokenizer.special_tokens_map.values(): 355 | if isinstance(special_token, list): 356 | for special_tok in special_token: 357 | output = output.replace(special_tok, "") 358 | else: 359 | output = output.replace(special_token, "") 360 | output = output.strip() 361 | 362 | if conv.name == "xgen" and output.startswith("Assistant:"): 363 | output = output.replace("Assistant:", "", 1).strip() 364 | 365 | #print("output: ", output) 366 | ''' 367 | except RuntimeError as e: 368 | print("ERROR question ID: ", question["question_id"]) 369 | output = "ERROR" 370 | ''' 371 | turns.append(output) 372 | 373 | 374 | choices.append({"index": i, "turns": turns, "prompts" : prompts}) 375 | 376 | if lade.get_device() == 0 and ds_local_rank == 0: 377 | # Dump answers 378 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 379 | with open(os.path.expanduser(answer_file), "a") as fout: 380 | ans_json = { 381 | "question_id": question_idx, 382 | "answer_id": shortuuid.uuid(), 383 | "model_id": model_id, 384 | "choices": choices, 385 | "tstamp": time.time(), 386 | } 387 | fout.write(json.dumps(ans_json) + "\n") 388 | #if question_idx == 1: 389 | # break 390 | 391 | if lade.get_device() == 0 and ds_local_rank == 0: 392 | torch.save(stats[question_idx], answer_file + ".pt") 393 | print("LOG SAVE TO ", answer_file + ".pt") 394 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}") 395 | lade.log_history() 396 | lade.save_log(answer_file + "-lade-log.pt") 397 | 398 | 399 | def reorg_answer_file(answer_file): 400 | """Sort by question id and de-duplication""" 401 | answers = {} 402 | with open(answer_file, "r") as fin: 403 | for l in fin: 404 | qid = json.loads(l)["question_id"] 405 | answers[qid] = l 406 | 407 | qids = sorted(list(answers.keys())) 408 | with open(answer_file, "w") as fout: 409 | for qid in qids: 410 | fout.write(answers[qid]) 411 | 412 | 413 | if __name__ == "__main__": 414 | parser = argparse.ArgumentParser() 415 | parser.add_argument( 416 | "--model-path", 417 | type=str, 418 | required=True, 419 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 420 | ) 421 | parser.add_argument( 422 | "--model-id", type=str, required=True, help="A custom name for the model." 423 | ) 424 | parser.add_argument( 425 | "--cache-dir", 426 | type=str, 427 | default="", 428 | ) 429 | parser.add_argument( 430 | "--debug", 431 | action="store_true", 432 | ) 433 | parser.add_argument( 434 | "--bench-name", 435 | type=str, 436 | default="cnndm", 437 | help="The name of the benchmark question set.", 438 | ) 439 | parser.add_argument( 440 | "--question-begin", 441 | type=int, 442 | help="A debug option. The begin index of questions.", 443 | ) 444 | parser.add_argument( 445 | "--question-end", type=int, help="A debug option. The end index of questions." 446 | ) 447 | parser.add_argument( 448 | "--cpu_offloading", action="store_true" 449 | ) 450 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 451 | parser.add_argument( 452 | "--max-new-token", 453 | type=int, 454 | default=1024, 455 | help="The maximum number of new generated tokens.", 456 | ) 457 | parser.add_argument( 458 | "--num-choices", 459 | type=int, 460 | default=1, 461 | help="How many completion choices to generate.", 462 | ) 463 | parser.add_argument( 464 | "--num-gpus-per-model", 465 | type=int, 466 | default=1, 467 | help="The number of GPUs per model.", 468 | ) 469 | parser.add_argument( 470 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 471 | ) 472 | parser.add_argument( 473 | "--max-gpu-memory", 474 | type=str, 475 | help="Maxmum GPU memory used for model weights per GPU.", 476 | ) 477 | parser.add_argument( 478 | "--dtype", 479 | type=str, 480 | choices=["float32", "float64", "float16", "bfloat16"], 481 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 482 | default=None, 483 | ) 484 | parser.add_argument( 485 | "--local_rank", 486 | type=int, 487 | default=0, 488 | ) 489 | parser.add_argument( 490 | "--local-rank", 491 | type=int, 492 | default=0, 493 | ) 494 | parser.add_argument( 495 | "--level", 496 | type=int, 497 | default=3, 498 | ) 499 | parser.add_argument( 500 | "--window", 501 | type=int, 502 | default=10, 503 | ) 504 | parser.add_argument( 505 | "--guess", 506 | type=int, 507 | default=10, 508 | ) 509 | parser.add_argument( 510 | "--use-tp", 511 | type=int, 512 | default=0, 513 | ) 514 | parser.add_argument( 515 | "--use-pp", 516 | type=int, 517 | default=0, 518 | ) 519 | parser.add_argument( 520 | "--use-tp-ds", 521 | type=int, 522 | default=0, 523 | ) 524 | parser.add_argument( 525 | "--use-flash", 526 | type=int, 527 | default=0, 528 | ) 529 | parser.add_argument( 530 | "--do-sample", 531 | type=int, 532 | default=0, 533 | ) 534 | 535 | args = parser.parse_args() 536 | if int(os.environ.get("USE_LADE", 0)): 537 | 538 | lade.augment_all() 539 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(","))) 540 | print("lade activated config: ", lade.decoding.CONFIG_MAP) 541 | 542 | question_file = f"mtbench.jsonl" 543 | if args.answer_file: 544 | answer_file = args.answer_file 545 | else: 546 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 547 | 548 | print(f"Output to {answer_file}") 549 | 550 | run_eval( 551 | model_path=args.model_path, 552 | model_id=args.model_id, 553 | question_file=question_file, 554 | question_begin=args.question_begin, 555 | question_end=args.question_end, 556 | answer_file=answer_file, 557 | max_new_token=args.max_new_token, 558 | num_choices=args.num_choices, 559 | num_gpus_per_model=args.num_gpus_per_model, 560 | num_gpus_total=args.num_gpus_total, 561 | max_gpu_memory=args.max_gpu_memory, 562 | dtype=str_to_torch_dtype(args.dtype), 563 | debug=args.debug, 564 | cache_dir=args.cache_dir, 565 | cpu_offloading=args.cpu_offloading, 566 | use_pp=args.use_pp, 567 | use_tp_ds=args.use_tp_ds, 568 | use_tp=args.use_tp, 569 | use_flash=args.use_flash, 570 | do_sample=args.do_sample 571 | ) 572 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 573 | if lade.get_device() == 0 and ds_local_rank == 0: 574 | reorg_answer_file(answer_file) 575 | -------------------------------------------------------------------------------- /applications/eval_humaneval.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py 7 | import argparse 8 | import json 9 | import os 10 | import random 11 | import time 12 | 13 | import shortuuid 14 | import torch 15 | from tqdm import tqdm 16 | from typing import Dict, List, Optional 17 | from fastchat.llm_judge.common import load_questions, temperature_config 18 | from fastchat.model import get_conversation_template 19 | from fastchat.utils import str_to_torch_dtype 20 | import time 21 | import lade 22 | from human_eval.data import write_jsonl, read_problems 23 | 24 | def run_eval( 25 | model_path, 26 | model_id, 27 | question_file, 28 | question_begin, 29 | question_end, 30 | answer_file, 31 | max_new_token, 32 | num_choices, 33 | num_gpus_per_model, 34 | num_gpus_total, 35 | max_gpu_memory, 36 | dtype, 37 | debug, 38 | cache_dir, 39 | cpu_offloading, 40 | use_pp, 41 | use_tp, 42 | use_tp_ds, 43 | use_flash, 44 | do_sample 45 | ): 46 | #questions = load_questions(question_file, question_begin, question_end) 47 | questions = read_problems() 48 | questions = list(questions.values())[question_begin:question_end] 49 | 50 | # random shuffle the questions to balance the loading 51 | ###not shuffle 52 | #random.shuffle(questions) 53 | 54 | # Split the question file into `num_gpus` files 55 | assert num_gpus_total % num_gpus_per_model == 0 56 | 57 | get_answers_func = get_model_answers 58 | 59 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) 60 | ans_handles = [] 61 | for i in range(0, len(questions), chunk_size): 62 | ans_handles.append( 63 | get_answers_func( 64 | model_path, 65 | model_id, 66 | questions[i : i + chunk_size], 67 | answer_file, 68 | max_new_token, 69 | num_choices, 70 | num_gpus_per_model, 71 | max_gpu_memory, 72 | dtype=dtype, 73 | debug=debug, 74 | cache_dir=cache_dir, 75 | cpu_offloading=cpu_offloading, 76 | use_tp=use_tp, 77 | use_pp=use_pp, 78 | use_tp_ds=use_tp_ds, 79 | use_flash=use_flash, 80 | do_sample=do_sample 81 | ) 82 | ) 83 | 84 | 85 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM 86 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration 87 | 88 | def load_model( 89 | model_path: str, 90 | device: str = "cuda", 91 | device_map: str= "", 92 | num_gpus: int = 1, 93 | max_gpu_memory: Optional[str] = None, 94 | dtype: Optional[torch.dtype] = None, 95 | load_8bit: bool = False, 96 | cpu_offloading: bool = False, 97 | revision: str = "main", 98 | debug: bool = False, 99 | use_flash:bool = False 100 | ): 101 | """Load a model from Hugging Face.""" 102 | # get model adapter 103 | adapter = Llama2Adapter() 104 | # Handle device mapping 105 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( 106 | device, load_8bit, cpu_offloading 107 | ) 108 | if device == "cpu": 109 | kwargs = {"torch_dtype": torch.float32} 110 | if CPU_ISA in ["avx512_bf16", "amx"]: 111 | try: 112 | import intel_extension_for_pytorch as ipex 113 | 114 | kwargs = {"torch_dtype": torch.bfloat16} 115 | except ImportError: 116 | warnings.warn( 117 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" 118 | ) 119 | elif device.startswith("cuda"): 120 | kwargs = {"torch_dtype": torch.float16} 121 | if num_gpus != 1: 122 | kwargs["device_map"] = "auto" 123 | if max_gpu_memory is None: 124 | kwargs[ 125 | "device_map" 126 | ] = "sequential" # This is important for not the same VRAM sizes 127 | available_gpu_memory = get_gpu_memory(num_gpus) 128 | kwargs["max_memory"] = { 129 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" 130 | for i in range(num_gpus) 131 | } 132 | else: 133 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} 134 | 135 | if cpu_offloading: 136 | # raises an error on incompatible platforms 137 | from transformers import BitsAndBytesConfig 138 | 139 | if "max_memory" in kwargs: 140 | kwargs["max_memory"]["cpu"] = ( 141 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" 142 | ) 143 | kwargs["quantization_config"] = BitsAndBytesConfig( 144 | load_in_8bit_fp32_cpu_offload=cpu_offloading 145 | ) 146 | kwargs["load_in_8bit"] = load_8bit 147 | elif load_8bit: 148 | if num_gpus != 1: 149 | warnings.warn( 150 | "8-bit quantization is not supported for multi-gpu inference." 151 | ) 152 | else: 153 | model, tokenizer = adapter.load_compress_model( 154 | model_path=model_path, 155 | device=device, 156 | torch_dtype=kwargs["torch_dtype"], 157 | revision=revision, 158 | ) 159 | if debug: 160 | print(model) 161 | return model, tokenizer 162 | kwargs["revision"] = revision 163 | 164 | if dtype is not None: # Overwrite dtype if it is provided in the arguments. 165 | kwargs["torch_dtype"] = dtype 166 | if use_flash: 167 | kwargs["use_flash_attention_2"] = use_flash 168 | if len(device_map) > 0: 169 | kwargs["device_map"] = device_map 170 | # Load model 171 | model, tokenizer = adapter.load_model(model_path, kwargs) 172 | 173 | if len(device_map) > 0: 174 | return model, tokenizer 175 | 176 | if ( 177 | device == "cpu" 178 | and kwargs["torch_dtype"] is torch.bfloat16 179 | and CPU_ISA is not None 180 | ): 181 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) 182 | 183 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in ( 184 | "mps", 185 | "xpu", 186 | "npu", 187 | ): 188 | model.to(device) 189 | 190 | if device == "xpu": 191 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) 192 | 193 | if debug: 194 | print(model) 195 | 196 | return model, tokenizer 197 | 198 | #@torch.inference_mode() 199 | def get_model_answers( 200 | model_path, 201 | model_id, 202 | questions, 203 | answer_file, 204 | max_new_token, 205 | num_choices, 206 | num_gpus_per_model, 207 | max_gpu_memory, 208 | dtype, 209 | debug, 210 | cache_dir, 211 | cpu_offloading, 212 | use_pp, 213 | use_tp_ds, 214 | use_tp, 215 | use_flash, 216 | do_sample 217 | ): 218 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 219 | 220 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices) 221 | #tokenizer = AutoTokenizer.from_pretrained(model_path) 222 | #cfg = AutoConfig.from_pretrained(model_path) 223 | #cfg._flash_attn_2_enabled= use_flash 224 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 225 | if use_pp: 226 | model, tokenizer = load_model( 227 | model_path, 228 | use_flash=use_flash, 229 | device=f"cuda", 230 | device_map="balanced", 231 | num_gpus=num_gpus_per_model, 232 | max_gpu_memory=max_gpu_memory, 233 | dtype=dtype, 234 | load_8bit=False, 235 | cpu_offloading=cpu_offloading, 236 | debug=debug, 237 | ) 238 | 239 | elif use_tp_ds: 240 | import deepspeed 241 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0'))) 242 | model, tokenizer = load_model( 243 | model_path, 244 | use_flash=use_flash, 245 | device_map="cpu", 246 | num_gpus=num_gpus_per_model, 247 | max_gpu_memory=max_gpu_memory, 248 | dtype=dtype, 249 | load_8bit=False, 250 | cpu_offloading=cpu_offloading, 251 | debug=debug, 252 | ) 253 | model = deepspeed.init_inference( 254 | model, 255 | mp_size=int(os.getenv("WORLD_SIZE", "1")), 256 | dtype=torch.half 257 | ) 258 | else: 259 | model, tokenizer = load_model( 260 | model_path, 261 | use_flash=use_flash, 262 | device=f"cuda:{lade.get_device()}", 263 | num_gpus=num_gpus_per_model, 264 | max_gpu_memory=max_gpu_memory, 265 | dtype=dtype, 266 | load_8bit=False, 267 | cpu_offloading=cpu_offloading, 268 | debug=debug, 269 | ) 270 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device()) 271 | model.tokenizer = tokenizer 272 | 273 | overall_time = 0 274 | overall_tp = 0 275 | overall_gen = 0 276 | count_gen = 0 277 | stats = {} 278 | for question_idx, question in enumerate(tqdm(questions)): 279 | 280 | if not do_sample: 281 | temperature = 0.0 #force greedy 282 | 283 | stats[question_idx] = {} # 284 | choices = [] 285 | for i in range(num_choices): 286 | torch.manual_seed(i) 287 | conv = get_conversation_template(model_id) 288 | turns = [] 289 | prompts = [] 290 | 291 | for j in range(1): 292 | qs = question["prompt"] 293 | prompt = qs 294 | 295 | input_ids = tokenizer(prompt, return_tensors="pt", 296 | max_length=1024, truncation=True).input_ids.to("cuda") 297 | 298 | if temperature < 1e-4: 299 | do_sample = False 300 | else: 301 | do_sample = True 302 | 303 | 304 | # some models may error out when generating long outputs 305 | if True: 306 | start_time = time.time() 307 | output_ids = model.generate( 308 | input_ids, 309 | do_sample=do_sample, 310 | temperature=temperature, 311 | max_new_tokens=max_new_token, 312 | ) 313 | end_time = time.time() 314 | gap_time = end_time - start_time 315 | tokens = output_ids.numel() - input_ids.numel() 316 | overall_time += gap_time 317 | overall_gen += tokens 318 | overall_tp += tokens / gap_time 319 | count_gen += 1 320 | 321 | stats[question_idx][j] = [gap_time, tokens] 322 | if lade.get_device() == 0 and ds_local_rank == 0: 323 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time]) 324 | 325 | output = tokenizer.decode( 326 | output_ids[0].tolist(), 327 | skip_special_tokens=False, 328 | ) 329 | 330 | turns.append(output) 331 | prompts.append(prompt) 332 | 333 | choices.append({"index": i, "turns": turns, "prompts" : prompts}) 334 | 335 | if lade.get_device() == 0 and ds_local_rank == 0: 336 | # Dump answers 337 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 338 | with open(os.path.expanduser(answer_file), "a") as fout: 339 | ans_json = { 340 | "question_id": question_idx, 341 | "answer_id": shortuuid.uuid(), 342 | "model_id": model_id, 343 | "choices": choices, 344 | "tstamp": time.time(), 345 | } 346 | fout.write(json.dumps(ans_json) + "\n") 347 | #if question_idx == 1: 348 | # break 349 | 350 | if lade.get_device() == 0 and ds_local_rank == 0: 351 | torch.save(stats[question_idx], answer_file + ".pt") 352 | print("LOG SAVE TO ", answer_file + ".pt") 353 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}") 354 | lade.log_history() 355 | lade.save_log(answer_file + "-lade-log.pt") 356 | 357 | 358 | def reorg_answer_file(answer_file): 359 | """Sort by question id and de-duplication""" 360 | answers = {} 361 | with open(answer_file, "r") as fin: 362 | for l in fin: 363 | qid = json.loads(l)["question_id"] 364 | answers[qid] = l 365 | 366 | qids = sorted(list(answers.keys())) 367 | with open(answer_file, "w") as fout: 368 | for qid in qids: 369 | fout.write(answers[qid]) 370 | 371 | 372 | if __name__ == "__main__": 373 | parser = argparse.ArgumentParser() 374 | parser.add_argument( 375 | "--model-path", 376 | type=str, 377 | required=True, 378 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 379 | ) 380 | parser.add_argument( 381 | "--model-id", type=str, required=True, help="A custom name for the model." 382 | ) 383 | parser.add_argument( 384 | "--cache-dir", 385 | type=str, 386 | default="", 387 | ) 388 | parser.add_argument( 389 | "--debug", 390 | action="store_true", 391 | ) 392 | parser.add_argument( 393 | "--bench-name", 394 | type=str, 395 | default="humaneval", 396 | help="The name of the benchmark question set.", 397 | ) 398 | parser.add_argument( 399 | "--question-begin", 400 | type=int, 401 | help="A debug option. The begin index of questions.", 402 | ) 403 | parser.add_argument( 404 | "--question-end", type=int, help="A debug option. The end index of questions." 405 | ) 406 | parser.add_argument( 407 | "--cpu_offloading", action="store_true" 408 | ) 409 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 410 | parser.add_argument( 411 | "--max-new-token", 412 | type=int, 413 | default=512, 414 | help="The maximum number of new generated tokens.", 415 | ) 416 | parser.add_argument( 417 | "--num-choices", 418 | type=int, 419 | default=1, 420 | help="How many completion choices to generate.", 421 | ) 422 | parser.add_argument( 423 | "--num-gpus-per-model", 424 | type=int, 425 | default=1, 426 | help="The number of GPUs per model.", 427 | ) 428 | parser.add_argument( 429 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 430 | ) 431 | parser.add_argument( 432 | "--max-gpu-memory", 433 | type=str, 434 | help="Maxmum GPU memory used for model weights per GPU.", 435 | ) 436 | parser.add_argument( 437 | "--dtype", 438 | type=str, 439 | choices=["float32", "float64", "float16", "bfloat16"], 440 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 441 | default=None, 442 | ) 443 | parser.add_argument( 444 | "--local_rank", 445 | type=int, 446 | default=0, 447 | ) 448 | parser.add_argument( 449 | "--local-rank", 450 | type=int, 451 | default=0, 452 | ) 453 | parser.add_argument( 454 | "--level", 455 | type=int, 456 | default=3, 457 | ) 458 | parser.add_argument( 459 | "--window", 460 | type=int, 461 | default=10, 462 | ) 463 | parser.add_argument( 464 | "--guess", 465 | type=int, 466 | default=10, 467 | ) 468 | parser.add_argument( 469 | "--use-tp", 470 | type=int, 471 | default=0, 472 | ) 473 | parser.add_argument( 474 | "--use-pp", 475 | type=int, 476 | default=0, 477 | ) 478 | parser.add_argument( 479 | "--use-tp-ds", 480 | type=int, 481 | default=0, 482 | ) 483 | parser.add_argument( 484 | "--use-flash", 485 | type=int, 486 | default=0, 487 | ) 488 | parser.add_argument( 489 | "--do-sample", 490 | type=int, 491 | default=0, 492 | ) 493 | 494 | args = parser.parse_args() 495 | if int(os.environ.get("USE_LADE", 0)): 496 | 497 | lade.augment_all() 498 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(","))) 499 | print("lade activated config: ", lade.decoding.CONFIG_MAP) 500 | 501 | question_file = f"mtbench.jsonl" 502 | if args.answer_file: 503 | answer_file = args.answer_file 504 | else: 505 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 506 | 507 | print(f"Output to {answer_file}") 508 | 509 | run_eval( 510 | model_path=args.model_path, 511 | model_id=args.model_id, 512 | question_file=question_file, 513 | question_begin=args.question_begin, 514 | question_end=args.question_end, 515 | answer_file=answer_file, 516 | max_new_token=args.max_new_token, 517 | num_choices=args.num_choices, 518 | num_gpus_per_model=args.num_gpus_per_model, 519 | num_gpus_total=args.num_gpus_total, 520 | max_gpu_memory=args.max_gpu_memory, 521 | dtype=str_to_torch_dtype(args.dtype), 522 | debug=args.debug, 523 | cache_dir=args.cache_dir, 524 | cpu_offloading=args.cpu_offloading, 525 | use_pp=args.use_pp, 526 | use_tp_ds=args.use_tp_ds, 527 | use_tp=args.use_tp, 528 | use_flash=args.use_flash, 529 | do_sample=args.do_sample 530 | ) 531 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 532 | if lade.get_device() == 0 and ds_local_rank == 0: 533 | reorg_answer_file(answer_file) 534 | -------------------------------------------------------------------------------- /applications/eval_mtbench.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py 7 | import argparse 8 | import json 9 | import os 10 | import random 11 | import time 12 | 13 | import shortuuid 14 | import torch 15 | from tqdm import tqdm 16 | from typing import Dict, List, Optional 17 | from fastchat.llm_judge.common import load_questions, temperature_config 18 | from fastchat.model import get_conversation_template 19 | from fastchat.utils import str_to_torch_dtype 20 | import time 21 | import lade 22 | 23 | def run_eval( 24 | model_path, 25 | model_id, 26 | question_file, 27 | question_begin, 28 | question_end, 29 | answer_file, 30 | max_new_token, 31 | num_choices, 32 | num_gpus_per_model, 33 | num_gpus_total, 34 | max_gpu_memory, 35 | dtype, 36 | debug, 37 | cache_dir, 38 | cpu_offloading, 39 | use_pp, 40 | use_tp, 41 | use_tp_ds, 42 | use_flash, 43 | do_sample 44 | ): 45 | questions = load_questions(question_file, question_begin, question_end) 46 | # random shuffle the questions to balance the loading 47 | ###not shuffle 48 | #random.shuffle(questions) 49 | 50 | # Split the question file into `num_gpus` files 51 | assert num_gpus_total % num_gpus_per_model == 0 52 | 53 | get_answers_func = get_model_answers 54 | 55 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) 56 | ans_handles = [] 57 | for i in range(0, len(questions), chunk_size): 58 | ans_handles.append( 59 | get_answers_func( 60 | model_path, 61 | model_id, 62 | questions[i : i + chunk_size], 63 | answer_file, 64 | max_new_token, 65 | num_choices, 66 | num_gpus_per_model, 67 | max_gpu_memory, 68 | dtype=dtype, 69 | debug=debug, 70 | cache_dir=cache_dir, 71 | cpu_offloading=cpu_offloading, 72 | use_tp=use_tp, 73 | use_pp=use_pp, 74 | use_tp_ds=use_tp_ds, 75 | use_flash=use_flash, 76 | do_sample=do_sample 77 | ) 78 | ) 79 | 80 | 81 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM 82 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration 83 | 84 | def load_model( 85 | model_path: str, 86 | device: str = "cuda", 87 | device_map: str= "", 88 | num_gpus: int = 1, 89 | max_gpu_memory: Optional[str] = None, 90 | dtype: Optional[torch.dtype] = None, 91 | load_8bit: bool = False, 92 | cpu_offloading: bool = False, 93 | revision: str = "main", 94 | debug: bool = False, 95 | use_flash:bool = False 96 | ): 97 | """Load a model from Hugging Face.""" 98 | # get model adapter 99 | adapter = Llama2Adapter() 100 | # Handle device mapping 101 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( 102 | device, load_8bit, cpu_offloading 103 | ) 104 | if device == "cpu": 105 | kwargs = {"torch_dtype": torch.float32} 106 | if CPU_ISA in ["avx512_bf16", "amx"]: 107 | try: 108 | import intel_extension_for_pytorch as ipex 109 | 110 | kwargs = {"torch_dtype": torch.bfloat16} 111 | except ImportError: 112 | warnings.warn( 113 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" 114 | ) 115 | elif device.startswith("cuda"): 116 | kwargs = {"torch_dtype": torch.float16} 117 | if num_gpus != 1: 118 | kwargs["device_map"] = "auto" 119 | if max_gpu_memory is None: 120 | kwargs[ 121 | "device_map" 122 | ] = "sequential" # This is important for not the same VRAM sizes 123 | available_gpu_memory = get_gpu_memory(num_gpus) 124 | kwargs["max_memory"] = { 125 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" 126 | for i in range(num_gpus) 127 | } 128 | else: 129 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} 130 | 131 | if cpu_offloading: 132 | # raises an error on incompatible platforms 133 | from transformers import BitsAndBytesConfig 134 | 135 | if "max_memory" in kwargs: 136 | kwargs["max_memory"]["cpu"] = ( 137 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" 138 | ) 139 | kwargs["quantization_config"] = BitsAndBytesConfig( 140 | load_in_8bit_fp32_cpu_offload=cpu_offloading 141 | ) 142 | kwargs["load_in_8bit"] = load_8bit 143 | elif load_8bit: 144 | if num_gpus != 1: 145 | warnings.warn( 146 | "8-bit quantization is not supported for multi-gpu inference." 147 | ) 148 | else: 149 | model, tokenizer = adapter.load_compress_model( 150 | model_path=model_path, 151 | device=device, 152 | torch_dtype=kwargs["torch_dtype"], 153 | revision=revision, 154 | ) 155 | if debug: 156 | print(model) 157 | return model, tokenizer 158 | kwargs["revision"] = revision 159 | 160 | if dtype is not None: # Overwrite dtype if it is provided in the arguments. 161 | kwargs["torch_dtype"] = dtype 162 | if use_flash: 163 | kwargs["use_flash_attention_2"] = use_flash 164 | if len(device_map) > 0: 165 | kwargs["device_map"] = device_map 166 | # Load model 167 | model, tokenizer = adapter.load_model(model_path, kwargs) 168 | 169 | if len(device_map) > 0: 170 | return model, tokenizer 171 | 172 | if ( 173 | device == "cpu" 174 | and kwargs["torch_dtype"] is torch.bfloat16 175 | and CPU_ISA is not None 176 | ): 177 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) 178 | 179 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in ( 180 | "mps", 181 | "xpu", 182 | "npu", 183 | ): 184 | model.to(device) 185 | 186 | if device == "xpu": 187 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) 188 | 189 | if debug: 190 | print(model) 191 | 192 | return model, tokenizer 193 | 194 | #@torch.inference_mode() 195 | def get_model_answers( 196 | model_path, 197 | model_id, 198 | questions, 199 | answer_file, 200 | max_new_token, 201 | num_choices, 202 | num_gpus_per_model, 203 | max_gpu_memory, 204 | dtype, 205 | debug, 206 | cache_dir, 207 | cpu_offloading, 208 | use_pp, 209 | use_tp_ds, 210 | use_tp, 211 | use_flash, 212 | do_sample 213 | ): 214 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 215 | 216 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices) 217 | 218 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 219 | if use_pp: 220 | model, tokenizer = load_model( 221 | model_path, 222 | use_flash=use_flash, 223 | device=f"cuda", 224 | device_map="balanced", 225 | num_gpus=num_gpus_per_model, 226 | max_gpu_memory=max_gpu_memory, 227 | dtype=dtype, 228 | load_8bit=False, 229 | cpu_offloading=cpu_offloading, 230 | debug=debug, 231 | ) 232 | 233 | elif use_tp_ds: 234 | import deepspeed 235 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0'))) 236 | model, tokenizer = load_model( 237 | model_path, 238 | use_flash=use_flash, 239 | device_map="cpu", 240 | num_gpus=num_gpus_per_model, 241 | max_gpu_memory=max_gpu_memory, 242 | dtype=dtype, 243 | load_8bit=False, 244 | cpu_offloading=cpu_offloading, 245 | debug=debug, 246 | ) 247 | model = deepspeed.init_inference( 248 | model, 249 | mp_size=int(os.getenv("WORLD_SIZE", "1")), 250 | dtype=torch.half 251 | ) 252 | else: 253 | model, tokenizer = load_model( 254 | model_path, 255 | use_flash=use_flash, 256 | device=f"cuda:{lade.get_device()}", 257 | num_gpus=num_gpus_per_model, 258 | max_gpu_memory=max_gpu_memory, 259 | dtype=dtype, 260 | load_8bit=False, 261 | cpu_offloading=cpu_offloading, 262 | debug=debug, 263 | ) 264 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device()) 265 | model.tokenizer = tokenizer 266 | 267 | overall_time = 0 268 | overall_tp = 0 269 | overall_gen = 0 270 | count_gen = 0 271 | stats = {} 272 | for question_idx, question in enumerate(tqdm(questions)): 273 | if question["category"] in temperature_config: 274 | temperature = temperature_config[question["category"]] 275 | else: 276 | temperature = 0.7 277 | 278 | if not do_sample: 279 | temperature = 0.0 #force greedy 280 | 281 | stats[question_idx] = {} # 282 | choices = [] 283 | for i in range(num_choices): 284 | torch.manual_seed(i) 285 | conv = get_conversation_template(model_id) 286 | turns = [] 287 | prompts = [] 288 | 289 | for j in range(len(question["turns"])): 290 | qs = question["turns"][j] 291 | conv.append_message(conv.roles[0], qs) 292 | conv.append_message(conv.roles[1], None) 293 | prompt = conv.get_prompt() 294 | prompts.append(prompt) 295 | input_ids = tokenizer([prompt]).input_ids 296 | 297 | if temperature < 1e-4: 298 | do_sample = False 299 | else: 300 | do_sample = True 301 | 302 | 303 | # some models may error out when generating long outputs 304 | if True: 305 | start_time = time.time() 306 | output_ids = model.generate( 307 | torch.as_tensor(input_ids).cuda(), 308 | do_sample=do_sample, 309 | temperature=temperature, 310 | max_new_tokens=max_new_token, 311 | top_k=0.0, top_p=1.0, 312 | ) 313 | end_time = time.time() 314 | gap_time = end_time - start_time 315 | tokens = output_ids.numel() - len(input_ids[0]) 316 | overall_time += gap_time 317 | overall_gen += tokens 318 | overall_tp += tokens / gap_time 319 | count_gen += 1 320 | 321 | stats[question_idx][j] = [gap_time, tokens] 322 | if lade.get_device() == 0 and ds_local_rank == 0: 323 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time]) 324 | 325 | if model.config.is_encoder_decoder: 326 | output_ids = output_ids[0] 327 | else: 328 | output_ids = output_ids[0][len(input_ids[0]) :] 329 | 330 | # be consistent with the template's stop_token_ids 331 | if conv.stop_token_ids: 332 | stop_token_ids_index = [ 333 | i 334 | for i, id in enumerate(output_ids) 335 | if id in conv.stop_token_ids 336 | ] 337 | if len(stop_token_ids_index) > 0: 338 | output_ids = output_ids[: stop_token_ids_index[0]] 339 | 340 | output = tokenizer.decode( 341 | output_ids, 342 | spaces_between_special_tokens=False, 343 | ) 344 | if conv.stop_str and output.find(conv.stop_str) > 0: 345 | output = output[: output.find(conv.stop_str)] 346 | for special_token in tokenizer.special_tokens_map.values(): 347 | if isinstance(special_token, list): 348 | for special_tok in special_token: 349 | output = output.replace(special_tok, "") 350 | else: 351 | output = output.replace(special_token, "") 352 | output = output.strip() 353 | 354 | if conv.name == "xgen" and output.startswith("Assistant:"): 355 | output = output.replace("Assistant:", "", 1).strip() 356 | ''' 357 | except RuntimeError as e: 358 | print("ERROR question ID: ", question["question_id"]) 359 | output = "ERROR" 360 | ''' 361 | turns.append(output) 362 | conv.messages[-1][-1] = output 363 | 364 | choices.append({"index": i, "turns": turns, "prompts" : prompts}) 365 | 366 | if lade.get_device() == 0 and ds_local_rank == 0: 367 | # Dump answers 368 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 369 | with open(os.path.expanduser(answer_file), "a") as fout: 370 | ans_json = { 371 | "question_id": question["question_id"], 372 | "answer_id": shortuuid.uuid(), 373 | "model_id": model_id, 374 | "choices": choices, 375 | "tstamp": time.time(), 376 | } 377 | fout.write(json.dumps(ans_json) + "\n") 378 | #if question_idx == 1: 379 | # break 380 | 381 | if lade.get_device() == 0 and ds_local_rank == 0: 382 | torch.save(stats[question_idx], answer_file + ".pt") 383 | print("LOG SAVE TO ", answer_file + ".pt") 384 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}") 385 | lade.log_history() 386 | lade.save_log(answer_file + "-lade-log.pt") 387 | 388 | 389 | def reorg_answer_file(answer_file): 390 | """Sort by question id and de-duplication""" 391 | answers = {} 392 | with open(answer_file, "r") as fin: 393 | for l in fin: 394 | qid = json.loads(l)["question_id"] 395 | answers[qid] = l 396 | 397 | qids = sorted(list(answers.keys())) 398 | with open(answer_file, "w") as fout: 399 | for qid in qids: 400 | fout.write(answers[qid]) 401 | 402 | 403 | if __name__ == "__main__": 404 | parser = argparse.ArgumentParser() 405 | parser.add_argument( 406 | "--model-path", 407 | type=str, 408 | required=True, 409 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 410 | ) 411 | parser.add_argument( 412 | "--model-id", type=str, required=True, help="A custom name for the model." 413 | ) 414 | parser.add_argument( 415 | "--cache-dir", 416 | type=str, 417 | default="", 418 | ) 419 | parser.add_argument( 420 | "--debug", 421 | action="store_true", 422 | ) 423 | parser.add_argument( 424 | "--bench-name", 425 | type=str, 426 | default="mt_bench", 427 | help="The name of the benchmark question set.", 428 | ) 429 | parser.add_argument( 430 | "--question-begin", 431 | type=int, 432 | help="A debug option. The begin index of questions.", 433 | ) 434 | parser.add_argument( 435 | "--question-end", type=int, help="A debug option. The end index of questions." 436 | ) 437 | parser.add_argument( 438 | "--cpu_offloading", action="store_true" 439 | ) 440 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 441 | parser.add_argument( 442 | "--max-new-token", 443 | type=int, 444 | default=1024, 445 | help="The maximum number of new generated tokens.", 446 | ) 447 | parser.add_argument( 448 | "--num-choices", 449 | type=int, 450 | default=1, 451 | help="How many completion choices to generate.", 452 | ) 453 | parser.add_argument( 454 | "--num-gpus-per-model", 455 | type=int, 456 | default=1, 457 | help="The number of GPUs per model.", 458 | ) 459 | parser.add_argument( 460 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 461 | ) 462 | parser.add_argument( 463 | "--max-gpu-memory", 464 | type=str, 465 | help="Maxmum GPU memory used for model weights per GPU.", 466 | ) 467 | parser.add_argument( 468 | "--dtype", 469 | type=str, 470 | choices=["float32", "float64", "float16", "bfloat16"], 471 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 472 | default=None, 473 | ) 474 | parser.add_argument( 475 | "--local_rank", 476 | type=int, 477 | default=0, 478 | ) 479 | parser.add_argument( 480 | "--local-rank", 481 | type=int, 482 | default=0, 483 | ) 484 | parser.add_argument( 485 | "--level", 486 | type=int, 487 | default=3, 488 | ) 489 | parser.add_argument( 490 | "--window", 491 | type=int, 492 | default=10, 493 | ) 494 | parser.add_argument( 495 | "--guess", 496 | type=int, 497 | default=10, 498 | ) 499 | parser.add_argument( 500 | "--use-tp", 501 | type=int, 502 | default=0, 503 | ) 504 | parser.add_argument( 505 | "--use-pp", 506 | type=int, 507 | default=0, 508 | ) 509 | parser.add_argument( 510 | "--use-tp-ds", 511 | type=int, 512 | default=0, 513 | ) 514 | parser.add_argument( 515 | "--use-flash", 516 | type=int, 517 | default=0, 518 | ) 519 | parser.add_argument( 520 | "--do-sample", 521 | type=int, 522 | default=0, 523 | ) 524 | 525 | args = parser.parse_args() 526 | if int(os.environ.get("USE_LADE", 0)): 527 | 528 | lade.augment_all() 529 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(","))) 530 | print("lade activated config: ", lade.decoding.CONFIG_MAP) 531 | 532 | question_file = f"mtbench.jsonl" 533 | if args.answer_file: 534 | answer_file = args.answer_file 535 | else: 536 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 537 | 538 | print(f"Output to {answer_file}") 539 | 540 | run_eval( 541 | model_path=args.model_path, 542 | model_id=args.model_id, 543 | question_file=question_file, 544 | question_begin=args.question_begin, 545 | question_end=args.question_end, 546 | answer_file=answer_file, 547 | max_new_token=args.max_new_token, 548 | num_choices=args.num_choices, 549 | num_gpus_per_model=args.num_gpus_per_model, 550 | num_gpus_total=args.num_gpus_total, 551 | max_gpu_memory=args.max_gpu_memory, 552 | dtype=str_to_torch_dtype(args.dtype), 553 | debug=args.debug, 554 | cache_dir=args.cache_dir, 555 | cpu_offloading=args.cpu_offloading, 556 | use_pp=args.use_pp, 557 | use_tp_ds=args.use_tp_ds, 558 | use_tp=args.use_tp, 559 | use_flash=args.use_flash, 560 | do_sample=args.do_sample 561 | ) 562 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 563 | if lade.get_device() == 0 and ds_local_rank == 0: 564 | reorg_answer_file(answer_file) 565 | -------------------------------------------------------------------------------- /applications/eval_xsum.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | #adapted from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/gen_model_answer.py 7 | import argparse 8 | import json 9 | import os 10 | import random 11 | import time 12 | 13 | import shortuuid 14 | import torch 15 | from tqdm import tqdm 16 | from typing import Dict, List, Optional 17 | from fastchat.llm_judge.common import load_questions, temperature_config 18 | from fastchat.model import get_conversation_template 19 | from fastchat.utils import str_to_torch_dtype 20 | import time 21 | import lade 22 | from datasets import load_dataset 23 | 24 | def run_eval( 25 | model_path, 26 | model_id, 27 | question_file, 28 | question_begin, 29 | question_end, 30 | answer_file, 31 | max_new_token, 32 | num_choices, 33 | num_gpus_per_model, 34 | num_gpus_total, 35 | max_gpu_memory, 36 | dtype, 37 | debug, 38 | cache_dir, 39 | cpu_offloading, 40 | use_pp, 41 | use_tp, 42 | use_tp_ds, 43 | use_flash, 44 | do_sample 45 | ): 46 | questions = load_dataset("EdinburghNLP/xsum", split="validation", streaming=False)["document"][question_begin:question_end] 47 | # random shuffle the questions to balance the loading 48 | ###not shuffle 49 | #random.shuffle(questions) 50 | 51 | # Split the question file into `num_gpus` files 52 | assert num_gpus_total % num_gpus_per_model == 0 53 | 54 | get_answers_func = get_model_answers 55 | 56 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) 57 | ans_handles = [] 58 | for i in range(0, len(questions), chunk_size): 59 | ans_handles.append( 60 | get_answers_func( 61 | model_path, 62 | model_id, 63 | questions[i : i + chunk_size], 64 | answer_file, 65 | max_new_token, 66 | num_choices, 67 | num_gpus_per_model, 68 | max_gpu_memory, 69 | dtype=dtype, 70 | debug=debug, 71 | cache_dir=cache_dir, 72 | cpu_offloading=cpu_offloading, 73 | use_tp=use_tp, 74 | use_pp=use_pp, 75 | use_tp_ds=use_tp_ds, 76 | use_flash=use_flash, 77 | do_sample=do_sample 78 | ) 79 | ) 80 | 81 | 82 | from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, LlamaForCausalLM 83 | from fastchat.model.model_adapter import Llama2Adapter, raise_warning_for_incompatible_cpu_offloading_configuration 84 | 85 | def load_model( 86 | model_path: str, 87 | device: str = "cuda", 88 | device_map: str= "", 89 | num_gpus: int = 1, 90 | max_gpu_memory: Optional[str] = None, 91 | dtype: Optional[torch.dtype] = None, 92 | load_8bit: bool = False, 93 | cpu_offloading: bool = False, 94 | revision: str = "main", 95 | debug: bool = False, 96 | use_flash:bool = False 97 | ): 98 | """Load a model from Hugging Face.""" 99 | # get model adapter 100 | adapter = Llama2Adapter() 101 | # Handle device mapping 102 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( 103 | device, load_8bit, cpu_offloading 104 | ) 105 | if device == "cpu": 106 | kwargs = {"torch_dtype": torch.float32} 107 | if CPU_ISA in ["avx512_bf16", "amx"]: 108 | try: 109 | import intel_extension_for_pytorch as ipex 110 | 111 | kwargs = {"torch_dtype": torch.bfloat16} 112 | except ImportError: 113 | warnings.warn( 114 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" 115 | ) 116 | elif device.startswith("cuda"): 117 | kwargs = {"torch_dtype": torch.float16} 118 | if num_gpus != 1: 119 | kwargs["device_map"] = "auto" 120 | if max_gpu_memory is None: 121 | kwargs[ 122 | "device_map" 123 | ] = "sequential" # This is important for not the same VRAM sizes 124 | available_gpu_memory = get_gpu_memory(num_gpus) 125 | kwargs["max_memory"] = { 126 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" 127 | for i in range(num_gpus) 128 | } 129 | else: 130 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} 131 | 132 | if cpu_offloading: 133 | # raises an error on incompatible platforms 134 | from transformers import BitsAndBytesConfig 135 | 136 | if "max_memory" in kwargs: 137 | kwargs["max_memory"]["cpu"] = ( 138 | str(math.floor(psutil.virtual_memory().available / 2**20)) + "Mib" 139 | ) 140 | kwargs["quantization_config"] = BitsAndBytesConfig( 141 | load_in_8bit_fp32_cpu_offload=cpu_offloading 142 | ) 143 | kwargs["load_in_8bit"] = load_8bit 144 | elif load_8bit: 145 | if num_gpus != 1: 146 | warnings.warn( 147 | "8-bit quantization is not supported for multi-gpu inference." 148 | ) 149 | else: 150 | model, tokenizer = adapter.load_compress_model( 151 | model_path=model_path, 152 | device=device, 153 | torch_dtype=kwargs["torch_dtype"], 154 | revision=revision, 155 | ) 156 | if debug: 157 | print(model) 158 | return model, tokenizer 159 | kwargs["revision"] = revision 160 | 161 | if dtype is not None: # Overwrite dtype if it is provided in the arguments. 162 | kwargs["torch_dtype"] = dtype 163 | if use_flash: 164 | kwargs["use_flash_attention_2"] = use_flash 165 | if len(device_map) > 0: 166 | kwargs["device_map"] = device_map 167 | # Load model 168 | model, tokenizer = adapter.load_model(model_path, kwargs) 169 | 170 | if len(device_map) > 0: 171 | return model, tokenizer 172 | 173 | if ( 174 | device == "cpu" 175 | and kwargs["torch_dtype"] is torch.bfloat16 176 | and CPU_ISA is not None 177 | ): 178 | model = ipex.optimize(model, dtype=kwargs["torch_dtype"]) 179 | 180 | if (device.startswith("cuda") and num_gpus == 1 and not cpu_offloading) or device in ( 181 | "mps", 182 | "xpu", 183 | "npu", 184 | ): 185 | model.to(device) 186 | 187 | if device == "xpu": 188 | model = torch.xpu.optimize(model, dtype=kwargs["torch_dtype"], inplace=True) 189 | 190 | if debug: 191 | print(model) 192 | 193 | return model, tokenizer 194 | 195 | #@torch.inference_mode() 196 | def get_model_answers( 197 | model_path, 198 | model_id, 199 | questions, 200 | answer_file, 201 | max_new_token, 202 | num_choices, 203 | num_gpus_per_model, 204 | max_gpu_memory, 205 | dtype, 206 | debug, 207 | cache_dir, 208 | cpu_offloading, 209 | use_pp, 210 | use_tp_ds, 211 | use_tp, 212 | use_flash, 213 | do_sample 214 | ): 215 | devices = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",") 216 | 217 | print("configuration: ", "flash attn: ", use_flash, " HF PP: ", use_pp, " DS TP: ", use_tp_ds, " GPUS: ", devices) 218 | 219 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 220 | if use_pp: 221 | model, tokenizer = load_model( 222 | model_path, 223 | use_flash=use_flash, 224 | device=f"cuda", 225 | device_map="balanced", 226 | num_gpus=num_gpus_per_model, 227 | max_gpu_memory=max_gpu_memory, 228 | dtype=dtype, 229 | load_8bit=False, 230 | cpu_offloading=cpu_offloading, 231 | debug=debug, 232 | ) 233 | 234 | elif use_tp_ds: 235 | import deepspeed 236 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', '0'))) 237 | model, tokenizer = load_model( 238 | model_path, 239 | use_flash=use_flash, 240 | device_map="cpu", 241 | num_gpus=num_gpus_per_model, 242 | max_gpu_memory=max_gpu_memory, 243 | dtype=dtype, 244 | load_8bit=False, 245 | cpu_offloading=cpu_offloading, 246 | debug=debug, 247 | ) 248 | model = deepspeed.init_inference( 249 | model, 250 | mp_size=int(os.getenv("WORLD_SIZE", "1")), 251 | dtype=torch.half 252 | ) 253 | else: 254 | model, tokenizer = load_model( 255 | model_path, 256 | use_flash=use_flash, 257 | device=f"cuda:{lade.get_device()}", 258 | num_gpus=num_gpus_per_model, 259 | max_gpu_memory=max_gpu_memory, 260 | dtype=dtype, 261 | load_8bit=False, 262 | cpu_offloading=cpu_offloading, 263 | debug=debug, 264 | ) 265 | #model = AutoModelForCausalLM.from_pretrained(model_path, config=cfg, torch_dtype=torch.float16, device_map=lade.get_device()) 266 | model.tokenizer = tokenizer 267 | 268 | overall_time = 0 269 | overall_tp = 0 270 | overall_gen = 0 271 | count_gen = 0 272 | stats = {} 273 | for question_idx, question in enumerate(tqdm(questions)): 274 | 275 | stats[question_idx] = {} # 276 | choices = [] 277 | for i in range(num_choices): 278 | torch.manual_seed(i) 279 | conv = get_conversation_template(model_id) 280 | turns = [] 281 | prompts = [] 282 | 283 | for j in range(1): 284 | 285 | prompt = f'''[INST] <> 286 | You are an intelligent chatbot. Answer the questions only using the following context: 287 | 288 | {question} 289 | 290 | Here are some rules you always follow: 291 | 292 | - Generate human readable output, avoid creating output with gibberish text. 293 | - Generate only the requested output, don't include any other language before or after the requested output. 294 | - Never say thank you, that you are happy to help, that you are an AI agent, etc. Just answer directly. 295 | - Generate professional language typically used in business documents in North America. 296 | - Never generate offensive or foul language. 297 | 298 | <> 299 | 300 | Briefly summarize the given context. [/INST] 301 | Summary: ''' 302 | 303 | prompts.append(prompt) 304 | 305 | input_ids = tokenizer([prompt]).input_ids 306 | 307 | #print("len: ", len(input_ids[0])) 308 | if len(input_ids[0]) > 2048: #skip input len > 2048 tokens 309 | continue 310 | 311 | # some models may error out when generating long outputs 312 | if True: 313 | if do_sample: 314 | start_time = time.time() 315 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=True, top_k=0, temperature=1.0, top_p=1.0) 316 | end_time = time.time() 317 | else: 318 | start_time = time.time() 319 | output_ids = model.generate(torch.as_tensor(input_ids).cuda(), max_new_tokens=max_new_token, do_sample=False, top_k=0) 320 | end_time = time.time() 321 | 322 | gap_time = end_time - start_time 323 | tokens = output_ids.numel() - len(input_ids[0]) 324 | overall_time += gap_time 325 | overall_gen += tokens 326 | overall_tp += tokens / gap_time 327 | count_gen += 1 328 | 329 | stats[question_idx][j] = [gap_time, tokens] 330 | if lade.get_device() == 0 and ds_local_rank == 0: 331 | print([f"step {i} turn {j} time: ", gap_time, " generated tokens: ", tokens, " throughput: " , tokens / gap_time]) 332 | 333 | if model.config.is_encoder_decoder: 334 | output_ids = output_ids[0] 335 | else: 336 | output_ids = output_ids[0][len(input_ids[0]) :] 337 | 338 | # be consistent with the template's stop_token_ids 339 | if conv.stop_token_ids: 340 | stop_token_ids_index = [ 341 | i 342 | for i, id in enumerate(output_ids) 343 | if id in conv.stop_token_ids 344 | ] 345 | if len(stop_token_ids_index) > 0: 346 | output_ids = output_ids[: stop_token_ids_index[0]] 347 | 348 | output = tokenizer.decode( 349 | output_ids, 350 | spaces_between_special_tokens=False, 351 | ) 352 | if conv.stop_str and output.find(conv.stop_str) > 0: 353 | output = output[: output.find(conv.stop_str)] 354 | for special_token in tokenizer.special_tokens_map.values(): 355 | if isinstance(special_token, list): 356 | for special_tok in special_token: 357 | output = output.replace(special_tok, "") 358 | else: 359 | output = output.replace(special_token, "") 360 | output = output.strip() 361 | 362 | if conv.name == "xgen" and output.startswith("Assistant:"): 363 | output = output.replace("Assistant:", "", 1).strip() 364 | 365 | #print("output: ", output) 366 | ''' 367 | except RuntimeError as e: 368 | print("ERROR question ID: ", question["question_id"]) 369 | output = "ERROR" 370 | ''' 371 | turns.append(output) 372 | 373 | 374 | choices.append({"index": i, "turns": turns, "prompts" : prompts}) 375 | 376 | if lade.get_device() == 0 and ds_local_rank == 0: 377 | # Dump answers 378 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 379 | with open(os.path.expanduser(answer_file), "a") as fout: 380 | ans_json = { 381 | "question_id": question_idx, 382 | "answer_id": shortuuid.uuid(), 383 | "model_id": model_id, 384 | "choices": choices, 385 | "tstamp": time.time(), 386 | } 387 | fout.write(json.dumps(ans_json) + "\n") 388 | #if question_idx == 1: 389 | # break 390 | 391 | if lade.get_device() == 0 and ds_local_rank == 0: 392 | torch.save(stats[question_idx], answer_file + ".pt") 393 | print("LOG SAVE TO ", answer_file + ".pt") 394 | print(f"AVERAGE THROUGHPUT1 {overall_tp / count_gen} AVERAGE THROUGHPUT2 {overall_gen / overall_time} STAT {[overall_tp, count_gen, overall_gen, overall_time]}") 395 | lade.log_history() 396 | lade.save_log(answer_file + "-lade-log.pt") 397 | 398 | 399 | def reorg_answer_file(answer_file): 400 | """Sort by question id and de-duplication""" 401 | answers = {} 402 | with open(answer_file, "r") as fin: 403 | for l in fin: 404 | qid = json.loads(l)["question_id"] 405 | answers[qid] = l 406 | 407 | qids = sorted(list(answers.keys())) 408 | with open(answer_file, "w") as fout: 409 | for qid in qids: 410 | fout.write(answers[qid]) 411 | 412 | 413 | if __name__ == "__main__": 414 | parser = argparse.ArgumentParser() 415 | parser.add_argument( 416 | "--model-path", 417 | type=str, 418 | required=True, 419 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 420 | ) 421 | parser.add_argument( 422 | "--model-id", type=str, required=True, help="A custom name for the model." 423 | ) 424 | parser.add_argument( 425 | "--cache-dir", 426 | type=str, 427 | default="", 428 | ) 429 | parser.add_argument( 430 | "--debug", 431 | action="store_true", 432 | ) 433 | parser.add_argument( 434 | "--bench-name", 435 | type=str, 436 | default="xsum", 437 | help="The name of the benchmark question set.", 438 | ) 439 | parser.add_argument( 440 | "--question-begin", 441 | type=int, 442 | help="A debug option. The begin index of questions.", 443 | ) 444 | parser.add_argument( 445 | "--question-end", type=int, help="A debug option. The end index of questions." 446 | ) 447 | parser.add_argument( 448 | "--cpu_offloading", action="store_true" 449 | ) 450 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 451 | parser.add_argument( 452 | "--max-new-token", 453 | type=int, 454 | default=1024, 455 | help="The maximum number of new generated tokens.", 456 | ) 457 | parser.add_argument( 458 | "--num-choices", 459 | type=int, 460 | default=1, 461 | help="How many completion choices to generate.", 462 | ) 463 | parser.add_argument( 464 | "--num-gpus-per-model", 465 | type=int, 466 | default=1, 467 | help="The number of GPUs per model.", 468 | ) 469 | parser.add_argument( 470 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 471 | ) 472 | parser.add_argument( 473 | "--max-gpu-memory", 474 | type=str, 475 | help="Maxmum GPU memory used for model weights per GPU.", 476 | ) 477 | parser.add_argument( 478 | "--dtype", 479 | type=str, 480 | choices=["float32", "float64", "float16", "bfloat16"], 481 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 482 | default=None, 483 | ) 484 | parser.add_argument( 485 | "--local_rank", 486 | type=int, 487 | default=0, 488 | ) 489 | parser.add_argument( 490 | "--local-rank", 491 | type=int, 492 | default=0, 493 | ) 494 | parser.add_argument( 495 | "--level", 496 | type=int, 497 | default=3, 498 | ) 499 | parser.add_argument( 500 | "--window", 501 | type=int, 502 | default=10, 503 | ) 504 | parser.add_argument( 505 | "--guess", 506 | type=int, 507 | default=10, 508 | ) 509 | parser.add_argument( 510 | "--use-tp", 511 | type=int, 512 | default=0, 513 | ) 514 | parser.add_argument( 515 | "--use-pp", 516 | type=int, 517 | default=0, 518 | ) 519 | parser.add_argument( 520 | "--use-tp-ds", 521 | type=int, 522 | default=0, 523 | ) 524 | parser.add_argument( 525 | "--use-flash", 526 | type=int, 527 | default=0, 528 | ) 529 | parser.add_argument( 530 | "--do-sample", 531 | type=int, 532 | default=0, 533 | ) 534 | 535 | args = parser.parse_args() 536 | if int(os.environ.get("USE_LADE", 0)): 537 | 538 | lade.augment_all() 539 | lade.config_lade(LEVEL=args.level, WINDOW_SIZE=args.window, GUESS_SET_SIZE=args.guess, DEBUG=1, USE_FLASH=args.use_flash, DIST_WORKERS=len(os.environ.get("CUDA_VISIBLE_DEVICES").split(","))) 540 | print("lade activated config: ", lade.decoding.CONFIG_MAP) 541 | 542 | question_file = f"mtbench.jsonl" 543 | if args.answer_file: 544 | answer_file = args.answer_file 545 | else: 546 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 547 | 548 | print(f"Output to {answer_file}") 549 | 550 | run_eval( 551 | model_path=args.model_path, 552 | model_id=args.model_id, 553 | question_file=question_file, 554 | question_begin=args.question_begin, 555 | question_end=args.question_end, 556 | answer_file=answer_file, 557 | max_new_token=args.max_new_token, 558 | num_choices=args.num_choices, 559 | num_gpus_per_model=args.num_gpus_per_model, 560 | num_gpus_total=args.num_gpus_total, 561 | max_gpu_memory=args.max_gpu_memory, 562 | dtype=str_to_torch_dtype(args.dtype), 563 | debug=args.debug, 564 | cache_dir=args.cache_dir, 565 | cpu_offloading=args.cpu_offloading, 566 | use_pp=args.use_pp, 567 | use_tp_ds=args.use_tp_ds, 568 | use_tp=args.use_tp, 569 | use_flash=args.use_flash, 570 | do_sample=args.do_sample 571 | ) 572 | ds_local_rank = int(os.getenv('LOCAL_RANK', '0')) 573 | if lade.get_device() == 0 and ds_local_rank == 0: 574 | reorg_answer_file(answer_file) 575 | -------------------------------------------------------------------------------- /applications/run_chat.sh: -------------------------------------------------------------------------------- 1 | USE_LADE=1 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug 2 | USE_LADE=0 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug 3 | 4 | USE_LADE=1 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat 5 | USE_LADE=0 python chatbot.py --model_path meta-llama/Llama-2-7b-chat-hf --debug --chat 6 | -------------------------------------------------------------------------------- /applications/run_mtbench.sh: -------------------------------------------------------------------------------- 1 | #download data 2 | wget https://raw.githubusercontent.com/lm-sys/FastChat/v0.2.31/fastchat/llm_judge/data/mt_bench/question.jsonl -O mtbench.jsonl 3 | 4 | export CUDA=0 5 | export LADE=0 6 | export LEVEL=0 7 | export WIN=0 8 | export GUESS=0 9 | export FLASH=0 10 | export PP=0 11 | CUDA_VISIBLE_DEVICES=$CUDA USE_LADE=$LADE python eval_mtbench.py \ 12 | --model-path meta-llama/Llama-2-7b-chat-hf --model-id \ 13 | llama-2-7b-level-$LEVEL-win-$WIN-guess-$GUESS-f$FLASH-pp$CUDA \ 14 | --level $LEVEL --window $WIN --guess $GUESS --use-flash $FLASH --use-pp $PP 15 | 16 | export CUDA=0 17 | export LADE=1 18 | export LEVEL=5 19 | export WIN=15 20 | export GUESS=15 21 | export FLASH=0 22 | export PP=0 23 | CUDA_VISIBLE_DEVICES=$CUDA USE_LADE=$LADE python eval_mtbench.py \ 24 | --model-path meta-llama/Llama-2-7b-chat-hf --model-id \ 25 | llama-2-7b-level-$LEVEL-win-$WIN-guess-$GUESS-f$FLASH-pp$CUDA \ 26 | --level $LEVEL --window $WIN --guess $GUESS --use-flash $FLASH --use-pp $PP 27 | 28 | export GPUS=1 29 | export LEVEL=0 30 | export WIN=0 31 | export GUESS=0 32 | export FLASH=0 33 | deepspeed --num_gpus $GPUS eval_mtbench.py --model-path meta-llama/Llama-2-7b-chat-hf \ 34 | --model-id llama-2-7b-level-$LEVEL-win-$WIN-guess-$GUESS-f$FLASH-ds$GPUS \ 35 | --level $LEVEL --window $WIN --guess $GUESS --use-flash $FLASH --use-tp-ds 1 36 | -------------------------------------------------------------------------------- /lade/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import augment_llama 2 | from .utils import augment_generate 3 | from .utils import augment_all 4 | from .utils import config_lade, save_log, log_history 5 | from .lade_distributed import * 6 | -------------------------------------------------------------------------------- /lade/decoding.py: -------------------------------------------------------------------------------- 1 | from transformers import GenerationMixin 2 | import torch 3 | import copy 4 | import inspect 5 | import warnings 6 | from dataclasses import dataclass 7 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 8 | from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GreedySearchOutput, SampleOutput, TemperatureLogitsWarper, TopPLogitsWarper, TopKLogitsWarper 9 | import torch.distributed as dist 10 | import os, time, random 11 | FUNC_MAP = {} 12 | CONFIG_MAP = {} 13 | COLOR_PRINT = int(os.environ.get("COLOR_PRINT", 0)) 14 | 15 | def greedy_search_proxy(self, *args, **kwargs): 16 | USE_LADE = int(os.environ.get("USE_LADE", 0)) 17 | CHAT = int(os.environ.get("CHAT", 0)) 18 | if CHAT and USE_LADE: 19 | return jacobi_greedy_search_multilevel(self, chat=True, *args, **kwargs) 20 | elif CHAT: 21 | return greedy_search_chat(self, chat=True, *args, **kwargs) 22 | 23 | if USE_LADE: 24 | return jacobi_greedy_search_multilevel(self, chat=False, *args, **kwargs) 25 | else: 26 | return FUNC_MAP["greedy_search"](self, *args, **kwargs) 27 | 28 | def sample_proxy(self, *args, **kwargs): 29 | USE_LADE = int(os.environ.get("USE_LADE", 0)) 30 | 31 | if USE_LADE: 32 | return jacobi_sample_multilevel(self, chat=int(os.environ.get("CHAT", 0)), *args, **kwargs) 33 | else: 34 | return FUNC_MAP["greedy_search"](self, *args, **kwargs) 35 | 36 | 37 | def update_token_map(token_map, lst_token, past_tokens, new_results, LEVEL, WINDOW_SIZE, GUESS_SET_SIZE): 38 | if GUESS_SET_SIZE != -1: #limited guess set size for each key, lru policy 39 | if lst_token not in token_map: 40 | token_map[lst_token] = [] 41 | tup = tuple(past_tokens[ll][0] for ll in range(1, LEVEL - 1)) + (new_results[0],) 42 | if tup in token_map[lst_token]: 43 | token_map[lst_token].remove(tup) 44 | token_map[lst_token].append(tup) 45 | elif len(token_map[lst_token]) < GUESS_SET_SIZE: 46 | token_map[lst_token].append(tup) 47 | else: 48 | assert len(token_map[lst_token]) == GUESS_SET_SIZE 49 | token_map[lst_token] = token_map[lst_token][1:] + [tup] 50 | 51 | for i in range(1, WINDOW_SIZE): 52 | if past_tokens[0][i - 1] not in token_map: 53 | token_map[past_tokens[0][i - 1]] = [] 54 | tup = tuple(past_tokens[ll][i] for ll in range(1, LEVEL - 1)) + (new_results[i],) 55 | 56 | if tup in token_map[past_tokens[0][i - 1]]: 57 | token_map[past_tokens[0][i - 1]].remove(tup) 58 | token_map[past_tokens[0][i - 1]].append(tup) 59 | elif len(token_map[past_tokens[0][i - 1]]) < GUESS_SET_SIZE: 60 | token_map[past_tokens[0][i - 1]].append(tup) 61 | else: 62 | assert len(token_map[past_tokens[0][i - 1]]) == GUESS_SET_SIZE 63 | token_map[past_tokens[0][i - 1]] = token_map[past_tokens[0][i - 1]][1:] + [tup] 64 | 65 | else: #unlimited guess set size for each key 66 | #first add 67 | if lst_token not in token_map: 68 | token_map[lst_token] = set() 69 | #build tuple 70 | tup = tuple(past_tokens[ll][0] for ll in range(1, LEVEL - 1)) + (new_results[0],) 71 | #add tuple 72 | token_map[lst_token].add(tup) 73 | 74 | for i in range(1, WINDOW_SIZE): 75 | if past_tokens[0][i - 1] not in token_map: 76 | token_map[past_tokens[0][i - 1]] = set() 77 | tup = tuple(past_tokens[ll][i] for ll in range(1, LEVEL - 1)) + (new_results[i],) 78 | token_map[past_tokens[0][i - 1]].add(tup) 79 | 80 | def append_new_generated_pool(tokens, token_map, LEVEL, GUESS_SET_SIZE): 81 | if len(tokens) != LEVEL: 82 | return 83 | lst_token = tokens[0] 84 | tup = tuple(tokens[1:]) 85 | 86 | if GUESS_SET_SIZE != -1: #limited guess set size for each key, lru policy 87 | if lst_token not in token_map: 88 | token_map[lst_token] = [] 89 | if tup in token_map[lst_token]: 90 | token_map[lst_token].remove(tup) 91 | token_map[lst_token].append(tup) 92 | elif len(token_map[lst_token]) < GUESS_SET_SIZE: 93 | token_map[lst_token].append(tup) 94 | else: 95 | assert len(token_map[lst_token]) == GUESS_SET_SIZE 96 | token_map[lst_token] = token_map[lst_token][1:] + [tup] 97 | else: #unlimited guess set size for each key 98 | #first add 99 | if lst_token not in token_map: 100 | token_map[lst_token] = set() 101 | token_map[lst_token].add(tup) 102 | 103 | 104 | def fill_pool_with_prompt(prompts, token_map, LEVEL, GUESS_SET_SIZE): 105 | for start_idx in range(len(prompts) - LEVEL + 1): 106 | lst_token = prompts[start_idx] 107 | tup = tuple(prompts[start_idx+1:start_idx+LEVEL]) 108 | 109 | if len(tup) != LEVEL - 1: 110 | return 111 | 112 | if GUESS_SET_SIZE != -1: #limited guess set size for each key, lru policy 113 | if lst_token not in token_map: 114 | token_map[lst_token] = [] 115 | if tup in token_map[lst_token]: 116 | token_map[lst_token].remove(tup) 117 | token_map[lst_token].append(tup) 118 | elif len(token_map[lst_token]) < GUESS_SET_SIZE: 119 | token_map[lst_token].append(tup) 120 | else: 121 | assert len(token_map[lst_token]) == GUESS_SET_SIZE 122 | token_map[lst_token] = token_map[lst_token][1:] + [tup] 123 | else: #unlimited guess set size for each key 124 | #first add 125 | if lst_token not in token_map: 126 | token_map[lst_token] = set() 127 | token_map[lst_token].add(tup) 128 | 129 | 130 | 131 | def filter_window(level_window, eos_token_id, reset_func): 132 | 133 | for idx in range(len(level_window)): 134 | if level_window[idx] == eos_token_id: 135 | level_window[idx] = reset_func() 136 | 137 | def jacobi_sample_multilevel( 138 | self, 139 | input_ids: torch.LongTensor, 140 | logits_processor: Optional[LogitsProcessorList] = None, 141 | stopping_criteria: Optional[StoppingCriteriaList] = None, 142 | logits_warper: Optional[LogitsProcessorList] = None, 143 | max_length: Optional[int] = None, 144 | pad_token_id: Optional[int] = None, 145 | eos_token_id: Optional[Union[int, List[int]]] = None, 146 | output_attentions: Optional[bool] = None, 147 | output_hidden_states: Optional[bool] = None, 148 | output_scores: Optional[bool] = None, 149 | return_dict_in_generate: Optional[bool] = None, 150 | synced_gpus: bool = False, 151 | streamer: Optional["BaseStreamer"] = None, 152 | chat: bool = False, 153 | **model_kwargs, 154 | ) -> Union[SampleOutput, torch.LongTensor]: 155 | r""" 156 | Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and 157 | can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. 158 | 159 | 160 | 161 | In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead. 162 | For an overview of generation strategies and code examples, check the [following 163 | guide](../generation_strategies). 164 | 165 | 166 | 167 | Parameters: 168 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 169 | The sequence used as a prompt for the generation. 170 | logits_processor (`LogitsProcessorList`, *optional*): 171 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 172 | used to modify the prediction scores of the language modeling head applied at each generation step. 173 | stopping_criteria (`StoppingCriteriaList`, *optional*): 174 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 175 | used to tell if the generation loop should stop. 176 | logits_warper (`LogitsProcessorList`, *optional*): 177 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used 178 | to warp the prediction score distribution of the language modeling head applied before multinomial 179 | sampling at each generation step. 180 | max_length (`int`, *optional*, defaults to 20): 181 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated 182 | tokens. The maximum length of the sequence to be generated. 183 | pad_token_id (`int`, *optional*): 184 | The id of the *padding* token. 185 | eos_token_id (`Union[int, List[int]]`, *optional*): 186 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. 187 | output_attentions (`bool`, *optional*, defaults to `False`): 188 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 189 | returned tensors for more details. 190 | output_hidden_states (`bool`, *optional*, defaults to `False`): 191 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 192 | for more details. 193 | output_scores (`bool`, *optional*, defaults to `False`): 194 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 195 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 196 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 197 | synced_gpus (`bool`, *optional*, defaults to `False`): 198 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 199 | streamer (`BaseStreamer`, *optional*): 200 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed 201 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing. 202 | model_kwargs: 203 | Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is 204 | an encoder-decoder model the kwargs should include `encoder_outputs`. 205 | 206 | Return: 207 | [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`: 208 | A `torch.LongTensor` containing the generated tokens (default behaviour) or a 209 | [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 210 | `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if 211 | `model.config.is_encoder_decoder=True`. 212 | 213 | Examples: 214 | 215 | ```python 216 | >>> from transformers import ( 217 | ... AutoTokenizer, 218 | ... AutoModelForCausalLM, 219 | ... LogitsProcessorList, 220 | ... MinLengthLogitsProcessor, 221 | ... TopKLogitsWarper, 222 | ... TemperatureLogitsWarper, 223 | ... StoppingCriteriaList, 224 | ... MaxLengthCriteria, 225 | ... ) 226 | >>> import torch 227 | 228 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") 229 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") 230 | 231 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token 232 | >>> model.config.pad_token_id = model.config.eos_token_id 233 | >>> model.generation_config.pad_token_id = model.config.eos_token_id 234 | 235 | >>> input_prompt = "Today is a beautiful day, and" 236 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids 237 | 238 | >>> # instantiate logits processors 239 | >>> logits_processor = LogitsProcessorList( 240 | ... [ 241 | ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id), 242 | ... ] 243 | ... ) 244 | >>> # instantiate logits processors 245 | >>> logits_warper = LogitsProcessorList( 246 | ... [ 247 | ... TopKLogitsWarper(50), 248 | ... TemperatureLogitsWarper(0.7), 249 | ... ] 250 | ... ) 251 | 252 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) 253 | 254 | >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT 255 | >>> outputs = model.sample( 256 | ... input_ids, 257 | ... logits_processor=logits_processor, 258 | ... logits_warper=logits_warper, 259 | ... stopping_criteria=stopping_criteria, 260 | ... ) 261 | 262 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) 263 | ['Today is a beautiful day, and we must do everything possible to make it a day of celebration.'] 264 | ```""" 265 | # init values 266 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 267 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 268 | if max_length is not None: 269 | warnings.warn( 270 | "`max_length` is deprecated in this function, use" 271 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", 272 | UserWarning, 273 | ) 274 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 275 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() 276 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 277 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 278 | if isinstance(eos_token_id, int): 279 | eos_token_id = [eos_token_id] 280 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 281 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 282 | output_attentions = ( 283 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 284 | ) 285 | output_hidden_states = ( 286 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 287 | ) 288 | return_dict_in_generate = ( 289 | return_dict_in_generate 290 | if return_dict_in_generate is not None 291 | else self.generation_config.return_dict_in_generate 292 | ) 293 | 294 | # init attention / hidden states / scores tuples 295 | scores = () if (return_dict_in_generate and output_scores) else None 296 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 297 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 298 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 299 | 300 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 301 | if return_dict_in_generate and self.config.is_encoder_decoder: 302 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 303 | encoder_hidden_states = ( 304 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 305 | ) 306 | 307 | # keep track of which sequences are already finished 308 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 309 | 310 | this_peer_finished = False # used by synced_gpus only 311 | 312 | WINDOW_SIZE = CONFIG_MAP.get("WINDOW_SIZE", 60) 313 | GUESS_SET_SIZE = CONFIG_MAP.get("GUESS_SET_SIZE", 60) 314 | ALWAYS_FWD_ONE = CONFIG_MAP.get("ALWAYS_FWD_ONE", 1) 315 | LEVEL = CONFIG_MAP.get("LEVEL", 8) 316 | DEBUG = CONFIG_MAP.get("DEBUG", 0) 317 | DIST_WORKERS = CONFIG_MAP.get("DIST_WORKERS", 1) 318 | LOCAL_RANK = CONFIG_MAP.get("LOCAL_RANK", 0) 319 | USE_FLASH = CONFIG_MAP.get("USE_FLASH", 0) #not use flash by default 320 | POOL_FROM_PROMPT = CONFIG_MAP.get("POOL_FROM_PROMPT", 0) 321 | USE_AWQ = False #not support AWQ 322 | #IN FLASH ATTENTION WE REORDERED LOOKAHEAD WINDOW 323 | 324 | GUESS_SIZE = LEVEL - 1 325 | NOT_SEQ = 0 326 | CONTINUE_ALL = 0 327 | TEMP_FOR_GUESS = 0.0 328 | 329 | assert TEMP_FOR_GUESS == 0 330 | #assert LEVEL <= 8 331 | def random_set(): 332 | return random.randint(0,self.vocab_size - 1) 333 | 334 | all_old_tokens = input_ids[0].tolist() 335 | init_len = len(all_old_tokens) 336 | #print("original: ", init_len, input_ids.numel()) 337 | 338 | def copy_from(): 339 | return random.choice(all_old_tokens) 340 | 341 | order_copy_from_idx = [0] 342 | 343 | def order_copy_from(): 344 | if order_copy_from_idx[0] >= len(all_old_tokens): 345 | order_copy_from_idx[0] = 0 346 | ret = all_old_tokens[order_copy_from_idx[0]] 347 | order_copy_from_idx[0] = 1 + order_copy_from_idx[0] 348 | return ret 349 | 350 | def copy_from_last(): 351 | return all_old_tokens[-1] 352 | 353 | set_token = copy_from 354 | 355 | past_tokens = [[set_token() for _ in range(WINDOW_SIZE + LEVEL - 3)]] + [None for _ in range(LEVEL - 2)] 356 | 357 | if DIST_WORKERS > 1: 358 | dist.broadcast_object_list(past_tokens, src=0) #keep past_tokens always the same on different GPUs 359 | 360 | ###############end Init methods 361 | fill_level = 0 362 | guess_tokens = None 363 | token_map = {} 364 | steps = 0 365 | guess_skip_dist = 0 366 | 367 | if POOL_FROM_PROMPT: 368 | fill_pool_with_prompt(all_old_tokens, token_map, LEVEL, GUESS_SET_SIZE) 369 | 370 | if chat: 371 | init = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \ 372 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 373 | prev = len(init) 374 | 375 | for warper in logits_warper: 376 | #assert type(warper) == TemperatureLogitsWarper or type(warper) == TopPLogitsWarper or type(warper) == TopKLogitsWarper, f"please set top_k=0 {warper}" 377 | assert type(warper) == TemperatureLogitsWarper or type(warper) == TopKLogitsWarper or type(warper) == TopPLogitsWarper, f"please set top_k=0.0 and top_p=1.0 {warper}" 378 | 379 | # auto-regressive generation 380 | while True: 381 | if synced_gpus: 382 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 383 | # The following logic allows an early break if all peers finished generating their sequence 384 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 385 | # send 0.0 if we finished, 1.0 otherwise 386 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 387 | # did all peers finish? the reduced sum will be 0.0 then 388 | if this_peer_finished_flag.item() == 0.0: 389 | break 390 | 391 | # prepare model inputs 392 | #this only support llama, check compatibility with other models 393 | past_key_values = model_kwargs.pop("past_key_values", None) 394 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 395 | if past_key_values is None: 396 | model_inputs["input_ids"] = input_ids 397 | else: 398 | model_inputs["input_ids"] = model_inputs["input_ids"][:, -1 - guess_skip_dist:] 399 | model_inputs["position_ids"] = model_inputs["position_ids"][:, -1 - guess_skip_dist:] 400 | model_inputs["past_key_values"] = past_key_values 401 | 402 | if past_tokens[LEVEL - 2] is not None and lst_token in token_map and GUESS_SET_SIZE > 0: 403 | guess_tokens_ = token_map[lst_token] 404 | guess_tokens = [] 405 | for tok in list(guess_tokens_): 406 | guess_tokens += list(tok) 407 | else: 408 | guess_tokens = None 409 | 410 | #not support logits_processor yet 411 | assert return_dict_in_generate == False 412 | assert len(logits_processor) == 0 413 | 414 | # forward pass to get next token 415 | outputs = self.jforward_multilevel( 416 | **model_inputs, 417 | past_tokens=past_tokens, 418 | guess_tokens=guess_tokens, 419 | return_dict=True, 420 | not_seq=NOT_SEQ, 421 | continue_all=CONTINUE_ALL, 422 | output_attentions=output_attentions, 423 | output_hidden_states=output_hidden_states, 424 | level=LEVEL, 425 | WINDOWS_SIZE=WINDOW_SIZE, 426 | guess_size=GUESS_SIZE, 427 | fill_level=fill_level, 428 | dist_workers=DIST_WORKERS, 429 | la_mask_offset=0, 430 | local_rank=LOCAL_RANK, 431 | use_flash=USE_FLASH 432 | ) 433 | 434 | steps += 1 435 | 436 | if synced_gpus and this_peer_finished: 437 | continue # don't waste resources running the code we don't need 438 | 439 | next_token_logits = outputs.out_logits #outputs.logits[:, -1, :] 440 | 441 | #not support logits_processor and only support temperature w/o top-p top-k, I will support these two later 442 | # pre-process distribution 443 | next_token_scores = logits_warper(input_ids, next_token_logits) 444 | 445 | #delete return_dict_in_generate here, we set it to False 446 | # Store scores, attentions and hidden_states when required 447 | 448 | # finished sentences should have their next token be a padding token 449 | #if eos_token_id is not None: 450 | # if pad_token_id is None: 451 | # raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 452 | # next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 453 | #for bs > 1, so I comment these out 454 | 455 | #handling output 456 | max_hit = 0 457 | 458 | if past_tokens[1] is None: 459 | #first fill, not use verification branch 460 | assert fill_level == 0 461 | probs = torch.nn.functional.softmax(next_token_scores, dim=-1) 462 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 463 | hits = [next_tokens.item()] 464 | 465 | past_tokens[0] = past_tokens[0][1:] 466 | past_tokens[1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist() #fill window with argmax 467 | 468 | fill_level += 1 469 | elif past_tokens[LEVEL - 2] is None: 470 | #fill other levels, not use verification branch 471 | probs = torch.nn.functional.softmax(next_token_scores, dim=-1) 472 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 473 | hits = [next_tokens.item()] 474 | 475 | for level in range(fill_level + 1): 476 | past_tokens[level] = past_tokens[level][1:] 477 | 478 | past_tokens[fill_level + 1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist()[1:] #fill window with argmax 479 | 480 | fill_level += 1 481 | else: 482 | 483 | 484 | if guess_tokens is not None: 485 | probs_next = torch.nn.functional.softmax(next_token_scores, dim=-1)[0] 486 | hits = [] 487 | #= original model output 488 | guess_logits = logits_warper(input_ids, outputs.guess_logits[0]) 489 | guess_probs = torch.nn.functional.softmax(guess_logits, dim=-1) # 490 | #guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist() 491 | guess_indices = list(range(outputs.guess_logits.size(1) // GUESS_SIZE)) 492 | #algorithm modified from specinfer 493 | for idx_in_ngram in range(GUESS_SIZE): 494 | 495 | g_idx = 0 496 | is_accept = False 497 | #print("gues: ", guess_indices) 498 | 499 | while g_idx < len(guess_indices): 500 | guess_idx = guess_indices[g_idx] 501 | guess_offset = guess_idx * GUESS_SIZE 502 | 503 | 504 | #draft_guess is draft model (by lookahead) generation 505 | draft_guess = guess_tokens[guess_offset + idx_in_ngram] 506 | prob_accept = min(1, probs_next[draft_guess].item()) #min(1, prob_llm/prob_draft) #use argmax, prob_draft is 1 507 | sample_prob = random.random() 508 | 509 | if sample_prob < prob_accept: 510 | #accept 511 | hits.append(draft_guess) 512 | is_accept = True 513 | max_hit_idx = guess_idx 514 | new_guess_indices = [] 515 | for guess_idx_n in guess_indices: 516 | guess_offset_n = guess_idx_n * GUESS_SIZE 517 | new_draft_guess = guess_tokens[guess_offset_n + idx_in_ngram] 518 | if new_draft_guess == draft_guess: 519 | new_guess_indices.append(guess_idx_n) 520 | guess_indices = new_guess_indices 521 | break 522 | else: 523 | #not accept 524 | #max norm (argmax) 525 | probs_next[draft_guess] = 0 526 | probs_next = probs_next / probs_next.sum() 527 | g_idx += 1 528 | 529 | if is_accept: 530 | probs_next = guess_probs[guess_offset + idx_in_ngram] 531 | continue 532 | else: 533 | new_token_gen = torch.multinomial(probs_next, num_samples=1).item() 534 | #print("non accept: ", probs_next.size(), new_token_gen) 535 | hits.append(new_token_gen) 536 | break 537 | 538 | #hits.append(new_token_gen) 539 | 540 | max_hit = len(hits) - 1 541 | 542 | else: 543 | probs_next = torch.nn.functional.softmax(next_token_scores, dim=-1) 544 | next_tokens = torch.multinomial(probs_next, num_samples=1).squeeze(1) 545 | hits = [next_tokens.item()] 546 | 547 | 548 | #new window level, use argmax to generate 549 | new_results = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist() 550 | 551 | assert len(past_tokens[LEVEL - 2]) == WINDOW_SIZE and len(new_results) == WINDOW_SIZE 552 | 553 | update_token_map(token_map, lst_token, past_tokens, new_results, LEVEL, WINDOW_SIZE, GUESS_SET_SIZE) 554 | 555 | #update windows when max_hit > 1 556 | if ALWAYS_FWD_ONE: 557 | past_tokens[0] = past_tokens[1][1:] 558 | for level in range(1, LEVEL - 2): 559 | past_tokens[level] = past_tokens[level + 1][:] 560 | 561 | past_tokens[LEVEL - 2] = new_results 562 | else: 563 | past_tokens[0] = past_tokens[1][1 + max_hit:] 564 | for level in range(1, LEVEL - 2): 565 | past_tokens[level] = past_tokens[level + 1][max_hit:] 566 | 567 | past_tokens[LEVEL - 2] = new_results[max_hit:] 568 | 569 | 570 | if max_hit > 0: 571 | if not ALWAYS_FWD_ONE: 572 | for level in range(LEVEL - 1): 573 | past_tokens[level] = past_tokens[level] + [set_token() for _ in range(max_hit)] 574 | 575 | attention_mask = model_kwargs["attention_mask"] 576 | model_kwargs["attention_mask"] = torch.cat((attention_mask, torch.ones(1, max_hit, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1) 577 | 578 | if eos_token_id is not None: 579 | #filter (we find too many in window lead to numerical error) 580 | filter_window(past_tokens[LEVEL - 2], eos_token_id[0], set_token) 581 | 582 | #update kv cache of correctly speculated tokens 583 | past_key_values = [] 584 | for idx, kv in enumerate(outputs.past_key_values): 585 | for hh in range(max_hit): 586 | assert outputs.step_len == kv[0].size(2) 587 | kv[0][:,:,outputs.kvcache_len + hh,:] = kv[0][:,:,outputs.step_len-len(guess_tokens)+max_hit_idx * GUESS_SIZE + hh,:] 588 | kv[1][:,:,outputs.kvcache_len + hh,:] = kv[1][:,:,outputs.step_len-len(guess_tokens)+max_hit_idx * GUESS_SIZE + hh,:] 589 | past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) ) 590 | outputs.past_key_values = past_key_values 591 | 592 | lst_token = hits[max_hit] 593 | 594 | for hit_ids in range(max_hit + 1): 595 | if eos_token_id is not None and hits[hit_ids] == eos_token_id[0]: 596 | all_old_tokens.append(hits[hit_ids]) 597 | next_tokens = eos_token_id_tensor 598 | max_hit = hit_ids 599 | break 600 | else: 601 | all_old_tokens.append(hits[hit_ids]) 602 | if POOL_FROM_PROMPT: 603 | append_new_generated_pool(all_old_tokens[-LEVEL:], token_map, LEVEL, GUESS_SET_SIZE) 604 | 605 | if chat: 606 | 607 | all_str = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \ 608 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 609 | if COLOR_PRINT: 610 | from termcolor import colored 611 | if max_hit > 1: 612 | not_hit = self.tokenizer.decode(all_old_tokens[:-max_hit + 1], skip_special_tokens=True, \ 613 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 614 | pt = colored(not_hit[prev:],"blue") + colored(all_str[len(not_hit):], "blue") 615 | else: 616 | pt = all_str[prev:] 617 | print(pt, flush=True, end="") 618 | else: 619 | print(all_str[prev:], flush=True, end="") 620 | prev = len(all_str) 621 | 622 | # update generated ids, model inputs, and length for next step 623 | input_ids = torch.cat([input_ids, torch.tensor(hits[:max_hit + 1], device=input_ids.device, dtype=input_ids.dtype).unsqueeze(0)], dim=-1) 624 | 625 | #input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 626 | 627 | 628 | ###not change codes below 629 | if streamer is not None: 630 | streamer.put(next_tokens.cpu()) 631 | model_kwargs = self._update_model_kwargs_for_generation( 632 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 633 | ) 634 | 635 | # if eos_token was found in one sentence, set sentence to finished 636 | if eos_token_id_tensor is not None: 637 | unfinished_sequences = unfinished_sequences.mul( 638 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 639 | ) 640 | 641 | # stop when each sentence is finished 642 | if unfinished_sequences.max() == 0: 643 | this_peer_finished = True 644 | 645 | # stop if we exceed the maximum length 646 | if stopping_criteria(input_ids, scores): 647 | this_peer_finished = True 648 | 649 | if this_peer_finished and not synced_gpus: 650 | break 651 | 652 | #if predict more tokens than max_length, remove them 653 | for criteria in stopping_criteria: 654 | if hasattr(criteria, "max_length"): 655 | all_old_tokens = all_old_tokens[:criteria.max_length] 656 | input_ids = input_ids[:,:criteria.max_length] 657 | 658 | if max_length is not None: 659 | all_old_tokens = all_old_tokens[:init_len + max_length] 660 | input_ids = input_ids[:][:init_len + max_length] 661 | #end handling 662 | if DEBUG and LOCAL_RANK == 0: 663 | print("\n==========================ACCELERATION===SUMMARY======================================") 664 | print("Generated tokens: ", len(all_old_tokens) - init_len, "Total steps: ", steps, " Compression ratio: ", round((len(all_old_tokens) - init_len) / steps, 2)) 665 | print("======================================================================================", end="") 666 | CONFIG_MAP["log"].append([len(all_old_tokens) - init_len, steps, round((len(all_old_tokens) - init_len) / steps, 2)]) 667 | 668 | if streamer is not None: 669 | streamer.end() 670 | 671 | if return_dict_in_generate: 672 | if self.config.is_encoder_decoder: 673 | return SampleEncoderDecoderOutput( 674 | sequences=input_ids, 675 | scores=scores, 676 | encoder_attentions=encoder_attentions, 677 | encoder_hidden_states=encoder_hidden_states, 678 | decoder_attentions=decoder_attentions, 679 | cross_attentions=cross_attentions, 680 | decoder_hidden_states=decoder_hidden_states, 681 | past_key_values=model_kwargs.get("past_key_values"), 682 | ) 683 | else: 684 | return SampleDecoderOnlyOutput( 685 | sequences=input_ids, 686 | scores=scores, 687 | attentions=decoder_attentions, 688 | hidden_states=decoder_hidden_states, 689 | past_key_values=model_kwargs.get("past_key_values"), 690 | ) 691 | else: 692 | return input_ids 693 | 694 | 695 | 696 | 697 | def jacobi_greedy_search_multilevel( 698 | self, 699 | input_ids: torch.LongTensor, 700 | logits_processor: Optional[LogitsProcessorList] = None, 701 | stopping_criteria: Optional[StoppingCriteriaList] = None, 702 | max_length: Optional[int] = None, 703 | pad_token_id: Optional[int] = None, 704 | eos_token_id: Optional[Union[int, List[int]]] = None, 705 | output_attentions: Optional[bool] = None, 706 | output_hidden_states: Optional[bool] = None, 707 | output_scores: Optional[bool] = None, 708 | return_dict_in_generate: Optional[bool] = None, 709 | synced_gpus: bool = False, 710 | streamer: Optional["BaseStreamer"] = None, 711 | 712 | chat: bool = False, 713 | stop_token: Optional[str]= None, 714 | **model_kwargs, 715 | ) -> Union[GreedySearchOutput, torch.LongTensor]: 716 | r""" 717 | Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be 718 | used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. 719 | 720 | 721 | 722 | In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() 723 | instead. For an overview of generation strategies and code examples, check the [following 724 | guide](../generation_strategies). 725 | 726 | 727 | 728 | 729 | Parameters: 730 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 731 | The sequence used as a prompt for the generation. 732 | logits_processor (`LogitsProcessorList`, *optional*): 733 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 734 | used to modify the prediction scores of the language modeling head applied at each generation step. 735 | stopping_criteria (`StoppingCriteriaList`, *optional*): 736 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 737 | used to tell if the generation loop should stop. 738 | 739 | max_length (`int`, *optional*, defaults to 20): 740 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated 741 | tokens. The maximum length of the sequence to be generated. 742 | pad_token_id (`int`, *optional*): 743 | The id of the *padding* token. 744 | eos_token_id (`Union[int, List[int]]`, *optional*): 745 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. 746 | output_attentions (`bool`, *optional*, defaults to `False`): 747 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 748 | returned tensors for more details. 749 | output_hidden_states (`bool`, *optional*, defaults to `False`): 750 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 751 | for more details. 752 | output_scores (`bool`, *optional*, defaults to `False`): 753 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 754 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 755 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 756 | synced_gpus (`bool`, *optional*, defaults to `False`): 757 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 758 | streamer (`BaseStreamer`, *optional*): 759 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed 760 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing. 761 | model_kwargs: 762 | Additional model specific keyword arguments will be forwarded to the `forward` function of the model. 763 | If model is an encoder-decoder model the kwargs should include `encoder_outputs`. 764 | 765 | Return: 766 | [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or 767 | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a 768 | [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 769 | `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if 770 | `model.config.is_encoder_decoder=True`. 771 | 772 | Examples: 773 | 774 | ```python 775 | >>> from transformers import ( 776 | ... AutoTokenizer, 777 | ... AutoModelForCausalLM, 778 | ... LogitsProcessorList, 779 | ... MinLengthLogitsProcessor, 780 | ... StoppingCriteriaList, 781 | ... MaxLengthCriteria, 782 | ... ) 783 | 784 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") 785 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") 786 | 787 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token 788 | >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id 789 | 790 | >>> input_prompt = "It might be possible to" 791 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids 792 | 793 | >>> # instantiate logits processors 794 | >>> logits_processor = LogitsProcessorList( 795 | ... [ 796 | ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), 797 | ... ] 798 | ... ) 799 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) 800 | 801 | >>> outputs = model.greedy_search( 802 | ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria 803 | ... ) 804 | 805 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) 806 | ["It might be possible to get a better understanding of the nature of the problem, but it's not"] 807 | ```""" 808 | # init values 809 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 810 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 811 | if max_length is not None: 812 | warnings.warn( 813 | "`max_length` is deprecated in this function, use" 814 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", 815 | UserWarning, 816 | ) 817 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 818 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 819 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 820 | if isinstance(eos_token_id, int): 821 | eos_token_id = [eos_token_id] 822 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 823 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 824 | output_attentions = ( 825 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 826 | ) 827 | output_hidden_states = ( 828 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 829 | ) 830 | return_dict_in_generate = ( 831 | return_dict_in_generate 832 | if return_dict_in_generate is not None 833 | else self.generation_config.return_dict_in_generate 834 | ) 835 | 836 | # init attention / hidden states / scores tuples 837 | scores = () if (return_dict_in_generate and output_scores) else None 838 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 839 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 840 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 841 | 842 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 843 | if return_dict_in_generate and self.config.is_encoder_decoder: 844 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 845 | encoder_hidden_states = ( 846 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 847 | ) 848 | 849 | # keep track of which sequences are already finished 850 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 851 | 852 | this_peer_finished = False # used by synced_gpus only 853 | ############### configurations 854 | WINDOW_SIZE = CONFIG_MAP.get("WINDOW_SIZE", 60) 855 | GUESS_SET_SIZE = CONFIG_MAP.get("GUESS_SET_SIZE", 60) 856 | ALWAYS_FWD_ONE = CONFIG_MAP.get("ALWAYS_FWD_ONE", 1) 857 | LEVEL = CONFIG_MAP.get("LEVEL", 8) 858 | DEBUG = CONFIG_MAP.get("DEBUG", 0) 859 | DIST_WORKERS = CONFIG_MAP.get("DIST_WORKERS", 1) 860 | LOCAL_RANK = CONFIG_MAP.get("LOCAL_RANK", 0) 861 | USE_FLASH = CONFIG_MAP.get("USE_FLASH", 0) #not use flash by default 862 | POOL_FROM_PROMPT = CONFIG_MAP.get("POOL_FROM_PROMPT", 0) 863 | USE_AWQ = False #not support AWQ 864 | #IN FLASH ATTENTION WE REORDERED LOOKAHEAD WINDOW 865 | 866 | GUESS_SIZE = LEVEL - 1 867 | NOT_SEQ = 0 868 | CONTINUE_ALL = 0 869 | TEMP_FOR_GUESS = 0.0 870 | USE_AWQ = False 871 | import random 872 | assert TEMP_FOR_GUESS == 0 873 | assert ALWAYS_FWD_ONE == 1 874 | assert USE_AWQ == False 875 | 876 | ############### Init methods 877 | #random.seed(10) #unset this random seed later 878 | 879 | all_old_tokens = input_ids[0].tolist() 880 | init_len = len(all_old_tokens) 881 | order_copy_from_idx = [0] 882 | 883 | 884 | def random_set(): 885 | return random.randint(0,self.vocab_size - 1) 886 | 887 | def copy_from(): 888 | return random.choice(all_old_tokens) 889 | 890 | def order_copy_from(): 891 | if order_copy_from_idx[0] >= len(all_old_tokens): 892 | order_copy_from_idx[0] = 0 893 | ret = all_old_tokens[order_copy_from_idx[0]] 894 | order_copy_from_idx[0] = 1 + order_copy_from_idx[0] 895 | return ret 896 | 897 | def copy_from_last(): 898 | return all_old_tokens[-1] 899 | 900 | set_token = copy_from 901 | 902 | past_tokens = [[set_token() for _ in range(WINDOW_SIZE + LEVEL - 3)]] + [None for _ in range(LEVEL - 2)] 903 | #past_tokens is the lookahead window. Current we initialize it with random copy from prompts 904 | 905 | if DIST_WORKERS > 1: 906 | dist.broadcast_object_list(past_tokens, src=0) #keep past_tokens always the same on different GPUs 907 | 908 | ###############end Init methods 909 | fill_level = 0 910 | guess_tokens = None 911 | token_map = {} 912 | steps = 0 913 | guess_skip_dist = 0 914 | 915 | if POOL_FROM_PROMPT: 916 | fill_pool_with_prompt(all_old_tokens, token_map, LEVEL, GUESS_SET_SIZE) 917 | 918 | if chat: 919 | init = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \ 920 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 921 | prev = len(init) 922 | 923 | while True: 924 | if synced_gpus: 925 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 926 | # The following logic allows an early break if all peers finished generating their sequence 927 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 928 | # send 0.0 if we finished, 1.0 otherwise 929 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 930 | # did all peers finish? the reduced sum will be 0.0 then 931 | if this_peer_finished_flag.item() == 0.0: 932 | break 933 | 934 | # prepare model inputs 935 | #this only support llama, check compatibility with other models 936 | past_key_values = model_kwargs.pop("past_key_values", None) 937 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 938 | if past_key_values is None: 939 | model_inputs["input_ids"] = input_ids 940 | else: 941 | model_inputs["input_ids"] = model_inputs["input_ids"][:, -1 - guess_skip_dist:] 942 | model_inputs["position_ids"] = model_inputs["position_ids"][:, -1 - guess_skip_dist:] 943 | model_inputs["past_key_values"] = past_key_values 944 | 945 | ori_guess = None 946 | #set up guess_tokens for verification branch 947 | # past_tokens[LEVEL - 2] is not None means we are still in warmup stage filling multi-level window 948 | if past_tokens[LEVEL - 2] is not None and lst_token in token_map and GUESS_SET_SIZE > 0: 949 | ###############NOT ENTER CURRENTLY 950 | guess_tokens_ = token_map[lst_token] 951 | guess_tokens = [] 952 | for tok in list(guess_tokens_): 953 | guess_tokens += list(tok) 954 | ori_guess = guess_tokens 955 | #shards guess_tokens on different GPUs 956 | if DIST_WORKERS > 1 and guess_tokens is not None: 957 | assert len(guess_tokens) % GUESS_SIZE == 0 958 | cnt_guess = (len(guess_tokens) // GUESS_SIZE + DIST_WORKERS - 1) // DIST_WORKERS 959 | guess_base = cnt_guess * LOCAL_RANK 960 | guess_end = cnt_guess * (LOCAL_RANK + 1) 961 | guess_tokens = guess_tokens[GUESS_SIZE * guess_base: GUESS_SIZE * guess_end] 962 | if len(guess_tokens) == 0: 963 | guess_tokens = None 964 | else: 965 | guess_tokens = None 966 | 967 | assert return_dict_in_generate == False 968 | assert len(logits_processor) == 0 969 | # forward pass to get next token 970 | #if LOCAL_RANK == 0: 971 | # print("position: ", model_inputs["input_ids"], model_inputs["position_ids"], ori_guess, guess_tokens) 972 | #forward 973 | if DIST_WORKERS > 1: 974 | window_len = len(past_tokens[0]) + 1 975 | split_window_len = (window_len + DIST_WORKERS - 1) // DIST_WORKERS 976 | window_start = min(split_window_len * LOCAL_RANK, window_len) 977 | window_end = min(split_window_len * (LOCAL_RANK + 1), window_len) 978 | 979 | if LOCAL_RANK == DIST_WORKERS - 1: 980 | assert len(past_tokens[0]) == window_end - 1 981 | past_tokens_inp = [past_tokens[0][: window_end - 1]] 982 | for l in range(1, len(past_tokens)): 983 | tokens = past_tokens[l] 984 | past_tokens_inp.append(tokens[window_start: window_end] if tokens is not None else None) 985 | else: 986 | past_tokens_inp = past_tokens 987 | 988 | outputs = self.jforward_multilevel( 989 | **model_inputs, 990 | past_tokens=past_tokens_inp, 991 | guess_tokens=guess_tokens, 992 | return_dict=True, 993 | not_seq=NOT_SEQ, 994 | continue_all=CONTINUE_ALL, 995 | output_attentions=output_attentions, 996 | output_hidden_states=output_hidden_states, 997 | level=LEVEL, 998 | WINDOWS_SIZE=WINDOW_SIZE, 999 | guess_size=GUESS_SIZE, 1000 | fill_level=fill_level, 1001 | dist_workers=DIST_WORKERS, 1002 | la_mask_offset=0, 1003 | local_rank=LOCAL_RANK, 1004 | use_flash=USE_FLASH 1005 | ) 1006 | 1007 | steps += 1 1008 | 1009 | if synced_gpus and this_peer_finished: 1010 | continue # don't waste resources running the code we don't need 1011 | 1012 | if past_tokens[LEVEL - 2] is None: #prefill 1013 | next_token_logits = outputs.out_logits 1014 | else: 1015 | next_token_logits = outputs.out_logits #outputs.logits[:, -1, :] 1016 | 1017 | # pre-process distribution 1018 | #next_tokens_scores = logits_processor(input_ids, next_token_logits) 1019 | next_tokens_scores = next_token_logits 1020 | # argmax 1021 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 1022 | 1023 | if DIST_WORKERS > 1: 1024 | torch.distributed.broadcast(next_tokens, src=0) 1025 | 1026 | # finished sentences should have their next token be a padding token 1027 | if eos_token_id is not None: 1028 | if pad_token_id is None: 1029 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 1030 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 1031 | 1032 | first_guess = next_tokens.item() 1033 | max_hit = 0 1034 | hits = [first_guess] + [0] * (GUESS_SIZE - 1) 1035 | 1036 | new_results = [] 1037 | 1038 | if past_tokens[1] is None: #filling multi-level window, the very first step is different 1039 | assert fill_level == 0 1040 | past_tokens[0] = past_tokens[0][1:] 1041 | past_tokens[1] = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist() 1042 | 1043 | if DIST_WORKERS > 1: 1044 | nn_past_tokens = [copy.deepcopy(past_tokens[1])] 1045 | torch.distributed.broadcast_object_list(nn_past_tokens, src=DIST_WORKERS - 1) 1046 | past_tokens[1] = nn_past_tokens[0] 1047 | 1048 | fill_level += 1 1049 | elif past_tokens[LEVEL - 2] is None: #filling multi-level window 1050 | for level in range(fill_level + 1): 1051 | past_tokens[level] = past_tokens[level][1:] 1052 | current_past_tokens = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist() 1053 | 1054 | 1055 | if DIST_WORKERS > 1: 1056 | nn_past_tokens = [None] * DIST_WORKERS 1057 | torch.distributed.all_gather_object(nn_past_tokens, current_past_tokens) 1058 | current_past_tokens = sum(nn_past_tokens, []) 1059 | 1060 | 1061 | #time.sleep(10000) 1062 | past_tokens[fill_level + 1] = current_past_tokens[1:] 1063 | #print("new past: ", (LOCAL_RANK, past_tokens)) 1064 | 1065 | 1066 | fill_level += 1 1067 | else: 1068 | #time.sleep(10000) 1069 | #multi-level window is filled 1070 | #match guess tokens 1071 | if guess_tokens is not None: 1072 | guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist() 1073 | for eg in range(len(guess_results) // GUESS_SIZE): 1074 | egx = eg * GUESS_SIZE 1075 | correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE] 1076 | myguess = guess_tokens[egx:egx + GUESS_SIZE] 1077 | gg = 0 1078 | for gg in range(len(myguess)): 1079 | if myguess[gg] != correct[gg]: 1080 | break 1081 | if gg > max_hit: 1082 | max_hit = gg 1083 | max_hit_idx = eg 1084 | hits[:max_hit + 1] = correct[:max_hit + 1] 1085 | #max_hit is the length of longest accepted sequence in verification branch 1086 | 1087 | #sync max_hit if we have multi-GPUs 1088 | if DIST_WORKERS > 1: 1089 | max_hit_all_ranks = [0] * DIST_WORKERS 1090 | torch.distributed.all_gather_object(max_hit_all_ranks, max_hit) 1091 | max_hit = max(max_hit_all_ranks) 1092 | max_hit_rank = max_hit_all_ranks.index(max_hit) 1093 | 1094 | if max_hit > 0: 1095 | hit_info = [hits] 1096 | torch.distributed.broadcast_object_list(hit_info, src=max_hit_rank) 1097 | hits = hit_info[0] 1098 | #print("rank: ", [hits, torch.distributed.get_rank(), max_hit, LOCAL_RANK, max_hit_rank]) 1099 | #if LOCAL_RANK == 0: 1100 | # print("rank: ",hits, max_hit) 1101 | #sync new_results 1102 | new_results = torch.argmax(outputs.inp_logits, dim=-1)[0].tolist() 1103 | 1104 | if DIST_WORKERS > 1: 1105 | nn_past_tokens = [None] * DIST_WORKERS 1106 | torch.distributed.all_gather_object(nn_past_tokens, new_results) 1107 | new_results = sum(nn_past_tokens, []) 1108 | #else: 1109 | # current_past_tokens = new_results 1110 | #print("brand new past: ", (LOCAL_RANK, past_tokens, new_results)) 1111 | 1112 | #time.sleep(1000) 1113 | 1114 | assert len(past_tokens[LEVEL - 2]) == WINDOW_SIZE and len(new_results) == WINDOW_SIZE 1115 | 1116 | update_token_map(token_map, lst_token, past_tokens, new_results, LEVEL, WINDOW_SIZE, GUESS_SET_SIZE) 1117 | 1118 | 1119 | if ALWAYS_FWD_ONE: 1120 | past_tokens[0] = past_tokens[1][1:] 1121 | for level in range(1, LEVEL - 2): 1122 | past_tokens[level] = past_tokens[level + 1][:] 1123 | 1124 | past_tokens[LEVEL - 2] = new_results 1125 | else: 1126 | past_tokens[0] = past_tokens[1][1 + max_hit:] 1127 | for level in range(1, LEVEL - 2): 1128 | past_tokens[level] = past_tokens[level + 1][max_hit:] 1129 | 1130 | past_tokens[LEVEL - 2] = new_results[max_hit:] 1131 | 1132 | 1133 | 1134 | if max_hit > 0: 1135 | if not ALWAYS_FWD_ONE: 1136 | for level in range(LEVEL - 1): 1137 | past_tokens[level] = past_tokens[level] + [set_token() for _ in range(max_hit)] 1138 | 1139 | attention_mask = model_kwargs["attention_mask"] 1140 | model_kwargs["attention_mask"] = torch.cat((attention_mask, torch.ones(1, max_hit, device=attention_mask.device, dtype=attention_mask.dtype)), dim=1) 1141 | 1142 | #not support awq 1143 | assert USE_AWQ == False 1144 | 1145 | past_key_values = [] 1146 | 1147 | #plan to remove kv-cache copy and set tokens into next input when dist_workers > 1, as communication is costly 1148 | if DIST_WORKERS > 1 and max_hit > 0: 1149 | 1150 | guess_skip_dist = max_hit 1151 | for idx, kv in enumerate(outputs.past_key_values): 1152 | past_key_values.append( (kv[0][:,:,:outputs.kvcache_len,:], kv[1][:,:,:outputs.kvcache_len,:]) ) 1153 | outputs.past_key_values = past_key_values 1154 | else: 1155 | guess_skip_dist = 0 1156 | offset_kv_cache = outputs.step_len-len(guess_tokens)+max_hit_idx * GUESS_SIZE if max_hit > 0 else 0 1157 | for idx, kv in enumerate(outputs.past_key_values): 1158 | #update kv-cache from verification branch 1159 | if max_hit > 0: 1160 | kv[0][:,:,outputs.kvcache_len:outputs.kvcache_len+max_hit,:] = kv[0][:,:,offset_kv_cache:offset_kv_cache+max_hit,:] 1161 | kv[1][:,:,outputs.kvcache_len:outputs.kvcache_len+max_hit,:] = kv[1][:,:,offset_kv_cache:offset_kv_cache+max_hit,:] 1162 | past_key_values.append( (kv[0][:,:,:outputs.kvcache_len + max_hit,:], kv[1][:,:,:outputs.kvcache_len + max_hit,:]) ) 1163 | outputs.past_key_values = past_key_values 1164 | 1165 | lst_token = hits[max_hit] 1166 | 1167 | #stopping condition 1168 | for hit_idx in range(max_hit + 1): 1169 | if eos_token_id is not None and hits[hit_idx] == eos_token_id[0]: 1170 | all_old_tokens.append(hits[hit_idx]) 1171 | next_tokens = eos_token_id_tensor 1172 | max_hit = hit_idx 1173 | break 1174 | else: 1175 | all_old_tokens.append(hits[max_hit]) 1176 | if POOL_FROM_PROMPT: 1177 | append_new_generated_pool(all_old_tokens[-LEVEL:], token_map, LEVEL, GUESS_SET_SIZE) 1178 | 1179 | 1180 | if chat and LOCAL_RANK == 0: 1181 | all_str = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \ 1182 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 1183 | if COLOR_PRINT: 1184 | from termcolor import colored 1185 | if max_hit > 1: 1186 | not_hit = self.tokenizer.decode(all_old_tokens[:-max_hit + 1], skip_special_tokens=True, \ 1187 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 1188 | pt = colored(not_hit[prev:],"blue") + colored(all_str[len(not_hit):], "blue") 1189 | else: 1190 | pt = all_str[prev:] 1191 | print(pt, flush=True, end="") 1192 | else: 1193 | print(all_str[prev:], flush=True, end="") 1194 | prev = len(all_str) 1195 | 1196 | input_ids = torch.cat([input_ids, torch.tensor(hits[:max_hit + 1], device=next_tokens.device, dtype=next_tokens.dtype).unsqueeze(0)], dim=-1) 1197 | 1198 | if streamer is not None: 1199 | streamer.put(next_tokens.cpu()) 1200 | model_kwargs = self._update_model_kwargs_for_generation( 1201 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 1202 | ) 1203 | 1204 | # if eos_token was found in one sentence, set sentence to finished 1205 | if eos_token_id_tensor is not None: 1206 | unfinished_sequences = unfinished_sequences.mul( 1207 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 1208 | ) 1209 | 1210 | # stop when each sentence is finished 1211 | if unfinished_sequences.max() == 0: 1212 | this_peer_finished = True 1213 | 1214 | # stop if we exceed the maximum length 1215 | if stopping_criteria(input_ids, scores): 1216 | this_peer_finished = True 1217 | 1218 | if this_peer_finished and not synced_gpus: 1219 | break 1220 | 1221 | for criteria in stopping_criteria: 1222 | if hasattr(criteria, "max_length"): 1223 | #print("steop: ", criteria.max_length, init_len, len(all_old_tokens), input_ids.size()) 1224 | all_old_tokens = all_old_tokens[:criteria.max_length] 1225 | input_ids = input_ids[:,:criteria.max_length] 1226 | if max_length is not None: 1227 | #print("max : ", max_length, init_len) 1228 | all_old_tokens = all_old_tokens[:init_len + max_length] 1229 | input_ids = input_ids[:][:init_len + max_length] 1230 | 1231 | if DEBUG and LOCAL_RANK == 0: 1232 | print("\n==========================ACCELERATION===SUMMARY======================================") 1233 | print("Generated tokens: ", len(all_old_tokens) - init_len, "Total steps: ", steps, " Compression ratio: ", round((len(all_old_tokens) - init_len) / steps, 2)) 1234 | print("======================================================================================", end="") 1235 | CONFIG_MAP["log"].append([len(all_old_tokens) - init_len, steps, round((len(all_old_tokens) - init_len) / steps, 2)]) 1236 | 1237 | if streamer is not None: 1238 | streamer.end() 1239 | 1240 | if return_dict_in_generate: 1241 | if self.config.is_encoder_decoder: 1242 | return GreedySearchEncoderDecoderOutput( 1243 | sequences=input_ids, 1244 | scores=scores, 1245 | encoder_attentions=encoder_attentions, 1246 | encoder_hidden_states=encoder_hidden_states, 1247 | decoder_attentions=decoder_attentions, 1248 | cross_attentions=cross_attentions, 1249 | decoder_hidden_states=decoder_hidden_states, 1250 | ) 1251 | else: 1252 | return GreedySearchDecoderOnlyOutput( 1253 | sequences=input_ids, 1254 | scores=scores, 1255 | attentions=decoder_attentions, 1256 | hidden_states=decoder_hidden_states, 1257 | ) 1258 | else: 1259 | return input_ids 1260 | 1261 | 1262 | 1263 | 1264 | 1265 | 1266 | def greedy_search_chat( 1267 | self, 1268 | input_ids: torch.LongTensor, 1269 | logits_processor: Optional[LogitsProcessorList] = None, 1270 | stopping_criteria: Optional[StoppingCriteriaList] = None, 1271 | max_length: Optional[int] = None, 1272 | pad_token_id: Optional[int] = None, 1273 | eos_token_id: Optional[Union[int, List[int]]] = None, 1274 | output_attentions: Optional[bool] = None, 1275 | output_hidden_states: Optional[bool] = None, 1276 | output_scores: Optional[bool] = None, 1277 | return_dict_in_generate: Optional[bool] = None, 1278 | synced_gpus: bool = False, 1279 | streamer: Optional["BaseStreamer"] = None, 1280 | chat: int=True, 1281 | stop_token: Optional[str] = None, 1282 | **model_kwargs, 1283 | ) -> Union[GreedySearchOutput, torch.LongTensor]: 1284 | r""" 1285 | Generates sequences of token ids for models with a language modeling head using **greedy decoding** and can be 1286 | used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. 1287 | 1288 | 1289 | 1290 | In most cases, you do not need to call [`~generation.GenerationMixin.greedy_search`] directly. Use generate() 1291 | instead. For an overview of generation strategies and code examples, check the [following 1292 | guide](../generation_strategies). 1293 | 1294 | 1295 | 1296 | 1297 | Parameters: 1298 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1299 | The sequence used as a prompt for the generation. 1300 | logits_processor (`LogitsProcessorList`, *optional*): 1301 | An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] 1302 | used to modify the prediction scores of the language modeling head applied at each generation step. 1303 | stopping_criteria (`StoppingCriteriaList`, *optional*): 1304 | An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] 1305 | used to tell if the generation loop should stop. 1306 | 1307 | max_length (`int`, *optional*, defaults to 20): 1308 | **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated 1309 | tokens. The maximum length of the sequence to be generated. 1310 | pad_token_id (`int`, *optional*): 1311 | The id of the *padding* token. 1312 | eos_token_id (`Union[int, List[int]]`, *optional*): 1313 | The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. 1314 | output_attentions (`bool`, *optional*, defaults to `False`): 1315 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1316 | returned tensors for more details. 1317 | output_hidden_states (`bool`, *optional*, defaults to `False`): 1318 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1319 | for more details. 1320 | output_scores (`bool`, *optional*, defaults to `False`): 1321 | Whether or not to return the prediction scores. See `scores` under returned tensors for more details. 1322 | return_dict_in_generate (`bool`, *optional*, defaults to `False`): 1323 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1324 | synced_gpus (`bool`, *optional*, defaults to `False`): 1325 | Whether to continue running the while loop until max_length (needed for ZeRO stage 3) 1326 | streamer (`BaseStreamer`, *optional*): 1327 | Streamer object that will be used to stream the generated sequences. Generated tokens are passed 1328 | through `streamer.put(token_ids)` and the streamer is responsible for any further processing. 1329 | model_kwargs: 1330 | Additional model specific keyword arguments will be forwarded to the `forward` function of the model. 1331 | If model is an encoder-decoder model the kwargs should include `encoder_outputs`. 1332 | 1333 | Return: 1334 | [`~generation.GreedySearchDecoderOnlyOutput`], [`~generation.GreedySearchEncoderDecoderOutput`] or 1335 | `torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a 1336 | [`~generation.GreedySearchDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and 1337 | `return_dict_in_generate=True` or a [`~generation.GreedySearchEncoderDecoderOutput`] if 1338 | `model.config.is_encoder_decoder=True`. 1339 | 1340 | Examples: 1341 | 1342 | ```python 1343 | >>> from transformers import ( 1344 | ... AutoTokenizer, 1345 | ... AutoModelForCausalLM, 1346 | ... LogitsProcessorList, 1347 | ... MinLengthLogitsProcessor, 1348 | ... StoppingCriteriaList, 1349 | ... MaxLengthCriteria, 1350 | ... ) 1351 | 1352 | >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") 1353 | >>> model = AutoModelForCausalLM.from_pretrained("gpt2") 1354 | 1355 | >>> # set pad_token_id to eos_token_id because GPT2 does not have a PAD token 1356 | >>> model.generation_config.pad_token_id = model.generation_config.eos_token_id 1357 | 1358 | >>> input_prompt = "It might be possible to" 1359 | >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids 1360 | 1361 | >>> # instantiate logits processors 1362 | >>> logits_processor = LogitsProcessorList( 1363 | ... [ 1364 | ... MinLengthLogitsProcessor(10, eos_token_id=model.generation_config.eos_token_id), 1365 | ... ] 1366 | ... ) 1367 | >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)]) 1368 | 1369 | >>> outputs = model.greedy_search( 1370 | ... input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria 1371 | ... ) 1372 | 1373 | >>> tokenizer.batch_decode(outputs, skip_special_tokens=True) 1374 | ["It might be possible to get a better understanding of the nature of the problem, but it's not"] 1375 | ```""" 1376 | # init values 1377 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 1378 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 1379 | if max_length is not None: 1380 | warnings.warn( 1381 | "`max_length` is deprecated in this function, use" 1382 | " `stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])` instead.", 1383 | UserWarning, 1384 | ) 1385 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 1386 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 1387 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 1388 | if isinstance(eos_token_id, int): 1389 | eos_token_id = [eos_token_id] 1390 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 1391 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 1392 | output_attentions = ( 1393 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 1394 | ) 1395 | output_hidden_states = ( 1396 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 1397 | ) 1398 | return_dict_in_generate = ( 1399 | return_dict_in_generate 1400 | if return_dict_in_generate is not None 1401 | else self.generation_config.return_dict_in_generate 1402 | ) 1403 | 1404 | # init attention / hidden states / scores tuples 1405 | scores = () if (return_dict_in_generate and output_scores) else None 1406 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 1407 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 1408 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 1409 | 1410 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 1411 | if return_dict_in_generate and self.config.is_encoder_decoder: 1412 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 1413 | encoder_hidden_states = ( 1414 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 1415 | ) 1416 | 1417 | # keep track of which sequences are already finished 1418 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 1419 | 1420 | assert input_ids.size(0) == 1 1421 | all_old_tokens = input_ids[0].tolist() 1422 | init_len = len(all_old_tokens) 1423 | init = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \ 1424 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 1425 | prev = len(init) 1426 | steps = 0 1427 | this_peer_finished = False # used by synced_gpus only 1428 | while True: 1429 | steps += 1 1430 | if synced_gpus: 1431 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 1432 | # The following logic allows an early break if all peers finished generating their sequence 1433 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 1434 | # send 0.0 if we finished, 1.0 otherwise 1435 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 1436 | # did all peers finish? the reduced sum will be 0.0 then 1437 | if this_peer_finished_flag.item() == 0.0: 1438 | break 1439 | 1440 | # prepare model inputs 1441 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 1442 | 1443 | # forward pass to get next token 1444 | outputs = self( 1445 | **model_inputs, 1446 | return_dict=True, 1447 | output_attentions=output_attentions, 1448 | output_hidden_states=output_hidden_states, 1449 | ) 1450 | 1451 | if synced_gpus and this_peer_finished: 1452 | continue # don't waste resources running the code we don't need 1453 | 1454 | next_token_logits = outputs.logits[:, -1, :] 1455 | 1456 | # pre-process distribution 1457 | next_tokens_scores = logits_processor(input_ids, next_token_logits) 1458 | 1459 | # Store scores, attentions and hidden_states when required 1460 | if return_dict_in_generate: 1461 | if output_scores: 1462 | scores += (next_tokens_scores,) 1463 | if output_attentions: 1464 | decoder_attentions += ( 1465 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) 1466 | ) 1467 | if self.config.is_encoder_decoder: 1468 | cross_attentions += (outputs.cross_attentions,) 1469 | 1470 | if output_hidden_states: 1471 | decoder_hidden_states += ( 1472 | (outputs.decoder_hidden_states,) 1473 | if self.config.is_encoder_decoder 1474 | else (outputs.hidden_states,) 1475 | ) 1476 | 1477 | # argmax 1478 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 1479 | 1480 | # finished sentences should have their next token be a padding token 1481 | if eos_token_id is not None: 1482 | if pad_token_id is None: 1483 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 1484 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 1485 | 1486 | all_old_tokens.append(next_tokens.item()) 1487 | all_str = self.tokenizer.decode(all_old_tokens, skip_special_tokens=True, \ 1488 | spaces_between_special_tokens=False, clean_up_tokenization_spaces=True,) 1489 | if chat: 1490 | print(all_str[prev:], flush=True, end="") 1491 | prev = len(all_str) 1492 | 1493 | 1494 | # update generated ids, model inputs, and length for next step 1495 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 1496 | if streamer is not None: 1497 | streamer.put(next_tokens.cpu()) 1498 | model_kwargs = self._update_model_kwargs_for_generation( 1499 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 1500 | ) 1501 | 1502 | # if eos_token was found in one sentence, set sentence to finished 1503 | if eos_token_id_tensor is not None: 1504 | unfinished_sequences = unfinished_sequences.mul( 1505 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 1506 | ) 1507 | 1508 | # stop when each sentence is finished 1509 | if unfinished_sequences.max() == 0: 1510 | this_peer_finished = True 1511 | 1512 | # stop if we exceed the maximum length 1513 | if stopping_criteria(input_ids, scores): 1514 | this_peer_finished = True 1515 | 1516 | if this_peer_finished and not synced_gpus: 1517 | break 1518 | DEBUG = CONFIG_MAP.get("DEBUG", 0) 1519 | if DEBUG: 1520 | #print("===DEBUG INFO===", " generated tokens: ", len(all_old_tokens) - init_len, "total step: ", steps, len(token_map.keys()), sum(len(value) for value in token_map.values()), input_ids.numel(), reps) 1521 | 1522 | print("\n==========================ACCELERATION===SUMMARY======================================") 1523 | print("Generated tokens: ", len(all_old_tokens) - init_len, "Total steps: ", steps, " Compression ratio: N/A ") 1524 | print("======================================================================================", end="") 1525 | 1526 | if streamer is not None: 1527 | streamer.end() 1528 | 1529 | if return_dict_in_generate: 1530 | if self.config.is_encoder_decoder: 1531 | return GreedySearchEncoderDecoderOutput( 1532 | sequences=input_ids, 1533 | scores=scores, 1534 | encoder_attentions=encoder_attentions, 1535 | encoder_hidden_states=encoder_hidden_states, 1536 | decoder_attentions=decoder_attentions, 1537 | cross_attentions=cross_attentions, 1538 | decoder_hidden_states=decoder_hidden_states, 1539 | ) 1540 | else: 1541 | return GreedySearchDecoderOnlyOutput( 1542 | sequences=input_ids, 1543 | scores=scores, 1544 | attentions=decoder_attentions, 1545 | hidden_states=decoder_hidden_states, 1546 | ) 1547 | else: 1548 | return input_ids 1549 | -------------------------------------------------------------------------------- /lade/lade_distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from .decoding import CONFIG_MAP 4 | 5 | def get_device(): 6 | if "LOCAL_RANK" not in CONFIG_MAP: 7 | return 0 8 | local_rank = CONFIG_MAP["LOCAL_RANK"] 9 | return local_rank 10 | 11 | def distributed(): 12 | return "DIST_WORKERS" in CONFIG_MAP and CONFIG_MAP["DIST_WORKERS"] > 1 13 | -------------------------------------------------------------------------------- /lade/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import GenerationMixin 3 | from transformers.models.llama import modeling_llama 4 | 5 | from lade.decoding import greedy_search_proxy, sample_proxy, FUNC_MAP, CONFIG_MAP 6 | from lade.models import modeling_llama as lade_modeling_llama 7 | #from .from lade.models import modeling_llama 8 | from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM 9 | import torch 10 | import torch.distributed as dist 11 | import inspect 12 | 13 | def config_lade(WINDOW_SIZE=None, LEVEL=None, DEBUG=None, GUESS_SET_SIZE=None, ALWAYS_FWD_ONE=None, SPLIT_FLAG=None, DIST_WORKERS=None, POOL_FROM_PROMPT=None, backend = 'nccl', USE_FLASH=None): 14 | if WINDOW_SIZE is not None: 15 | CONFIG_MAP["WINDOW_SIZE"] = WINDOW_SIZE 16 | if LEVEL is not None: 17 | CONFIG_MAP["LEVEL"] = LEVEL 18 | if GUESS_SET_SIZE is not None: 19 | CONFIG_MAP["GUESS_SET_SIZE"] = GUESS_SET_SIZE 20 | if ALWAYS_FWD_ONE is not None: 21 | CONFIG_MAP["ALWAYS_FWD_ONE"] = ALWAYS_FWD_ONE 22 | if DEBUG is not None: 23 | CONFIG_MAP["DEBUG"] = DEBUG 24 | if SPLIT_FLAG is not None: 25 | CONFIG_MAP["SPLIT_FLAG"] = SPLIT_FLAG 26 | if POOL_FROM_PROMPT is not None: 27 | CONFIG_MAP["POOL_FROM_PROMPT"] = POOL_FROM_PROMPT 28 | if DIST_WORKERS is not None and DIST_WORKERS > 1: 29 | CONFIG_MAP["DIST_WORKERS"] = DIST_WORKERS 30 | CONFIG_MAP["LOCAL_RANK"] = int(os.environ["LOCAL_RANK"]) 31 | dist.init_process_group(backend, rank=CONFIG_MAP["LOCAL_RANK"]) 32 | torch.cuda.set_device(CONFIG_MAP["LOCAL_RANK"]) 33 | assert dist.get_world_size() == DIST_WORKERS, "DIST_WORKERS config should be equal to work size" 34 | if USE_FLASH is not None: 35 | CONFIG_MAP["USE_FLASH"] = USE_FLASH 36 | 37 | CONFIG_MAP["log"] = [] 38 | 39 | 40 | def inject_module(lade_module, original_module): 41 | s = {} 42 | for name, cls in inspect.getmembers(original_module, inspect.isclass): 43 | s[name] = cls 44 | for name, cls in inspect.getmembers(lade_module, inspect.isclass): 45 | if str(cls.__module__).startswith("lade") and name in s: 46 | tc = s[name] 47 | for method_name in dir(cls): 48 | if callable(getattr(cls, method_name)): 49 | try: 50 | setattr(tc, method_name, getattr(cls, method_name)) 51 | except: 52 | pass 53 | 54 | 55 | def augment_llama(): 56 | inject_module(lade_modeling_llama, modeling_llama) 57 | #llama.modeling_llama.LlamaForCausalLM = lade_modeling_llama.LlamaForCausalLM 58 | #modeling_llama.LlamaForCausalLM.jforward_multilevel = lookahead_llama.jforward_multilevel 59 | #modeling_llama.LlamaModel.LlamaModeljforward = lookahead_llama.LlamaModeljforward 60 | #modeling_llama.LlamaModel.j_prepare_decoder_attention_mask = lookahead_llama.j_prepare_decoder_attention_mask 61 | 62 | def augment_generate(): 63 | FUNC_MAP["greedy_search"] = GenerationMixin.greedy_search 64 | FUNC_MAP["sample"] = GenerationMixin.sample 65 | GenerationMixin.greedy_search = greedy_search_proxy 66 | GenerationMixin.sample = sample_proxy 67 | #FUNC_MAP["sample"] = GenerationMixin.sample 68 | #GenerationMixin.sample = sample_proxy 69 | 70 | def augment_all(): 71 | augment_llama() 72 | augment_generate() 73 | 74 | def log_history(clear=False): 75 | gen = 0 76 | step = 0 77 | if "log" in CONFIG_MAP: 78 | for log in CONFIG_MAP["log"]: 79 | gen += log[0] 80 | step += log[1] 81 | if clear: 82 | CONFIG_MAP["log"] = [] 83 | print("LADE LOG - OVERALL GEN: ", gen, " STEPS: ", step, " AVG COMPRESS RATIO: ", (gen / step) if step > 0 else 0) 84 | 85 | def save_log(log_dir): 86 | if "log" in CONFIG_MAP: 87 | torch.save(CONFIG_MAP["log"], log_dir) 88 | 89 | def get_hf_model(model_path, quant, dtype, device, cache_dir): 90 | tokenizer = AutoTokenizer.from_pretrained(model_path, fast_tokenizer=True) 91 | model_config = AutoConfig.from_pretrained(model_path) 92 | assert quant is None or len(quant) == 0 93 | 94 | model = AutoModelForCausalLM.from_pretrained( 95 | model_path, torch_dtype=dtype, device_map=device, cache_dir=cache_dir if len(cache_dir) > 0 else None) 96 | model = model.eval() 97 | model.tokenizer = tokenizer 98 | 99 | return model, tokenizer 100 | 101 | def get_model(model_path, quant, dtype, device, cache_dir, use_ds, native_offload = False): 102 | return get_hf_model(model_path, quant, dtype, device, cache_dir) -------------------------------------------------------------------------------- /media/acc-demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/acc-demo.gif -------------------------------------------------------------------------------- /media/jacobi-iteration.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/jacobi-iteration.gif -------------------------------------------------------------------------------- /media/lookahead-decoding.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/lookahead-decoding.gif -------------------------------------------------------------------------------- /media/lookahead-perf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/lookahead-perf.png -------------------------------------------------------------------------------- /media/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hao-ai-lab/LookaheadDecoding/eed010da9c7b1867912675d480feec5629e0c2d0/media/mask.png -------------------------------------------------------------------------------- /minimal-flash.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | import time 4 | import os 5 | if int(os.environ.get("LOAD_LADE", 0)): 6 | import lade 7 | lade.augment_all() 8 | #For a 7B model, set LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7 9 | lade.config_lade(LEVEL=7, WINDOW_SIZE=20, GUESS_SET_SIZE=20, DEBUG=1, USE_FLASH=True, POOL_FROM_PROMPT=True) 10 | 11 | assert torch.cuda.is_available() 12 | 13 | torch_device = "cuda" 14 | 15 | model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(model_name) 18 | 19 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device, attn_implementation="flash_attention_2") 20 | model.tokenizer = tokenizer 21 | prompt = "How do you fine tune a large language model?" 22 | input_text = ( 23 | f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.\n<|user|>\n{prompt}\n<|assistant|>" 24 | ) 25 | 26 | 27 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device) 28 | 29 | #warm up 30 | greedy_output = model.generate(**model_inputs, max_new_tokens=1) 31 | #end warm up 32 | 33 | # generate 256 new tokens 34 | torch.cuda.synchronize() 35 | t0s = time.time() 36 | sample_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=True, temperature=0.7, 37 | top_k=50, top_p=0.9) 38 | torch.cuda.synchronize() 39 | t1s = time.time() 40 | 41 | torch.cuda.synchronize() 42 | t0g = time.time() 43 | greedy_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=False) 44 | torch.cuda.synchronize() 45 | t1g = time.time() 46 | 47 | print("Output:\n" + 100 * '-') 48 | print("Greedy output: ", tokenizer.decode(greedy_output[0], skip_special_tokens=False)) 49 | print("Sample output: ", tokenizer.decode(sample_output[0], skip_special_tokens=False)) 50 | 51 | print("Greedy Generated Tokens:", (greedy_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (greedy_output.numel() - model_inputs['input_ids'].numel()) / (t1g - t0g), " tokens/s") 52 | print("Sample Generated Tokens:", (sample_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (sample_output.numel() - model_inputs['input_ids'].numel()) / (t1s - t0s), " tokens/s") 53 | 54 | #python minimal.py #44 tokens/s 55 | #LOAD_LADE=1 USE_LADE=1 python minimal.py #74 tokens/s, 1.6x throughput without changing output distribution! 56 | 57 | -------------------------------------------------------------------------------- /minimal.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | import time 4 | import os 5 | if int(os.environ.get("LOAD_LADE", 0)): 6 | import lade 7 | lade.augment_all() 8 | #For a 7B model, set LEVEL=5, WINDOW_SIZE=7, GUESS_SET_SIZE=7 9 | lade.config_lade(LEVEL=7, WINDOW_SIZE=20, GUESS_SET_SIZE=20, DEBUG=1, POOL_FROM_PROMPT=True) 10 | 11 | assert torch.cuda.is_available() 12 | 13 | torch_device = "cuda" 14 | 15 | model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" 16 | 17 | tokenizer = AutoTokenizer.from_pretrained(model_name) 18 | 19 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map=torch_device) 20 | model.tokenizer = tokenizer 21 | prompt = "How do you fine tune a large language model?" 22 | input_text = ( 23 | f"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate.\n<|user|>\n{prompt}\n<|assistant|>" 24 | ) 25 | 26 | 27 | model_inputs = tokenizer(input_text, return_tensors='pt').to(torch_device) 28 | 29 | #warm up 30 | greedy_output = model.generate(**model_inputs, max_new_tokens=1) 31 | #end warm up 32 | 33 | # generate 256 new tokens 34 | torch.cuda.synchronize() 35 | t0s = time.time() 36 | sample_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=True, temperature=0.7, 37 | top_k=50, top_p=0.9) 38 | torch.cuda.synchronize() 39 | t1s = time.time() 40 | 41 | torch.cuda.synchronize() 42 | t0g = time.time() 43 | greedy_output = model.generate(**model_inputs, max_new_tokens=256, do_sample=False) 44 | torch.cuda.synchronize() 45 | t1g = time.time() 46 | 47 | print("Output:\n" + 100 * '-') 48 | print("Greedy output: ", tokenizer.decode(greedy_output[0], skip_special_tokens=False)) 49 | print("Sample output: ", tokenizer.decode(sample_output[0], skip_special_tokens=False)) 50 | 51 | print("Greedy Generated Tokens:", (greedy_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (greedy_output.numel() - model_inputs['input_ids'].numel()) / (t1g - t0g), " tokens/s") 52 | print("Sample Generated Tokens:", (sample_output.numel() - model_inputs['input_ids'].numel()) ,"Generation Speed: ", (sample_output.numel() - model_inputs['input_ids'].numel()) / (t1s - t0s), " tokens/s") 53 | 54 | #python minimal.py #44 tokens/s 55 | #LOAD_LADE=1 USE_LADE=1 python minimal.py #74 tokens/s, 1.6x throughput without changing output distribution! 56 | 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.36.2 2 | accelerate==0.23.0 3 | fschat==0.2.31 4 | openai 5 | anthropic 6 | einops==0.7.0 7 | torch<2.1.1 #torch 2.1.1 uses sdpa, which is not supported yet 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | if __name__ == '__main__': 4 | setuptools.setup( 5 | name='lade', 6 | version='0.0.2', 7 | description='Lookahead Decoding Implementation', 8 | author='Fu Yichao', 9 | author_email='yichaofu2000@outlook.com', 10 | license='Apache-2', 11 | url='https://github.com/hao-ai-lab/LookaheadDecoding.git', 12 | packages=['lade', 'lade.models'] 13 | ) 14 | 15 | --------------------------------------------------------------------------------