├── .gitignore ├── LICENSE ├── README.md ├── application ├── README.md ├── accept_length.py ├── common.py ├── gen_judgment.py ├── gen_model_answer_baseline.py ├── gen_model_answer_prompt_decoding.py ├── get_throughput_results.py ├── show_result.py ├── tree_latency.py └── webui.py ├── assets ├── .DS_Store ├── Overview.png ├── PPD_LOGO.png ├── Speed_Mem_Train.png ├── latency.png └── ppd_demo.gif ├── dataset_generation └── generate_dataset.py ├── examples └── tutorial.ipynb ├── poetry.lock ├── prompt ├── __init__.py ├── evaluation │ ├── __init__.py │ └── eval.py ├── hf_utils.py ├── inference │ ├── __init__.py │ ├── dynamic_sparse_trees_2_13b.py │ ├── dynamic_sparse_trees_2_7b.py │ ├── dynamic_sparse_trees_3_MobileLLaMA.py │ ├── dynamic_sparse_trees_3_vicuna_13b.py │ ├── dynamic_sparse_trees_3_vicuna_68m.py │ ├── dynamic_sparse_trees_3_vicuna_7b.py │ ├── full_sparse_trees_3_13b.py │ ├── full_sparse_trees_3_7b.py │ ├── random_sparse_trees_3_7b.py │ └── sparse_tree_builder.py ├── model │ ├── __init__.py │ ├── kv_cache.py │ ├── model.py │ └── modeling_llama_custom.py ├── train │ ├── __init__.py │ └── train.py └── utils.py ├── pyproject.toml └── script ├── eval └── eval-2-1-ensemble.sh ├── latency ├── optimal-sparse-tree.sh ├── vicuna-13b-gen.sh └── vicuna-7b-gen.sh └── train └── train-ensemble-attention-kd.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | venv 3 | data 4 | ShareGPT_Vicuna_unfiltered/ 5 | test* 6 | log 7 | *egg-info 8 | *log -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PPD

Parallel Prompt Decoding: Accelerate LLM Inference with Parallel Prompting

2 | 3 | **PPD** (Parallel Prompt Decoding) is a cost-efficient method to accelerate LLMs with trained appended prompt tokens. Our technique stands out for three key features: 4 | - *Orthogonal Optimization*: Orthogonal to speculative decoding, PPD provides the potential for synergistic integration. 5 | - *Memory Efficiency*: With a minimal runtime memory overhead of just 0.0004%, PPD is highly suitable for edge and mobile settings. 6 | - *Training Efficiency*: The training process is efficient, requiring only 16 hours on a single A100-40GB GPU. 7 | 8 |
9 | 10 | 11 | 12 |
13 |
14 | PPD on Vicuna-7b. 15 |
16 |
17 |
18 | 19 | The key intuition of **PPD** lies in the observation that if trained properly, prompt tokens appended to the input can approximate tokens generated at future timesteps, thereby partially recovering the missing conditional dependency information for multi-token generation. 20 | 21 |
22 | 23 | 24 | 25 |
26 |
27 | Inspired by the human natural language generation process where continuous words like common expressions and phrases are produced simultaneously, PPD introduces the use of prompt tokens for multi-token prediction. Specifically, these trained prompt tokens are appended to the original input sequence, enabling the concurrent generation of multiple output tokens in a single forward pass. 28 |
29 |
30 |
31 |
32 | 33 | 34 | 35 |
36 |
37 |
38 |
39 |
40 | 41 | **PPD** is designed to address the following challenges faced by the current speculative decoding methods: 42 | 43 | - Low step compression rate due to conditional independence assumption. 44 | - High complexity and cost of training/maintaining draft models. 45 | - Limited applicability in edge and mobile environments due to memory overhead. 46 | 47 | Through extensive experiments across LLMs ranging from MobileLlama to Vicuna-13B on a wide range of benchmarks, our approach demonstrates up to 2.49 $\times$ speedup and maintains a minimal runtime memory overhead of just 0.0004%. 48 | 49 |
50 | 51 | 52 | 53 |
54 |
55 | Evaluated on a single A100 with MT-Bench. 56 |
57 |
58 |
59 | 60 | Our paper is available ([here](https://arxiv.org/abs/2405.18628v1))! If you found it helpful, pls cite us: 61 | ``` 62 | @article{hao2024ppd, 63 | title={Hardware-aware parallel prompt decoding for memory-efficient acceleration of LLM inference}, 64 | author={Chen, Hao (Mark) and Luk, Wayne and Yiu, Ka Fai Cedric and Li, Rui and Mishchenko, Konstantin and Venieris, Stylianos I and Fan, Hongxiang, 65 | journal={arXiv preprint arXiv:2405.18628}, 66 | year={2024} 67 | } 68 | ``` 69 | 70 | ## Contents 71 | 72 | - [Installation](#installation) 73 | - [Model Weights](#model-weights) 74 | - [Dataset Generation](#dataset-generation) 75 | - [Truncated Dataset](#truncated-dataset) 76 | - [Distillation Dataset](#distillation-dataset) 77 | - [Training Special tokens](#special-tokens) 78 | - [Inference](#inference) 79 | - [Chat Application](#chat-application) 80 | - [MT Bench](#mt-bench) 81 | - [Alpaca Eval](#alpaca-eval) 82 | - [HumanEval](#humaneval) 83 | - [GSM8K](#gsm8k) 84 | 85 | ## Installation 86 | 87 | ```bash 88 | git clone https://github.com/hmarkc/prompt-decoding.git 89 | cd prompt-decoding 90 | pip install -e . 91 | ``` 92 | 93 | ## Model Weights 94 | 95 | | Original Model | PPD embedding weights | 96 | | ------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | 97 | | [lmsys/vicuna-7b-v1.3](https://huggingface.co/lmsys/vicuna-7b-v1.3) | [hmarkc/ppd-vicuna-7b-v1.3](https://huggingface.co/hmarkc/ppd-vicuna-7b-v1.3) | 98 | | [lmsys/vicuna-13b-v1.3](https://huggingface.co/lmsys/vicuna-13b-v1.3) | [hmarkc/ppd-vicuna-13b-v1.3](https://huggingface.co/hmarkc/ppd-vicuna-13b-v1.3) | 99 | | [mtgv/MobileLLaMA-1.4B-Chat](https://huggingface.co/mtgv/MobileLLaMA-1.4B-Chat) | [hmarkc/ppd-MobileLLaMA-1.4B-Chat](https://huggingface.co/hmarkc/ppd-MobileLLaMA-1.4B-Chat) | 100 | 101 | ## Dataset Generation 102 | 103 | ### Truncated Dataset 104 | 105 | With a given dataset, a random truncation is performed to reduce the contextual bias of the training of special tokens. Then, a distillation dataset is generated from the truncated dataset. 106 | 107 | The truncated datasets need to be generated first. Here is how a dataset for 3 special tokens can be generated for the ShareGPT dataset. 108 | 109 | ``` 110 | python generate_dataset.py --dataset_type finetune --num_special_tokens 3 --data_path ShareGPT_V4.3_unfiltered_cleaned_split.json --model_max_length 2048 111 | ``` 112 | 113 | ### Distillation Dataset 114 | 115 | Then, we can generate the distillation dataset from the truncated dataset. `--data_path` is the path to the previously generated truncated dataset and `--model_name_or_path` is the model the distribution of which we want to obtain. 116 | 117 | ``` 118 | python generate_dataset.py --dataset_type distillation --num_special_tokens 3 --data_path ShareGPT_training_dataset_3_finetune_2048.pt --model_max_length 2048 --model_name_or_path lmsys/vicuna-7b-v1.3 119 | ``` 120 | 121 | ## Training Special tokens 122 | 123 | Example script to train Vicuna-7b with distillation dataset named "ShareGPT_training_dataset_2_distillation.pt". 124 | 125 | ``` 126 | accelerate launch --num_processes 4 prompt/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \ 127 | --dataset_path "./ShareGPT_training_dataset_2_distillation.pt" \ 128 | --output_dir test/ \ 129 | --num_train_epochs 1 \ 130 | --save_steps 500 \ 131 | --model_max_length 2048 \ 132 | --num_special_tokens 3 \ 133 | --virtual_tokens_per_special_token 1 \ 134 | --per_device_train_batch_size 1 \ 135 | --per_device_eval_batch_size 1 \ 136 | --gradient_accumulation_steps 4 \ 137 | --evaluation_strategy "no" \ 138 | --learning_rate 1e-2 \ 139 | --weight_decay 0.0 \ 140 | --warmup_ratio 0.0 \ 141 | --lr_scheduler_type "cosine" \ 142 | --logging_steps 10 \ 143 | --load_in_4bit \ 144 | --vt_attention_type "ensemble" \ 145 | --trainer_type "distillation_trainer" 146 | ``` 147 | 148 | You need to change the `--dataset_path` to the location of the distillation dataset and specify `--trainer_type` as "distillation_trainer" to train with knowledge distillation. `--num_special_tokens` specifies the number of special tokens for training. `--virtual_tokens_per_special_token` is the number of virtual tokens used for 1 special token, which should be set to 1 to achieve the lowest latency results. 149 | 150 | ## Inference 151 | 152 | We employ a dynamically extended tree attention and top K candidates for inference. The supported evaluation datasets currently include [Alpaca Eval](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json), [MT Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge), and [HumanEval](https://github.com/openai/human-eval). 153 | 154 | Refer to this [README.md](application/README.md) on how to install libraries for these datasets. 155 | 156 | ### Chat Application 157 | 158 | We implemented a simple chat application using `gradio`. To start a server for the chat application, run `python application/webui.py --ppd-path `. 159 | 160 | ### MT Bench 161 | 162 | - To obtain the latency of the baseline model (without the use of special tokens), run 163 | 164 | ``` 165 | python3 gen_model_answer_baseline.py 166 | --model-path 167 | --model-id 168 | --answer-file 169 | --bench-name mt_bench 170 | --temperature 171 | ``` 172 | 173 | - To obtain the latency of the model with special tokens, run 174 | 175 | ``` 176 | python3 gen_model_answer_prompt_decoding.py 177 | --model-path 178 | --model-id 179 | --answer-file 180 | --tree-length 105 181 | --bench-name mt_bench 182 | --temperature 183 | ``` 184 | 185 | `--model-path` is the path to the trained special tokens and `--tree-length` is the length of the sparse tree used. 186 | 187 | - To view the latency results of a generated `.jsonl` file, run 188 | 189 | ``` 190 | python get_throughput_results.py data/mt_bench/experiments/vicuna-7b-faster1.jsonl --n 3 191 | ``` 192 | 193 | `--n` specifies the number of experiment runs to get the average of. 194 | 195 | ### Alpaca Eval 196 | 197 | We use Alpaca Eval dataset as the evaluation dataset. The latency results can be obtained using the same script as MT Bench and adding `--bench-name alpaca_eval`. 198 | 199 | - To compare the latencies and accept lengths between sparse trees with different sizes, run: 200 | 201 | ``` 202 | python accept_length.py \ 203 | --dir-path \ 204 | --file-name \ 205 | --model-name \ 206 | --eval-file-name gen_model_answer_prompt_decoding.py \ 207 | --n 1 \ 208 | --max-length 120 \ 209 | --min-length 60 \ 210 | --length-interval 9 \ 211 | --choices "[75, 105, 135, 165, 195, 225, 255, 285]" \ 212 | 213 | python3 tree_latency.py \ 214 | --model-path \ 215 | --model-id \ 216 | --answer-file \ 217 | --bench-name alpaca_eval \ 218 | --min-tree-length 60 \ 219 | --max-tree-length 120 \ 220 | --length-interval 3 \ 221 | --max-new-token 1024 222 | ``` 223 | 224 | This [script](script/latency/optimal-sparse-tree.sh) runs the latency tests on a range of sparse trees. 225 | 226 | ### HumanEval 227 | 228 | The latency results of HumanEval can be obtained using the same script as MT Bench and adding `--bench-name humaneval`. 229 | 230 | ### GSM8K 231 | 232 | The latency results of GSM8K can be obtained using the same script as MT Bench and adding `--bench-name gsm8k`. 233 | 234 | -------------------------------------------------------------------------------- /application/README.md: -------------------------------------------------------------------------------- 1 | ## LLM Judge 2 | 3 | | [installation](https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/README.md) 4 | 5 | ### Citation 6 | ``` 7 | @misc{zheng2023judging, 8 | title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, 9 | author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, 10 | year={2023}, 11 | eprint={2306.05685}, 12 | archivePrefix={arXiv}, 13 | primaryClass={cs.CL} 14 | } 15 | ``` 16 | 17 | ## Alpaca Eval 18 | 19 | | [installation](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json) 20 | 21 | ### Citation 22 | 23 | ``` 24 | @misc{alpaca_eval, 25 | author = {Xuechen Li and Tianyi Zhang and Yann Dubois and Rohan Taori and Ishaan Gulrajani and Carlos Guestrin and Percy Liang and Tatsunori B. Hashimoto }, 26 | title = {AlpacaEval: An Automatic Evaluator of Instruction-following Models}, 27 | year = {2023}, 28 | publisher = {GitHub}, 29 | journal = {GitHub repository}, 30 | howpublished = {\url{https://github.com/tatsu-lab/alpaca_eval}} 31 | } 32 | ``` 33 | 34 | ## HumanEval 35 | 36 | | [installation](https://github.com/openai/human-eval/blob/master/README.md) 37 | 38 | ### Citation 39 | 40 | ``` 41 | @article{chen2021codex, 42 | title={Evaluating Large Language Models Trained on Code}, 43 | author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba}, 44 | year={2021}, 45 | eprint={2107.03374}, 46 | archivePrefix={arXiv}, 47 | primaryClass={cs.LG} 48 | } 49 | ``` 50 | 51 | ## GSM8K 52 | 53 | | There is no need to install the ```GSM8K``` dataset manually. 54 | 55 | ### Citation 56 | 57 | ``` 58 | @article{cobbe2021gsm8k, 59 | title={Training Verifiers to Solve Math Word Problems}, 60 | author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John}, 61 | journal={arXiv preprint arXiv:2110.14168}, 62 | year={2021} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /application/accept_length.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pandas as pd 4 | import argparse 5 | from tqdm import tqdm 6 | import os 7 | import json 8 | import pandas as pd 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | def get_throughput_results(input_file): 13 | with open(input_file, "r") as f: 14 | data = [json.loads(line) for line in f] 15 | 16 | new_tokens = [] 17 | wall_time = [] 18 | throughputs = [] 19 | for d in data: 20 | for choice in d["choices"]: 21 | new_tokens.extend(choice["new_tokens"]) 22 | wall_time.extend(choice["wall_time"]) 23 | for i in range(len(choice["new_tokens"])): 24 | throughputs.append(choice["new_tokens"][i] / choice["wall_time"][i]) 25 | 26 | return sum(new_tokens) / sum(wall_time) 27 | # return sum(throughputs) / len(throughputs) 28 | 29 | def main(model_name, dir_path, file_name, eval_file_name, max_length, min_length, length_interval, run_baseline, choices, n): 30 | max_throughput = 0 31 | # run baseline first 32 | if run_baseline: 33 | for j in range(n): 34 | if os.system(f"python3 gen_model_answer_baseline.py --model-path ../test/{model_name} --model-id vicuna --answer-file {dir_path}/baseline_{j}.jsonl --bench-name alpaca_eval --max-new-token 20"): 35 | raise ValueError("Failed to generate baseline") 36 | if choices is not None: 37 | r = eval(choices) 38 | else: 39 | r = range(min_length, max_length+1, length_interval) 40 | for i in tqdm(r): 41 | throughputs = [] 42 | for j in range(n): 43 | # if file does not exist, generate it 44 | if not os.path.exists(f"{dir_path}/{file_name}{i}_{j}.jsonl"): 45 | # use alpaca dataset for evaluation 46 | if os.system(f"python3 {eval_file_name} --model-path ../test/{model_name} --model-id vicuna_faster --answer-file {dir_path}/{file_name}{i}_{j}.jsonl --tree-length {i} --bench-name alpaca_eval --max-new-token 20"): 47 | raise ValueError("Failed to generate prompt decoding model") 48 | 49 | if os.path.exists(f"{dir_path}/{file_name}{i}_{j}.jsonl"): 50 | throughput = get_throughput_results(f"{dir_path}/{file_name}{i}_{j}.jsonl") 51 | throughputs.append(throughput) 52 | 53 | if len(throughputs) > 0: 54 | throughput = sum(throughputs) / len(throughputs) 55 | std = pd.Series(throughputs).std() 56 | 57 | if throughput > max_throughput: 58 | max_throughput = throughput 59 | best_sparse_tree = i 60 | 61 | print(f"Throughput for sparse tree {i}: {throughput:.3f} tokens/s", f"std: {std:.3f}") 62 | print(f"Best sparse tree: {best_sparse_tree}") 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--dir-path", type=str, default="data/mt_bench/dynamic_sparse_tree_search/3-1-7b", help="Path to the directory") 67 | parser.add_argument("--file-name", type=str, default="dynamic_sparse_tree", help="Name of the file") 68 | parser.add_argument("--model-name", type=str, default="vicuna-7b-3-1", help="Name of the model") 69 | parser.add_argument("--eval-file-name", type=str, default="gen_model_answer_prompt_decoding.py", help="Name of the evaluation file") 70 | parser.add_argument("--max-length", type=int, default=100, help="Max length of the sparse tree") 71 | parser.add_argument("--min-length", type=int, default=60, help="Min length of the sparse tree") 72 | parser.add_argument("--length-interval", type=int, default=1, help="Interval of the length of the sparse tree") 73 | parser.add_argument("--run-baseline", action="store_true", help="Run baseline first") 74 | parser.add_argument("--choices", type=str, default=None, help="Choices for the prompt decoding model") 75 | parser.add_argument("--n", type=int, default=1, help="Number of files to group") 76 | args = parser.parse_args() 77 | 78 | main(model_name=args.model_name, 79 | dir_path=args.dir_path, 80 | file_name=args.file_name, 81 | eval_file_name=args.eval_file_name, 82 | max_length=args.max_length, 83 | min_length=args.min_length, 84 | length_interval=args.length_interval, 85 | run_baseline=args.run_baseline, 86 | choices=args.choices, 87 | n=args.n) 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /application/common.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common data structures and utilities. 3 | """ 4 | 5 | import ast 6 | import dataclasses 7 | import glob 8 | import json 9 | import os 10 | import re 11 | import time 12 | from typing import Optional 13 | 14 | import openai 15 | import anthropic 16 | 17 | from fastchat.model.model_adapter import ( 18 | get_conversation_template, 19 | ANTHROPIC_MODEL_LIST, 20 | OPENAI_MODEL_LIST, 21 | ) 22 | 23 | # API setting constants 24 | API_MAX_RETRY = 16 25 | API_RETRY_SLEEP = 10 26 | API_ERROR_OUTPUT = "$ERROR$" 27 | 28 | TIE_DELTA = 0.1 29 | 30 | # Categories that need reference answers 31 | NEED_REF_CATS = ["math", "reasoning", "coding", "arena-hard-200"] 32 | 33 | # Extract scores from judgments 34 | two_score_pattern = re.compile("\[\[(\d+\.?\d*),\s?(\d+\.?\d*)\]\]") 35 | two_score_pattern_backup = re.compile("\[(\d+\.?\d*),\s?(\d+\.?\d*)\]") 36 | one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]") 37 | one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]") 38 | 39 | # Sampling temperature configs for 40 | temperature_config = { 41 | "writing": 0.7, 42 | "roleplay": 0.7, 43 | "extraction": 0.0, 44 | "math": 0.0, 45 | "coding": 0.0, 46 | "reasoning": 0.0, 47 | "stem": 0.1, 48 | "humanities": 0.1, 49 | "arena-hard-200": 0.0, 50 | } 51 | 52 | reverse_model_map = { 53 | "model_1": "model_2", 54 | "model_2": "model_1", 55 | } 56 | 57 | 58 | @dataclasses.dataclass 59 | class Judge: 60 | model_name: str 61 | prompt_template: dict 62 | ref_based: bool = False 63 | multi_turn: bool = False 64 | 65 | 66 | @dataclasses.dataclass 67 | class MatchSingle: 68 | question: dict 69 | model: str 70 | answer: dict 71 | judge: Judge 72 | ref_answer: dict = None 73 | multi_turn: bool = False 74 | 75 | 76 | @dataclasses.dataclass 77 | class MatchPair: 78 | question: dict 79 | model_1: str 80 | model_2: str 81 | answer_1: dict 82 | answer_2: dict 83 | judge: Judge 84 | ref_answer: dict = None 85 | multi_turn: bool = False 86 | 87 | 88 | def load_questions(question_file: str, begin: Optional[int], end: Optional[int]): 89 | """Load questions from a file.""" 90 | questions = [] 91 | with open(question_file, "r") as ques_file: 92 | for line in ques_file: 93 | if line: 94 | questions.append(json.loads(line)) 95 | questions = questions[begin:end] 96 | return questions 97 | 98 | 99 | def load_model_answers(answer_dir: str): 100 | """Load model answers. 101 | 102 | The return value is a python dict of type: 103 | Dict[model_name: str -> Dict[question_id: int -> answer: dict]] 104 | """ 105 | filenames = glob.glob(os.path.join(answer_dir, "*.jsonl")) 106 | filenames.sort() 107 | model_answers = {} 108 | 109 | for filename in filenames: 110 | model_name = os.path.basename(filename)[:-6] 111 | answer = {} 112 | with open(filename) as fin: 113 | for line in fin: 114 | line = json.loads(line) 115 | answer[line["question_id"]] = line 116 | model_answers[model_name] = answer 117 | 118 | return model_answers 119 | 120 | 121 | def load_judge_prompts(prompt_file: str): 122 | """Load judge prompts. 123 | 124 | The return value is a python dict of type: 125 | Dict[judge_name: str -> dict] 126 | """ 127 | prompts = {} 128 | with open(prompt_file) as fin: 129 | for line in fin: 130 | line = json.loads(line) 131 | prompts[line["name"]] = line 132 | return prompts 133 | 134 | 135 | def run_judge_single(question, answer, judge, ref_answer, multi_turn=False): 136 | kwargs = {} 137 | model = judge.model_name 138 | if ref_answer is not None: 139 | kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0] 140 | if multi_turn: 141 | kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1] 142 | 143 | if multi_turn: 144 | user_prompt = judge.prompt_template["prompt_template"].format( 145 | question_1=question["turns"][0], 146 | question_2=question["turns"][1], 147 | answer_1=answer["choices"][0]["turns"][0], 148 | answer_2=answer["choices"][0]["turns"][1], 149 | **kwargs, 150 | ) 151 | else: 152 | user_prompt = judge.prompt_template["prompt_template"].format( 153 | question=question["turns"][0], 154 | answer=answer["choices"][0]["turns"][0], 155 | **kwargs, 156 | ) 157 | 158 | rating = -1 159 | 160 | system_prompt = judge.prompt_template["system_prompt"] 161 | conv = get_conversation_template(model) 162 | conv.set_system_message(system_prompt) 163 | conv.append_message(conv.roles[0], user_prompt) 164 | conv.append_message(conv.roles[1], None) 165 | 166 | if model in OPENAI_MODEL_LIST: 167 | judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) 168 | elif model in ANTHROPIC_MODEL_LIST: 169 | judgment = chat_completion_anthropic( 170 | model, conv, temperature=0, max_tokens=1024 171 | ) 172 | else: 173 | raise ValueError(f"Invalid judge model name: {model}") 174 | 175 | if judge.prompt_template["output_format"] == "[[rating]]": 176 | match = re.search(one_score_pattern, judgment) 177 | if not match: 178 | match = re.search(one_score_pattern_backup, judgment) 179 | 180 | if match: 181 | rating = ast.literal_eval(match.groups()[0]) 182 | else: 183 | rating = -1 184 | else: 185 | raise ValueError( 186 | f"invalid output format: {judge.prompt_template['output_format']}" 187 | ) 188 | 189 | return rating, user_prompt, judgment 190 | 191 | 192 | def play_a_match_single(match: MatchSingle, output_file: str): 193 | question, model, answer, judge, ref_answer, multi_turn = ( 194 | match.question, 195 | match.model, 196 | match.answer, 197 | match.judge, 198 | match.ref_answer, 199 | match.multi_turn, 200 | ) 201 | 202 | if judge.prompt_template["type"] == "single": 203 | score, user_prompt, judgment = run_judge_single( 204 | question, answer, judge, ref_answer, multi_turn=multi_turn 205 | ) 206 | 207 | question_id = question["question_id"] 208 | turn = 1 if not multi_turn else 2 209 | result = { 210 | "question_id": question_id, 211 | "model": model, 212 | "judge": (judge.model_name, judge.prompt_template["name"]), 213 | "user_prompt": user_prompt, 214 | "judgment": judgment, 215 | "score": score, 216 | "turn": turn, 217 | "tstamp": time.time(), 218 | } 219 | print( 220 | f"question: {question_id}, turn: {turn}, model: {model}, " 221 | f"score: {score}, " 222 | f"judge: {(judge.model_name, judge.prompt_template['name'])}" 223 | ) 224 | else: 225 | raise ValueError(f"invalid judge type: {judge['type']}") 226 | 227 | if output_file: 228 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 229 | with open(output_file, "a") as fout: 230 | fout.write(json.dumps(result) + "\n") 231 | 232 | return result 233 | 234 | 235 | def run_judge_pair(question, answer_a, answer_b, judge, ref_answer, multi_turn=False): 236 | kwargs = {} 237 | model = judge.model_name 238 | if ref_answer is not None: 239 | kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0] 240 | if multi_turn: 241 | kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1] 242 | 243 | if multi_turn: 244 | system_prompt = judge.prompt_template["system_prompt"] 245 | user_prompt = judge.prompt_template["prompt_template"].format( 246 | question_1=question["turns"][0], 247 | question_2=question["turns"][1], 248 | answer_a_1=answer_a["choices"][0]["turns"][0], 249 | answer_b_1=answer_b["choices"][0]["turns"][0], 250 | answer_a_2=answer_a["choices"][0]["turns"][1], 251 | answer_b_2=answer_b["choices"][0]["turns"][1], 252 | **kwargs, 253 | ) 254 | else: 255 | system_prompt = judge.prompt_template["system_prompt"] 256 | user_prompt = judge.prompt_template["prompt_template"].format( 257 | question=question["turns"][0], 258 | answer_a=answer_a["choices"][0]["turns"][0], 259 | answer_b=answer_b["choices"][0]["turns"][0], 260 | **kwargs, 261 | ) 262 | 263 | winner = "error" 264 | 265 | conv = get_conversation_template(model) 266 | conv.append_message(conv.roles[0], user_prompt) 267 | conv.append_message(conv.roles[1], None) 268 | 269 | if model in OPENAI_MODEL_LIST: 270 | conv.set_system_message(system_prompt) 271 | judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048) 272 | elif model in ANTHROPIC_MODEL_LIST: 273 | if system_prompt != "You are a helpful assistant.": 274 | user_prompt = "[Instruction]\n" + system_prompt + "\n\n" + user_prompt 275 | conv.messages[0][1] = user_prompt 276 | judgment = chat_completion_anthropic( 277 | model, conv, temperature=0, max_tokens=1024 278 | ) 279 | else: 280 | raise ValueError(f"Invalid judge model name: {model}") 281 | 282 | if judge.prompt_template["output_format"] == "[[A]]": 283 | if "[[A]]" in judgment: 284 | winner = "A" 285 | elif "[[B]]" in judgment: 286 | winner = "B" 287 | elif "[[C]]" in judgment: 288 | winner = "tie" 289 | else: 290 | winner = "error" 291 | elif judge.prompt_template["output_format"] == "[[rating_a,rating_b]]": 292 | match = re.search(two_score_pattern, judgment) 293 | if not match: 294 | match = re.search(two_score_pattern_backup, judgment) 295 | if match: 296 | scores = [ast.literal_eval(s.strip()) for s in match.groups()] 297 | if abs(scores[0] - scores[1]) <= TIE_DELTA: 298 | winner = "tie" 299 | elif scores[0] > scores[1]: 300 | winner = "A" 301 | else: 302 | winner = "B" 303 | else: 304 | winner = "error" 305 | else: 306 | raise ValueError( 307 | f"invalid output format: {judge.prompt_template['output_format']}" 308 | ) 309 | 310 | return winner, user_prompt, judgment 311 | 312 | 313 | def play_a_match_pair(match: MatchPair, output_file: str): 314 | question, model_1, model_2, answer_1, answer_2, judge, ref_answer, multi_turn = ( 315 | match.question, 316 | match.model_1, 317 | match.model_2, 318 | match.answer_1, 319 | match.answer_2, 320 | match.judge, 321 | match.ref_answer, 322 | match.multi_turn, 323 | ) 324 | 325 | if judge.prompt_template["type"] == "pairwise": 326 | g1_winner, g1_user_prompt, g1_judgment = run_judge_pair( 327 | question, answer_1, answer_2, judge, ref_answer, multi_turn=multi_turn 328 | ) 329 | g2_winner, g2_user_prompt, g2_judgment = run_judge_pair( 330 | question, answer_2, answer_1, judge, ref_answer, multi_turn=multi_turn 331 | ) 332 | 333 | g1_map = {"A": "model_1", "B": "model_2"} 334 | g2_map = {"A": "model_2", "B": "model_1"} 335 | g1_winner = g1_map.get(g1_winner, g1_winner) 336 | g2_winner = g2_map.get(g2_winner, g2_winner) 337 | question_id = question["question_id"] 338 | turn = 1 if not multi_turn else 2 339 | 340 | result = { 341 | "question_id": question_id, 342 | "model_1": model_1, 343 | "model_2": model_2, 344 | "g1_winner": g1_winner, 345 | "g2_winner": g2_winner, 346 | "judge": (judge.model_name, judge.prompt_template["name"]), 347 | "g1_user_prompt": g1_user_prompt, 348 | "g1_judgment": g1_judgment, 349 | "g2_user_prompt": g2_user_prompt, 350 | "g2_judgment": g2_judgment, 351 | "turn": turn, 352 | "tstamp": time.time(), 353 | } 354 | 355 | print( 356 | f"question: {question_id}, turn: {turn}, model_1: {model_1}, model_2: {model_2}, " 357 | f"g1_winner: {g1_winner}, g2_winner: {g2_winner}, " 358 | f"judge: {(judge.model_name, judge.prompt_template['name'])}" 359 | ) 360 | elif judge.prompt_template["type"] == "single": 361 | m1_score, m1_user_prompt, m1_judgment = run_judge_single( 362 | question, answer_1, judge 363 | ) 364 | m2_score, m2_user_prompt, m2_judgment = run_judge_single( 365 | question, answer_2, judge 366 | ) 367 | 368 | if abs(m1_score - m2_score) <= TIE_DELTA: 369 | winner = "tie" 370 | elif m1_score > m2_score: 371 | winner = "model_1" 372 | else: 373 | winner = "model_2" 374 | 375 | question_id = question["question_id"] 376 | result = { 377 | "question_id": question_id, 378 | "model_1": model_1, 379 | "model_2": model_2, 380 | "g1_winner": winner, 381 | "g2_winner": winner, 382 | "judge": (judge.model_name, judge.prompt_template["name"]), 383 | "g1_user_prompt": m1_user_prompt, 384 | "g1_judgment": m1_judgment, 385 | "g2_user_prompt": m2_user_prompt, 386 | "g2_judgment": m2_judgment, 387 | "m1_score": m1_score, 388 | "m2_score": m2_score, 389 | "tstamp": time.time(), 390 | } 391 | print( 392 | f"question: {question_id}, model_1: {model_1}, model_2: {model_2}, " 393 | f"winner: {winner}, m1_score: {m1_score}, m2_score: {m2_score}, " 394 | f"judge: {(judge.model_name, judge.prompt_template['name'])}" 395 | ) 396 | else: 397 | raise ValueError(f"invalid judge type: {judge['type']}") 398 | 399 | if output_file: 400 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 401 | with open(output_file, "a") as fout: 402 | fout.write(json.dumps(result) + "\n") 403 | 404 | return result 405 | 406 | 407 | def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None): 408 | if api_dict is not None: 409 | openai.api_base = api_dict["api_base"] 410 | openai.api_key = api_dict["api_key"] 411 | output = API_ERROR_OUTPUT 412 | for _ in range(API_MAX_RETRY): 413 | try: 414 | messages = conv.to_openai_api_messages() 415 | response = openai.ChatCompletion.create( 416 | model=model, 417 | messages=messages, 418 | n=1, 419 | temperature=temperature, 420 | max_tokens=max_tokens, 421 | ) 422 | output = response["choices"][0]["message"]["content"] 423 | break 424 | except openai.error.OpenAIError as e: 425 | print(type(e), e) 426 | time.sleep(API_RETRY_SLEEP) 427 | 428 | return output 429 | 430 | 431 | def chat_completion_openai_azure(model, conv, temperature, max_tokens, api_dict=None): 432 | openai.api_type = "azure" 433 | openai.api_version = "2023-07-01-preview" 434 | if api_dict is not None: 435 | openai.api_base = api_dict["api_base"] 436 | openai.api_key = api_dict["api_key"] 437 | else: 438 | openai.api_base = os.environ["AZURE_OPENAI_ENDPOINT"] 439 | openai.api_key = os.environ["AZURE_OPENAI_KEY"] 440 | 441 | if "azure-" in model: 442 | model = model[6:] 443 | 444 | output = API_ERROR_OUTPUT 445 | for _ in range(API_MAX_RETRY): 446 | try: 447 | messages = conv.to_openai_api_messages() 448 | response = openai.ChatCompletion.create( 449 | engine=model, 450 | messages=messages, 451 | n=1, 452 | temperature=temperature, 453 | max_tokens=max_tokens, 454 | ) 455 | output = response["choices"][0]["message"]["content"] 456 | break 457 | except openai.error.OpenAIError as e: 458 | print(type(e), e) 459 | time.sleep(API_RETRY_SLEEP) 460 | except openai.error.InvalidRequestError as e: 461 | print(type(e), e) 462 | break 463 | except KeyError: 464 | print(response) 465 | break 466 | 467 | return output 468 | 469 | 470 | def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=None): 471 | if api_dict is not None and "api_key" in api_dict: 472 | api_key = api_dict["api_key"] 473 | else: 474 | api_key = os.environ["ANTHROPIC_API_KEY"] 475 | 476 | output = API_ERROR_OUTPUT 477 | for _ in range(API_MAX_RETRY): 478 | try: 479 | c = anthropic.Anthropic(api_key=api_key) 480 | prompt = conv.get_prompt() 481 | response = c.completions.create( 482 | model=model, 483 | prompt=prompt, 484 | stop_sequences=[anthropic.HUMAN_PROMPT], 485 | max_tokens_to_sample=max_tokens, 486 | temperature=temperature, 487 | ) 488 | output = response.completion 489 | break 490 | except anthropic.APIError as e: 491 | print(type(e), e) 492 | time.sleep(API_RETRY_SLEEP) 493 | return output.strip() 494 | 495 | 496 | def chat_completion_palm(chat_state, model, conv, temperature, max_tokens): 497 | from fastchat.serve.api_provider import init_palm_chat 498 | 499 | assert model == "palm-2-chat-bison-001" 500 | 501 | if chat_state is None: 502 | chat_state = init_palm_chat("chat-bison@001") 503 | 504 | parameters = { 505 | "temperature": temperature, 506 | "top_p": 0.8, 507 | "top_k": 40, 508 | "max_output_tokens": max_tokens, 509 | } 510 | output = API_ERROR_OUTPUT 511 | for _ in range(API_MAX_RETRY): 512 | try: 513 | response = chat_state.send_message(conv.messages[-2][1], **parameters) 514 | output = response.text 515 | break 516 | except Exception as e: 517 | print(type(e), e) 518 | time.sleep(API_RETRY_SLEEP) 519 | return chat_state, output 520 | 521 | 522 | def normalize_game_key_single(gamekey, result): 523 | """Make the model names sorted in a game key.""" 524 | qid, model_1, model_2 = gamekey 525 | if model_1 < model_2: 526 | return gamekey, result 527 | else: 528 | new_gamekey = (qid, model_2, model_1) 529 | new_result = { 530 | "winners": tuple(reverse_model_map.get(x, x) for x in result["winners"]), 531 | "g1_judgment": result["g2_judgment"], 532 | "g2_judgment": result["g1_judgment"], 533 | } 534 | return new_gamekey, new_result 535 | 536 | 537 | def normalize_game_key_dict(judgment_dict): 538 | """Make the model names sorted in the game keys.""" 539 | ret = {} 540 | for key, value in judgment_dict.items(): 541 | new_key, new_value = normalize_game_key_single(key, value) 542 | ret[new_key] = new_value 543 | return ret 544 | 545 | 546 | def load_pairwise_model_judgments(filename: str): 547 | """Load model judgments. 548 | 549 | The return value is a dict of type: 550 | Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict] 551 | """ 552 | judge_dict = {} 553 | 554 | for line in open(filename): 555 | obj = json.loads(line) 556 | judge = tuple(obj["judge"]) 557 | qid, model_1, model_2 = obj["question_id"], obj["model_1"], obj["model_2"] 558 | 559 | if judge not in judge_dict: 560 | judge_dict[judge] = {} 561 | 562 | if "winner" in obj: 563 | winner = obj["winner"] 564 | elif "g1_winner" in obj and "g2_winner" in obj: 565 | g1_winner, g2_winner = obj["g1_winner"], obj["g2_winner"] 566 | if g1_winner == g2_winner: 567 | winner = g1_winner 568 | else: 569 | winner = "inconsistent" 570 | else: 571 | raise ValueError(f"Invalid keys: {list(obj.keys())}") 572 | 573 | gamekey = (qid, model_1, model_2) 574 | winners = (winner,) 575 | 576 | judge_dict[judge][gamekey] = { 577 | "winners": winners, 578 | "g1_judgment": obj["g1_judgment"], 579 | "g2_judgment": obj["g2_judgment"], 580 | } 581 | 582 | # Make the model names sorted in the game keys 583 | normalized = {} 584 | for judge, value in judge_dict.items(): 585 | normalized[judge] = normalize_game_key_dict(value) 586 | return normalized 587 | 588 | 589 | def load_single_model_judgments(filename: str): 590 | """Load model judgments. 591 | 592 | The return value is a dict of type: 593 | Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict] 594 | """ 595 | judge_dict = {} 596 | 597 | for line in open(filename): 598 | obj = json.loads(line) 599 | judge = tuple(obj["judge"]) 600 | qid, model = obj["question_id"], obj["model"] 601 | 602 | if judge not in judge_dict: 603 | judge_dict[judge] = {} 604 | 605 | gamekey = (qid, model) 606 | 607 | judge_dict[judge][gamekey] = { 608 | "score": obj["score"], 609 | "judgment": obj["judgment"], 610 | } 611 | return judge_dict 612 | 613 | 614 | def resolve_pairwise_judgment_dict( 615 | question, model_judgments_normal, model_judgments_math, multi_turn=False 616 | ): 617 | """Return the correct pairwise judge.""" 618 | if multi_turn: 619 | if question["category"] in NEED_REF_CATS: 620 | return model_judgments_math[("gpt-4", "pair-math-v1-multi-turn")] 621 | return model_judgments_normal[("gpt-4", "pair-v2-multi-turn")] 622 | 623 | if question["category"] in NEED_REF_CATS: 624 | return model_judgments_math[("gpt-4", "pair-math-v1")] 625 | else: 626 | return model_judgments_normal[("gpt-4", "pair-v2")] 627 | 628 | 629 | def resolve_single_judgment_dict( 630 | question, model_judgments_normal, model_judgments_math, multi_turn=False 631 | ): 632 | """Return the correct single answer grading judge.""" 633 | if multi_turn: 634 | if question["category"] in NEED_REF_CATS: 635 | return model_judgments_math[("gpt-4", "single-math-v1-multi-turn")] 636 | return model_judgments_normal[("gpt-4", "single-v1-multi-turn")] 637 | 638 | if question["category"] in NEED_REF_CATS: 639 | return model_judgments_math[("gpt-4", "single-math-v1")] 640 | else: 641 | return model_judgments_normal[("gpt-4", "single-v1")] 642 | 643 | 644 | def get_pairwise_judge_explanation(gamekey, judgment_dict): 645 | """Get model judge explanation.""" 646 | try: 647 | qid, model_1, model_2 = gamekey 648 | if model_1 < model_2: 649 | res = judgment_dict[gamekey] 650 | g1_judgment, g2_judgment = res["g1_judgment"], res["g2_judgment"] 651 | else: 652 | new_gamekey = (qid, model_2, model_1) 653 | res = judgment_dict[new_gamekey] 654 | 655 | model_1, model_2 = model_1, model_2 656 | g1_judgment, g2_judgment = res["g2_judgment"], res["g1_judgment"] 657 | 658 | return ( 659 | f"**Game 1**. **A**: {model_1}, **B**: {model_2}\n\n" 660 | f"**Judgment**: {g1_judgment}" 661 | + f"\n\n`--------------------------`\n\n" 662 | + f"**Game 2**. **A**: {model_2}, **B**: {model_1}\n\n" 663 | f"**Judgment**: {g2_judgment}" 664 | ) 665 | except KeyError: 666 | return "N/A" 667 | 668 | 669 | def get_single_judge_explanation(gamekey, judgment_dict): 670 | """Get model judge explanation.""" 671 | try: 672 | qid, model = gamekey 673 | 674 | res = judgment_dict[gamekey] 675 | 676 | g1_judgment = res["judgment"] 677 | g1_score = res["score"] 678 | 679 | return ( 680 | f"**Game 1**. **A**: {model}, **Score**: {g1_score}\n\n" 681 | f"**Judgment**: {g1_judgment}" 682 | ) 683 | except KeyError: 684 | return "N/A" 685 | 686 | 687 | def check_data(questions, model_answers, ref_answers, models, judges): 688 | # check model answers 689 | for m in models: 690 | assert m in model_answers, f"Missing model answer for {m}" 691 | m_answer = model_answers[m] 692 | for q in questions: 693 | assert ( 694 | q["question_id"] in m_answer 695 | ), f"Missing model {m}'s answer to Question {q['question_id']}" 696 | # check ref answers 697 | for jg in judges.values(): 698 | if not jg.ref_based: 699 | continue 700 | for q in questions: 701 | if q["category"] not in NEED_REF_CATS: 702 | continue 703 | assert ( 704 | q["question_id"] in ref_answers[jg.model_name] 705 | ), f"Missing reference answer to Question {q['question_id']} for judge {jg.model_name}" 706 | 707 | 708 | def get_model_list(answer_dir): 709 | file_paths = glob.glob(f"{answer_dir}/*.jsonl") 710 | file_names = [os.path.splitext(os.path.basename(f))[0] for f in file_paths] 711 | return file_names 712 | -------------------------------------------------------------------------------- /application/gen_judgment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all] 4 | """ 5 | import argparse 6 | from concurrent.futures import ThreadPoolExecutor 7 | import json 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from fastchat.llm_judge.common import ( 13 | load_questions, 14 | load_model_answers, 15 | load_judge_prompts, 16 | check_data, 17 | play_a_match_pair, 18 | play_a_match_single, 19 | get_model_list, 20 | Judge, 21 | MatchPair, 22 | MatchSingle, 23 | NEED_REF_CATS, 24 | ) 25 | 26 | 27 | def make_match( 28 | questions, 29 | models, 30 | model_answers, 31 | judge, 32 | baseline_model, 33 | ref_answers=None, 34 | multi_turn=False, 35 | ): 36 | matches = [] 37 | for q in questions: 38 | if multi_turn and len(q["turns"]) != 2: 39 | continue 40 | for i in range(len(models)): 41 | q_id = q["question_id"] 42 | m_1 = models[i] 43 | m_2 = baseline_model 44 | if m_1 == m_2: 45 | continue 46 | a_1 = model_answers[m_1][q_id] 47 | a_2 = model_answers[baseline_model][q_id] 48 | if ref_answers is not None: 49 | ref = ref_answers[judge.model_name][q_id] 50 | match = MatchPair( 51 | dict(q), 52 | m_1, 53 | m_2, 54 | a_1, 55 | a_2, 56 | judge, 57 | ref_answer=ref, 58 | multi_turn=multi_turn, 59 | ) 60 | else: 61 | match = MatchPair( 62 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn 63 | ) 64 | matches.append(match) 65 | return matches 66 | 67 | 68 | def make_match_all_pairs( 69 | questions, 70 | models, 71 | model_answers, 72 | judge, 73 | baseline_model=None, 74 | ref_answers=None, 75 | multi_turn=False, 76 | ): 77 | matches = [] 78 | for q in questions: 79 | if multi_turn and len(q["turns"]) != 2: 80 | continue 81 | for i in range(len(models)): 82 | for j in range(i + 1, len(models)): 83 | q_id = q["question_id"] 84 | m_1 = models[i] 85 | m_2 = models[j] 86 | a_1 = model_answers[m_1][q_id] 87 | a_2 = model_answers[m_2][q_id] 88 | if ref_answers is not None: 89 | ref = ref_answers[judge.model_name][q_id] 90 | match = MatchPair( 91 | dict(q), 92 | m_1, 93 | m_2, 94 | a_1, 95 | a_2, 96 | judge, 97 | ref_answer=ref, 98 | multi_turn=multi_turn, 99 | ) 100 | else: 101 | match = MatchPair( 102 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn 103 | ) 104 | matches.append(match) 105 | return matches 106 | 107 | 108 | def make_match_single( 109 | questions, 110 | models, 111 | model_answers, 112 | judge, 113 | baseline_model=None, 114 | ref_answers=None, 115 | multi_turn=False, 116 | ): 117 | matches = [] 118 | for q in questions: 119 | if multi_turn and len(q["turns"]) != 2: 120 | continue 121 | for i in range(len(models)): 122 | q_id = q["question_id"] 123 | m = models[i] 124 | a = model_answers[m][q_id] 125 | if ref_answers is not None: 126 | ref = ref_answers[judge.model_name][q_id] 127 | matches.append( 128 | MatchSingle( 129 | dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn 130 | ) 131 | ) 132 | else: 133 | matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn)) 134 | return matches 135 | 136 | 137 | def make_judge_pairwise(judge_model, judge_prompts): 138 | judges = {} 139 | judges["default"] = Judge(judge_model, judge_prompts["pair-v2"]) 140 | judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True) 141 | judges["default-mt"] = Judge( 142 | judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True 143 | ) 144 | judges["math-mt"] = Judge( 145 | judge_model, 146 | judge_prompts["pair-math-v1-multi-turn"], 147 | ref_based=True, 148 | multi_turn=True, 149 | ) 150 | return judges 151 | 152 | 153 | def make_judge_single(judge_model, judge_prompts): 154 | judges = {} 155 | judges["default"] = Judge(judge_model, judge_prompts["single-v1"]) 156 | judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True) 157 | judges["default-mt"] = Judge( 158 | judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True 159 | ) 160 | judges["math-mt"] = Judge( 161 | judge_model, 162 | judge_prompts["single-math-v1-multi-turn"], 163 | ref_based=True, 164 | multi_turn=True, 165 | ) 166 | return judges 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument( 172 | "--bench-name", 173 | type=str, 174 | default="mt_bench", 175 | help="The name of the benchmark question set.", 176 | ) 177 | parser.add_argument( 178 | "--judge-file", 179 | type=str, 180 | default="data/judge_prompts.jsonl", 181 | help="The file of judge prompts.", 182 | ) 183 | parser.add_argument("--judge-model", type=str, default="gpt-4") 184 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") 185 | parser.add_argument( 186 | "--mode", 187 | type=str, 188 | default="single", 189 | choices=["pairwise-baseline", "pairwise-all", "single"], 190 | help=( 191 | "Evaluation mode. " 192 | "`pairwise-baseline` runs pairwise comparision against a baseline. " 193 | "`pairwise-all` runs pairwise comparision between all pairs. " 194 | "`single` runs single answer grading." 195 | ), 196 | ) 197 | parser.add_argument( 198 | "--model-list", 199 | type=str, 200 | nargs="+", 201 | default=None, 202 | help="A list of models to be evaluated", 203 | ) 204 | parser.add_argument( 205 | "--parallel", type=int, default=1, help="The number of concurrent API calls." 206 | ) 207 | parser.add_argument( 208 | "--first-n", type=int, help="A debug option. Only run the first `n` judgments." 209 | ) 210 | args = parser.parse_args() 211 | 212 | question_file = f"data/{args.bench_name}/question.jsonl" 213 | answer_dir = f"data/{args.bench_name}/model_answer" 214 | ref_answer_dir = f"data/{args.bench_name}/reference_answer" 215 | 216 | # Load questions 217 | questions = load_questions(question_file, None, None) 218 | 219 | # Load answers 220 | model_answers = load_model_answers(answer_dir) 221 | ref_answers = load_model_answers(ref_answer_dir) 222 | 223 | # Load judge 224 | judge_prompts = load_judge_prompts(args.judge_file) 225 | 226 | if args.first_n: 227 | questions = questions[: args.first_n] 228 | 229 | if args.model_list is None: 230 | models = get_model_list(answer_dir) 231 | else: 232 | models = args.model_list 233 | 234 | if args.mode == "single": 235 | judges = make_judge_single(args.judge_model, judge_prompts) 236 | play_a_match_func = play_a_match_single 237 | output_file = ( 238 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" 239 | ) 240 | make_match_func = make_match_single 241 | baseline_model = None 242 | else: 243 | judges = make_judge_pairwise(args.judge_model, judge_prompts) 244 | play_a_match_func = play_a_match_pair 245 | output_file = ( 246 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" 247 | ) 248 | if args.mode == "pairwise-all": 249 | make_match_func = make_match_all_pairs 250 | baseline_model = None 251 | else: 252 | make_match_func = make_match 253 | baseline_model = args.baseline_model 254 | 255 | check_data(questions, model_answers, ref_answers, models, judges) 256 | 257 | question_math = [q for q in questions if q["category"] in NEED_REF_CATS] 258 | question_default = [q for q in questions if q["category"] not in NEED_REF_CATS] 259 | 260 | # Make matches 261 | matches = [] 262 | matches += make_match_func( 263 | question_default, models, model_answers, judges["default"], baseline_model 264 | ) 265 | matches += make_match_func( 266 | question_math, 267 | models, 268 | model_answers, 269 | judges["math"], 270 | baseline_model, 271 | ref_answers, 272 | ) 273 | matches += make_match_func( 274 | question_default, 275 | models, 276 | model_answers, 277 | judges["default-mt"], 278 | baseline_model, 279 | multi_turn=True, 280 | ) 281 | matches += make_match_func( 282 | question_math, 283 | models, 284 | model_answers, 285 | judges["math-mt"], 286 | baseline_model, 287 | ref_answers, 288 | multi_turn=True, 289 | ) 290 | 291 | match_stat = {} 292 | match_stat["bench_name"] = args.bench_name 293 | match_stat["mode"] = args.mode 294 | match_stat["judge"] = args.judge_model 295 | match_stat["baseline"] = baseline_model 296 | match_stat["model_list"] = models 297 | match_stat["total_num_questions"] = len(questions) 298 | match_stat["total_num_matches"] = len(matches) 299 | match_stat["output_path"] = output_file 300 | 301 | # Show match stats and prompt enter to continue 302 | print("Stats:") 303 | print(json.dumps(match_stat, indent=4)) 304 | input("Press Enter to confirm...") 305 | 306 | # Play matches 307 | if args.parallel == 1: 308 | for match in tqdm(matches): 309 | play_a_match_func(match, output_file=output_file) 310 | else: 311 | 312 | def play_a_match_wrapper(match): 313 | play_a_match_func(match, output_file=output_file) 314 | 315 | np.random.seed(0) 316 | np.random.shuffle(matches) 317 | 318 | with ThreadPoolExecutor(args.parallel) as executor: 319 | for match in tqdm( 320 | executor.map(play_a_match_wrapper, matches), total=len(matches) 321 | ): 322 | pass 323 | -------------------------------------------------------------------------------- /application/gen_model_answer_baseline.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import random 10 | import time 11 | import transformers 12 | 13 | import shortuuid 14 | import torch 15 | from tqdm import tqdm 16 | 17 | from fastchat.llm_judge.common import load_questions, temperature_config 18 | from fastchat.model import load_model, get_conversation_template 19 | from fastchat.utils import str_to_torch_dtype 20 | 21 | import sys 22 | 23 | from prompt.model.kv_cache import initialize_past_key_values 24 | from prompt.model.model import AutoPromptDecoder, PromptDecoder, PromptConfig 25 | from prompt.model.modeling_llama_custom import LlamaForCausalLM as CustomLlamaForCausalLM 26 | from human_eval.data import write_jsonl, read_problems 27 | from datasets import load_dataset 28 | from zeus.monitor import ZeusMonitor 29 | 30 | 31 | def infer(input_ids, model, tokenizer, choices, max_steps = 512, temperature=0.7, posterior_threshold = 0.09, posterior_alpha = 0.3, sampling='greedy', max_new_token=1024): 32 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 33 | # Avoid modifying the input_ids in-place 34 | input_ids = input_ids.clone() 35 | 36 | if not hasattr(model, "inference_buffers"): 37 | print('Generate buffers') 38 | model.generate_dynamic_buffers(choices) 39 | # Initialize the past key and value states 40 | if hasattr(model, "past_key_values"): 41 | past_key_values = model.past_key_values 42 | past_key_values_data = model.past_key_values_data 43 | current_length_data = model.current_length_data 44 | # Reset the past key and value states 45 | current_length_data.zero_() 46 | else: 47 | print('Initialize past key values') 48 | ( 49 | past_key_values, 50 | past_key_values_data, 51 | current_length_data, 52 | ) = initialize_past_key_values(model.base_model) 53 | model.past_key_values = past_key_values 54 | model.past_key_values_data = past_key_values_data 55 | model.current_length_data = current_length_data 56 | 57 | input_len = input_ids.shape[1] 58 | model.base_model.model.tree_mask = None 59 | model.base_model.model.vt_attention_mask = None 60 | model.base_model.model.prompt_token_indices = None 61 | outputs = model.base_model(input_ids, past_key_values = past_key_values, use_cache=True) 62 | new_token = 0 63 | 64 | for idx in range(max_steps): 65 | input_id = outputs.logits[:, -1:].argmax(dim=-1) 66 | outputs = model.base_model(input_id, use_cache=True, past_key_values = past_key_values) 67 | input_ids = torch.cat([input_ids, input_id], dim=-1) 68 | new_token += 1 69 | 70 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): 71 | break 72 | if new_token > max_new_token: 73 | break 74 | 75 | return input_ids, new_token, idx 76 | 77 | 78 | def run_eval( 79 | model_path, 80 | model_id, 81 | question_file, 82 | question_begin, 83 | question_end, 84 | answer_file, 85 | max_new_token, 86 | num_choices, 87 | num_gpus_per_model, 88 | num_gpus_total, 89 | max_gpu_memory, 90 | dtype, 91 | revision, 92 | benchname, 93 | warmup, 94 | tree_length, 95 | temperature, 96 | posterior_threshold, 97 | posterior_alpha, 98 | sampling, 99 | gpu_power 100 | ): 101 | if benchname == 'mt_bench': 102 | questions = load_questions(question_file, question_begin, question_end) 103 | elif benchname == 'humaneval': 104 | questions = read_problems() 105 | questions = list(questions.values())[question_begin:question_end] 106 | elif benchname == 'alpaca_eval': 107 | questions = json.load(open(question_file)) 108 | elif benchname == 'gsm8k': 109 | # only use the first 1000 questions from test set 110 | questions = load_dataset('gsm8k', 'main', streaming=False, split='test')['question'][:500] 111 | else: 112 | raise ValueError("Unknown benchmark name") 113 | 114 | # random shuffle the questions to balance the loading 115 | # random.shuffle(questions) 116 | 117 | # Split the question file into `num_gpus` files 118 | assert num_gpus_total % num_gpus_per_model == 0 119 | use_ray = num_gpus_total // num_gpus_per_model > 1 120 | 121 | if use_ray: 122 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)( 123 | get_model_answers 124 | ).remote 125 | else: 126 | get_answers_func = get_model_answers 127 | 128 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) 129 | ans_handles = [] 130 | for i in range(0, len(questions), chunk_size): 131 | ans_handles.append( 132 | get_answers_func( 133 | model_path, 134 | model_id, 135 | questions[i : i + chunk_size], 136 | answer_file, 137 | max_new_token, 138 | num_choices, 139 | num_gpus_per_model, 140 | max_gpu_memory, 141 | dtype=dtype, 142 | revision=revision, 143 | benchname=benchname, 144 | warmup=warmup, 145 | tree_length=tree_length, 146 | temperature=temperature, 147 | posterior_threshold=posterior_threshold, 148 | posterior_alpha=posterior_alpha, 149 | sampling=sampling, 150 | gpu_power=gpu_power 151 | ) 152 | ) 153 | 154 | if use_ray: 155 | ray.get(ans_handles) 156 | 157 | 158 | @torch.inference_mode() 159 | def get_model_answers( 160 | model_path, 161 | model_id, 162 | questions, 163 | answer_file, 164 | max_new_token, 165 | num_choices, 166 | num_gpus_per_model, 167 | max_gpu_memory, 168 | dtype, 169 | revision, 170 | benchname, 171 | warmup, 172 | tree_length, 173 | temperature, 174 | posterior_threshold, 175 | posterior_alpha, 176 | sampling, 177 | gpu_power 178 | ): 179 | model = AutoPromptDecoder.from_pretrained( 180 | model_path, 181 | low_cpu_mem_usage=True, 182 | # load_in_4bit=True, 183 | torch_dtype=torch.float16, 184 | # device_map="auto" 185 | ) 186 | model.cuda() 187 | tokenizer = model.tokenizer 188 | 189 | model.eval() 190 | print('Check model training state:',model.training) 191 | 192 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') 193 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices) 194 | 195 | # warmup 196 | for i in range(warmup): 197 | torch.manual_seed(0) 198 | question = questions[i] 199 | if benchname == 'mt_bench': 200 | num_turns = len(question["turns"]) 201 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k': 202 | num_turns = 1 203 | conv = get_conversation_template(model_id) 204 | turns = [] 205 | idxs = [] 206 | new_tokens = [] 207 | wall_time = [] 208 | for j in range(num_turns): 209 | if benchname == 'mt_bench': 210 | qs = question["turns"][j] 211 | conv.append_message(conv.roles[0], qs) 212 | conv.append_message(conv.roles[1], None) 213 | prompt = conv.get_prompt() 214 | elif benchname == 'humaneval': 215 | qs = question["prompt"] 216 | prompt = qs 217 | elif benchname == 'alpaca_eval': 218 | conv = get_conversation_template(model_id) 219 | conv.messages = [] 220 | conv.append_message(conv.roles[0], question["instruction"]) 221 | conv.append_message(conv.roles[1], "") 222 | prompt = conv.get_prompt() 223 | elif benchname == 'gsm8k': 224 | qs = question 225 | conv.append_message(conv.roles[0], qs) 226 | conv.append_message(conv.roles[1], None) 227 | prompt = conv.get_prompt() 228 | 229 | input_ids = tokenizer([prompt]).input_ids 230 | 231 | # try: 232 | torch.cuda.synchronize() 233 | start_time = time.time() 234 | output_ids, new_token, idx = infer( 235 | torch.as_tensor(input_ids).cuda(), 236 | model, 237 | tokenizer, 238 | tree_length, 239 | temperature=0, 240 | posterior_threshold=posterior_threshold, 241 | posterior_alpha=posterior_alpha, 242 | sampling=sampling, 243 | max_new_token=max_new_token 244 | ) 245 | torch.cuda.synchronize() 246 | total_time = time.time() - start_time 247 | if benchname == 'mt_bench': 248 | output_ids = output_ids[0][len(input_ids[0]) :] 249 | # be consistent with the template's stop_token_ids 250 | if conv.stop_token_ids: 251 | stop_token_ids_index = [ 252 | i 253 | for i, id in enumerate(output_ids) 254 | if id in conv.stop_token_ids 255 | ] 256 | if len(stop_token_ids_index) > 0: 257 | output_ids = output_ids[: stop_token_ids_index[0]] 258 | 259 | output = tokenizer.decode( 260 | output_ids, 261 | spaces_between_special_tokens=False, 262 | ) 263 | 264 | if conv.stop_str and output.find(conv.stop_str) > 0: 265 | output = output[: output.find(conv.stop_str)] 266 | for special_token in tokenizer.special_tokens_map.values(): 267 | if isinstance(special_token, list): 268 | for special_tok in special_token: 269 | output = output.replace(special_tok, "") 270 | else: 271 | output = output.replace(special_token, "") 272 | output = output.strip() 273 | 274 | if conv.name == "xgen" and output.startswith("Assistant:"): 275 | output = output.replace("Assistant:", "", 1).strip() 276 | conv.messages[-1][-1] = output 277 | 278 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k': 279 | output = tokenizer.decode( 280 | output_ids[0].tolist(), 281 | spaces_between_special_tokens=False, 282 | ) 283 | # except RuntimeError as e: 284 | # print(e) 285 | # print("ERROR question ID: ", question["question_id"]) 286 | # output = "ERROR" 287 | 288 | turns.append(output) 289 | idxs.append(int(idx)) 290 | new_tokens.append(int(new_token)) 291 | wall_time.append(total_time) 292 | print('Warmup done', warmup, 'steps') 293 | 294 | 295 | for i, question in tqdm(enumerate(questions), total=len(questions)): 296 | if benchname == 'mt_bench': 297 | question_id = question["question_id"] 298 | num_turns = len(question["turns"]) 299 | elif benchname == 'humaneval': 300 | question_id = question["task_id"] 301 | num_turns = 1 302 | elif benchname == 'alpaca_eval' or benchname == 'gsm8k': 303 | question_id = i 304 | num_turns = 1 305 | 306 | if "category" in question and question["category"] in temperature_config: 307 | if temperature is not None: 308 | print(f"Overwriting temperature with {temperature} from command line") 309 | temp = temperature 310 | else: 311 | print(f"Using temperature from config for category {question['category']}") 312 | temp = temperature_config[question["category"]] 313 | else: 314 | print(f"Unknown category, using default temperature 0.0") 315 | temp = 0.0 316 | 317 | choices = [] 318 | for i in range(num_choices): 319 | torch.manual_seed(0) 320 | conv = get_conversation_template(model_id) 321 | turns = [] 322 | idxs = [] 323 | new_tokens = [] 324 | wall_time = [] 325 | power = [] 326 | energy = [] 327 | for j in range(num_turns): 328 | if benchname == 'mt_bench': 329 | qs = question["turns"][j] 330 | conv.append_message(conv.roles[0], qs) 331 | conv.append_message(conv.roles[1], None) 332 | prompt = conv.get_prompt() 333 | elif benchname == 'humaneval': 334 | qs = question["prompt"] 335 | prompt = qs 336 | elif benchname == 'alpaca_eval': 337 | conv.messages = [] 338 | conv.append_message(conv.roles[0], question["instruction"]) 339 | conv.append_message(conv.roles[1], "") 340 | prompt = conv.get_prompt() 341 | elif benchname == 'gsm8k': 342 | qs = question 343 | conv.append_message(conv.roles[0], qs) 344 | conv.append_message(conv.roles[1], None) 345 | prompt = conv.get_prompt() 346 | input_ids = tokenizer([prompt]).input_ids 347 | 348 | try: 349 | torch.cuda.synchronize() 350 | if gpu_power: 351 | gpu_indices = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')] 352 | assert len(gpu_indices) == 1, "Only support single GPU for power measurement now" 353 | monitor = ZeusMonitor([0]) 354 | monitor.begin_window("infer") 355 | start_time = time.time() 356 | output_ids, new_token, idx = infer( 357 | torch.as_tensor(input_ids).cuda(), 358 | model, 359 | tokenizer, 360 | tree_length, 361 | temperature=temp, 362 | posterior_threshold=posterior_threshold, 363 | posterior_alpha=posterior_alpha, 364 | sampling=sampling, 365 | max_new_token=max_new_token 366 | ) 367 | 368 | torch.cuda.synchronize() 369 | if gpu_power: 370 | gpu_power_result = monitor.end_window("infer") 371 | total_time = time.time() - start_time 372 | if benchname == 'mt_bench': 373 | 374 | # if model.config.is_encoder_decoder: 375 | # output_ids = output_ids[0] 376 | # else: 377 | output_ids = output_ids[0][len(input_ids[0]) :] 378 | 379 | # be consistent with the template's stop_token_ids 380 | if conv.stop_token_ids: 381 | stop_token_ids_index = [ 382 | i 383 | for i, id in enumerate(output_ids) 384 | if id in conv.stop_token_ids 385 | ] 386 | if len(stop_token_ids_index) > 0: 387 | output_ids = output_ids[: stop_token_ids_index[0]] 388 | 389 | output = tokenizer.decode( 390 | output_ids, 391 | spaces_between_special_tokens=False, 392 | ) 393 | if conv.stop_str and output.find(conv.stop_str) > 0: 394 | output = output[: output.find(conv.stop_str)] 395 | for special_token in tokenizer.special_tokens_map.values(): 396 | if isinstance(special_token, list): 397 | for special_tok in special_token: 398 | output = output.replace(special_tok, "") 399 | else: 400 | output = output.replace(special_token, "") 401 | output = output.strip() 402 | 403 | if conv.name == "xgen" and output.startswith("Assistant:"): 404 | output = output.replace("Assistant:", "", 1).strip() 405 | conv.messages[-1][-1] = output 406 | 407 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k': 408 | output = tokenizer.decode( 409 | output_ids[0].tolist(), 410 | spaces_between_special_tokens=False, 411 | ) 412 | 413 | except RuntimeError as e: 414 | print("ERROR question ID: ", question["question_id"]) 415 | print(e) 416 | output = "ERROR" 417 | 418 | turns.append(output) 419 | idxs.append(int(idx)) 420 | new_tokens.append(int(new_token)) 421 | wall_time.append(total_time) 422 | if gpu_power: 423 | power.append(gpu_power_result.energy[0] / gpu_power_result.time) 424 | energy.append(gpu_power_result.energy[0]) 425 | # torch.cuda.empty_cache() 426 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time}) 427 | if gpu_power: 428 | choices[-1]["power"] = power 429 | choices[-1]["energy"] = energy 430 | 431 | # Dump answers 432 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 433 | with open(os.path.expanduser(answer_file), "a") as fout: 434 | ans_json = { 435 | "question_id": question_id, 436 | "answer_id": shortuuid.uuid(), 437 | "model_id": model_id, 438 | "choices": choices, 439 | "tstamp": time.time(), 440 | } 441 | fout.write(json.dumps(ans_json) + "\n") 442 | 443 | 444 | def reorg_answer_file(answer_file): 445 | """Sort by question id and de-duplication""" 446 | answers = {} 447 | with open(answer_file, "r") as fin: 448 | for l in fin: 449 | qid = json.loads(l)["question_id"] 450 | answers[qid] = l 451 | 452 | qids = sorted(list(answers.keys())) 453 | with open(answer_file, "w") as fout: 454 | for qid in qids: 455 | fout.write(answers[qid]) 456 | 457 | 458 | if __name__ == "__main__": 459 | parser = argparse.ArgumentParser() 460 | parser.add_argument( 461 | "--model-path", 462 | type=str, 463 | required=True, 464 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 465 | ) 466 | parser.add_argument( 467 | "--model-id", type=str, required=True, help="A custom name for the model." 468 | ) 469 | parser.add_argument( 470 | "--bench-name", 471 | type=str, 472 | default="mt_bench", 473 | help="The name of the benchmark question set.", 474 | ) 475 | parser.add_argument( 476 | "--question-begin", 477 | type=int, 478 | help="A debug option. The begin index of questions.", 479 | ) 480 | parser.add_argument( 481 | "--question-end", type=int, help="A debug option. The end index of questions." 482 | ) 483 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 484 | parser.add_argument( 485 | "--max-new-token", 486 | type=int, 487 | default=1024, 488 | help="The maximum number of new generated tokens.", 489 | ) 490 | parser.add_argument( 491 | "--num-choices", 492 | type=int, 493 | default=1, 494 | help="How many completion choices to generate.", 495 | ) 496 | parser.add_argument( 497 | "--num-gpus-per-model", 498 | type=int, 499 | default=1, 500 | help="The number of GPUs per model.", 501 | ) 502 | parser.add_argument( 503 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 504 | ) 505 | parser.add_argument( 506 | "--max-gpu-memory", 507 | type=str, 508 | help="Maxmum GPU memory used for model weights per GPU.", 509 | ) 510 | parser.add_argument( 511 | "--dtype", 512 | type=str, 513 | choices=["float32", "float16", "bfloat16"], 514 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 515 | default=None, 516 | ) 517 | parser.add_argument( 518 | "--revision", 519 | type=str, 520 | default="main", 521 | help="The model revision to load.", 522 | ) 523 | parser.add_argument( 524 | "--warmup", 525 | type=int, 526 | default=3, 527 | help="The number of warmup steps.", 528 | ) 529 | parser.add_argument( 530 | "--tree-length", 531 | type=str, 532 | default="75", 533 | help="The choices for sampling.", 534 | ) 535 | parser.add_argument( 536 | "--temperature", 537 | type=float, 538 | default=None, 539 | help="The temperature for sampling.", 540 | ) 541 | parser.add_argument( 542 | "--posterior_threshold", 543 | type=float, 544 | default=0.09, 545 | help="The threshold for posterior sampling.", 546 | ) 547 | parser.add_argument( 548 | "--posterior_alpha", 549 | type=float, 550 | default=0.3, 551 | help="The alpha for posterior sampling.", 552 | ) 553 | parser.add_argument( 554 | "--sampling", 555 | type=str, 556 | default='greedy', 557 | help="The sampling method for decoding." 558 | ) 559 | parser.add_argument( 560 | "--gpu-power", 561 | action='store_true', 562 | default=False, 563 | help="Whether to measure power consumption." 564 | ) 565 | 566 | args = parser.parse_args() 567 | 568 | if 'vicuna-13b' in args.model_path.lower(): 569 | if '-2-' in args.model_path: 570 | from prompt.inference.dynamic_sparse_trees_2_13b import * 571 | print('Using 13b sparse trees') 572 | elif '-3-' in args.model_path: 573 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import * 574 | print('Using 13b 3-1 sparse trees') 575 | else: 576 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import * 577 | print('Using 13b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-') 578 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length) 579 | elif 'vicuna-7b' in args.model_path.lower(): 580 | if '-2-' in args.model_path: 581 | from prompt.inference.dynamic_sparse_trees_2_7b import * 582 | print('Using 7b 2-1 sparse trees') 583 | elif '-3-' in args.model_path: 584 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import * 585 | print('Using 7b 3-1 sparse trees') 586 | else: 587 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import * 588 | print('Using 7b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-') 589 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length) 590 | elif 'mobilellama' in args.model_path.lower(): 591 | from prompt.inference.dynamic_sparse_trees_3_MobileLLaMA import * 592 | print('Using MobileLLaMA 3-1 sparse trees') 593 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length) 594 | else: 595 | raise ValueError("Unknown model path") 596 | 597 | if args.num_gpus_total // args.num_gpus_per_model > 1: 598 | import ray 599 | 600 | ray.init() 601 | 602 | question_file = None 603 | if args.bench_name == 'mt_bench': 604 | question_file = f"data/{args.bench_name}/question.jsonl" 605 | elif args.bench_name == 'alpaca_eval': 606 | question_file = f"data/{args.bench_name}/alpaca_eval.json" 607 | if args.answer_file: 608 | answer_file = args.answer_file 609 | else: 610 | answer_file_name = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)+"-sampling-"+args.sampling 611 | answer_file = f"data/{args.bench_name}/model_answer/{answer_file_name}.jsonl" 612 | 613 | print(f"Output to {answer_file}") 614 | run_eval( 615 | model_path=args.model_path, 616 | model_id=args.model_id, 617 | question_file=question_file, 618 | question_begin=args.question_begin, 619 | question_end=args.question_end, 620 | answer_file=answer_file, 621 | max_new_token=args.max_new_token, 622 | num_choices=args.num_choices, 623 | num_gpus_per_model=args.num_gpus_per_model, 624 | num_gpus_total=args.num_gpus_total, 625 | max_gpu_memory=args.max_gpu_memory, 626 | dtype=str_to_torch_dtype(args.dtype), 627 | revision=args.revision, 628 | benchname=args.bench_name, 629 | warmup=args.warmup, 630 | tree_length=args.tree_length, 631 | temperature=args.temperature, 632 | posterior_threshold=args.posterior_threshold, 633 | posterior_alpha=args.posterior_alpha, 634 | sampling=args.sampling, 635 | gpu_power=args.gpu_power 636 | ) 637 | 638 | reorg_answer_file(answer_file) 639 | -------------------------------------------------------------------------------- /application/gen_model_answer_prompt_decoding.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import random 10 | import time 11 | import transformers 12 | 13 | import shortuuid 14 | import torch 15 | from tqdm import tqdm 16 | 17 | from fastchat.llm_judge.common import load_questions, temperature_config 18 | from fastchat.model import load_model, get_conversation_template 19 | from fastchat.utils import str_to_torch_dtype 20 | 21 | import sys 22 | 23 | from prompt.model.kv_cache import initialize_past_key_values 24 | from prompt.model.model import AutoPromptDecoder, PromptDecoder, PromptConfig 25 | from prompt.model.modeling_llama_custom import LlamaForCausalLM as CustomLlamaForCausalLM 26 | from human_eval.data import write_jsonl, read_problems 27 | from datasets import load_dataset 28 | from zeus.monitor import ZeusMonitor 29 | 30 | 31 | def infer(input_ids, model, tokenizer, tree_length, max_steps = 512, temperature=0.7, posterior_threshold = 0.09, posterior_alpha = 0.3, sampling='greedy', max_new_token=1024): 32 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 33 | # Avoid modifying the input_ids in-place 34 | input_ids = input_ids.clone() 35 | 36 | if not hasattr(model, "inference_buffers"): 37 | print('Generate buffers') 38 | model.generate_dynamic_buffers(tree_length) 39 | # Initialize the past key and value states 40 | if hasattr(model, "past_key_values"): 41 | past_key_values = model.past_key_values 42 | past_key_values_data = model.past_key_values_data 43 | current_length_data = model.current_length_data 44 | # Reset the past key and value states 45 | current_length_data.zero_() 46 | else: 47 | ( 48 | past_key_values, 49 | past_key_values_data, 50 | current_length_data, 51 | ) = initialize_past_key_values(model.base_model) 52 | model.past_key_values = past_key_values 53 | model.past_key_values_data = past_key_values_data 54 | model.current_length_data = current_length_data 55 | 56 | input_len = input_ids.shape[1] 57 | logits, prompt_logits = model.start_inference(input_ids, past_key_values, current_length_data) 58 | new_token = 0 59 | 60 | for idx in range(max_steps): 61 | candidates, tree_candidates_embeds = model.generate_candidates( 62 | logits, 63 | prompt_logits, 64 | temperature, 65 | posterior_threshold, 66 | posterior_alpha, 67 | sampling) 68 | logits, all_logits = model.tree_decoding(tree_candidates_embeds, past_key_values, input_ids) 69 | best_candidate, accept_length = model.evaluate_posterior( 70 | logits, 71 | candidates, 72 | temperature, 73 | posterior_threshold, 74 | posterior_alpha, 75 | sampling) 76 | input_ids, logits, prompt_logits, new_token = model.update_inference_inputs( 77 | input_ids, 78 | candidates, 79 | best_candidate, 80 | accept_length, 81 | logits, 82 | all_logits, 83 | new_token, 84 | past_key_values_data, 85 | current_length_data, 86 | ) 87 | 88 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): 89 | break 90 | if new_token > max_new_token: 91 | break 92 | 93 | return input_ids, new_token, idx 94 | 95 | def run_eval( 96 | model_path, 97 | model_id, 98 | question_file, 99 | question_begin, 100 | question_end, 101 | answer_file, 102 | max_new_token, 103 | num_choices, 104 | num_gpus_per_model, 105 | num_gpus_total, 106 | max_gpu_memory, 107 | dtype, 108 | revision, 109 | benchname, 110 | warmup, 111 | tree_length, 112 | temperature, 113 | posterior_threshold, 114 | posterior_alpha, 115 | sampling, 116 | gpu_power 117 | ): 118 | if benchname == 'mt_bench': 119 | questions = load_questions(question_file, question_begin, question_end) 120 | elif benchname == 'humaneval': 121 | questions = read_problems() 122 | questions = list(questions.values())[question_begin:question_end] 123 | elif benchname == 'alpaca_eval': 124 | questions = json.load(open(question_file)) 125 | elif benchname == 'gsm8k': 126 | # only use the first 1000 questions from test set 127 | questions = load_dataset('gsm8k', 'main', streaming=False, split='test')['question'][:500] 128 | else: 129 | raise ValueError("Unknown benchmark name") 130 | 131 | # random shuffle the questions to balance the loading 132 | # random.shuffle(questions) 133 | 134 | # Split the question file into `num_gpus` files 135 | assert num_gpus_total % num_gpus_per_model == 0 136 | use_ray = num_gpus_total // num_gpus_per_model > 1 137 | 138 | if use_ray: 139 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)( 140 | get_model_answers 141 | ).remote 142 | else: 143 | get_answers_func = get_model_answers 144 | 145 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) 146 | ans_handles = [] 147 | for i in range(0, len(questions), chunk_size): 148 | ans_handles.append( 149 | get_answers_func( 150 | model_path, 151 | model_id, 152 | questions[i : i + chunk_size], 153 | answer_file, 154 | max_new_token, 155 | num_choices, 156 | num_gpus_per_model, 157 | max_gpu_memory, 158 | dtype=dtype, 159 | revision=revision, 160 | benchname=benchname, 161 | warmup=warmup, 162 | tree_length=tree_length, 163 | temperature=temperature, 164 | posterior_threshold=posterior_threshold, 165 | posterior_alpha=posterior_alpha, 166 | sampling=sampling, 167 | gpu_power=gpu_power 168 | ) 169 | ) 170 | 171 | if use_ray: 172 | ray.get(ans_handles) 173 | 174 | 175 | @torch.inference_mode() 176 | def get_model_answers( 177 | model_path, 178 | model_id, 179 | questions, 180 | answer_file, 181 | max_new_token, 182 | num_choices, 183 | num_gpus_per_model, 184 | max_gpu_memory, 185 | dtype, 186 | revision, 187 | benchname, 188 | warmup, 189 | tree_length, 190 | temperature, 191 | posterior_threshold, 192 | posterior_alpha, 193 | sampling, 194 | gpu_power 195 | ): 196 | model = AutoPromptDecoder.from_pretrained( 197 | model_path, 198 | low_cpu_mem_usage=True, 199 | # load_in_4bit=True, 200 | torch_dtype=torch.float16, 201 | # device_map="auto" 202 | ) 203 | model.cuda() 204 | tokenizer = model.tokenizer 205 | 206 | model.eval() 207 | print('Check model training state:',model.training) 208 | 209 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') 210 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices) 211 | 212 | # warmup 213 | for i in range(warmup): 214 | torch.manual_seed(0) 215 | question = questions[i] 216 | if benchname == 'mt_bench': 217 | num_turns = len(question["turns"]) 218 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k': 219 | num_turns = 1 220 | conv = get_conversation_template(model_id) 221 | turns = [] 222 | idxs = [] 223 | new_tokens = [] 224 | wall_time = [] 225 | for j in range(num_turns): 226 | if benchname == 'mt_bench': 227 | qs = question["turns"][j] 228 | conv.append_message(conv.roles[0], qs) 229 | conv.append_message(conv.roles[1], None) 230 | prompt = conv.get_prompt() 231 | elif benchname == 'humaneval': 232 | qs = question["prompt"] 233 | prompt = qs 234 | elif benchname == 'alpaca_eval': 235 | conv = get_conversation_template(model_id) 236 | conv.messages = [] 237 | conv.append_message(conv.roles[0], question["instruction"]) 238 | conv.append_message(conv.roles[1], "") 239 | prompt = conv.get_prompt() 240 | elif benchname == 'gsm8k': 241 | qs = question 242 | conv.append_message(conv.roles[0], qs) 243 | conv.append_message(conv.roles[1], None) 244 | prompt = conv.get_prompt() 245 | 246 | input_ids = tokenizer([prompt]).input_ids 247 | 248 | # try: 249 | torch.cuda.synchronize() 250 | start_time = time.time() 251 | output_ids, new_token, idx = infer( 252 | torch.as_tensor(input_ids).cuda(), 253 | model, 254 | tokenizer, 255 | tree_length, 256 | temperature=0, 257 | posterior_threshold=posterior_threshold, 258 | posterior_alpha=posterior_alpha, 259 | sampling=sampling, 260 | max_new_token=max_new_token 261 | ) 262 | torch.cuda.synchronize() 263 | total_time = time.time() - start_time 264 | if benchname == 'mt_bench': 265 | output_ids = output_ids[0][len(input_ids[0]) :] 266 | # be consistent with the template's stop_token_ids 267 | if conv.stop_token_ids: 268 | stop_token_ids_index = [ 269 | i 270 | for i, id in enumerate(output_ids) 271 | if id in conv.stop_token_ids 272 | ] 273 | if len(stop_token_ids_index) > 0: 274 | output_ids = output_ids[: stop_token_ids_index[0]] 275 | 276 | output = tokenizer.decode( 277 | output_ids, 278 | spaces_between_special_tokens=False, 279 | ) 280 | 281 | if conv.stop_str and output.find(conv.stop_str) > 0: 282 | output = output[: output.find(conv.stop_str)] 283 | for special_token in tokenizer.special_tokens_map.values(): 284 | if isinstance(special_token, list): 285 | for special_tok in special_token: 286 | output = output.replace(special_tok, "") 287 | else: 288 | output = output.replace(special_token, "") 289 | output = output.strip() 290 | 291 | if conv.name == "xgen" and output.startswith("Assistant:"): 292 | output = output.replace("Assistant:", "", 1).strip() 293 | conv.messages[-1][-1] = output 294 | 295 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k': 296 | output = tokenizer.decode( 297 | output_ids[0].tolist(), 298 | spaces_between_special_tokens=False, 299 | ) 300 | # except RuntimeError as e: 301 | # print(e) 302 | # print("ERROR question ID: ", question["question_id"]) 303 | # output = "ERROR" 304 | 305 | turns.append(output) 306 | idxs.append(int(idx)) 307 | new_tokens.append(int(new_token)) 308 | wall_time.append(total_time) 309 | print('Warmup done', warmup, 'steps') 310 | 311 | 312 | for i, question in tqdm(enumerate(questions), total=len(questions)): 313 | if benchname == 'mt_bench': 314 | question_id = question["question_id"] 315 | num_turns = len(question["turns"]) 316 | elif benchname == 'humaneval': 317 | question_id = question["task_id"] 318 | num_turns = 1 319 | elif benchname == 'alpaca_eval' or benchname == 'gsm8k': 320 | question_id = i 321 | num_turns = 1 322 | 323 | if "category" in question and question["category"] in temperature_config: 324 | if temperature is not None: 325 | print(f"Overwriting temperature with {temperature} from command line") 326 | temp = temperature 327 | else: 328 | print(f"Using temperature from config for category {question['category']}") 329 | temp = temperature_config[question["category"]] 330 | else: 331 | print(f"Unknown category, using default temperature 0.0") 332 | temp = 0.0 333 | 334 | choices = [] 335 | for i in range(num_choices): 336 | torch.manual_seed(0) 337 | conv = get_conversation_template(model_id) 338 | turns = [] 339 | idxs = [] 340 | new_tokens = [] 341 | wall_time = [] 342 | power = [] 343 | energy = [] 344 | for j in range(num_turns): 345 | if benchname == 'mt_bench': 346 | qs = question["turns"][j] 347 | conv.append_message(conv.roles[0], qs) 348 | conv.append_message(conv.roles[1], None) 349 | prompt = conv.get_prompt() 350 | elif benchname == 'humaneval': 351 | qs = question["prompt"] 352 | prompt = qs 353 | elif benchname == 'alpaca_eval': 354 | conv.messages = [] 355 | conv.append_message(conv.roles[0], question["instruction"]) 356 | conv.append_message(conv.roles[1], "") 357 | prompt = conv.get_prompt() 358 | elif benchname == 'gsm8k': 359 | qs = question 360 | conv.append_message(conv.roles[0], qs) 361 | conv.append_message(conv.roles[1], None) 362 | prompt = conv.get_prompt() 363 | input_ids = tokenizer([prompt]).input_ids 364 | 365 | try: 366 | torch.cuda.synchronize() 367 | if gpu_power: 368 | gpu_indices = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')] 369 | assert len(gpu_indices) == 1, "Only support single GPU for power measurement now" 370 | monitor = ZeusMonitor([0]) 371 | monitor.begin_window("infer") 372 | start_time = time.time() 373 | output_ids, new_token, idx = infer( 374 | torch.as_tensor(input_ids).cuda(), 375 | model, 376 | tokenizer, 377 | tree_length, 378 | temperature=temp, 379 | posterior_threshold=posterior_threshold, 380 | posterior_alpha=posterior_alpha, 381 | sampling=sampling, 382 | max_new_token=max_new_token 383 | ) 384 | 385 | torch.cuda.synchronize() 386 | if gpu_power: 387 | gpu_power_result = monitor.end_window("infer") 388 | total_time = time.time() - start_time 389 | if benchname == 'mt_bench': 390 | 391 | # if model.config.is_encoder_decoder: 392 | # output_ids = output_ids[0] 393 | # else: 394 | output_ids = output_ids[0][len(input_ids[0]) :] 395 | 396 | # be consistent with the template's stop_token_ids 397 | if conv.stop_token_ids: 398 | stop_token_ids_index = [ 399 | i 400 | for i, id in enumerate(output_ids) 401 | if id in conv.stop_token_ids 402 | ] 403 | if len(stop_token_ids_index) > 0: 404 | output_ids = output_ids[: stop_token_ids_index[0]] 405 | 406 | output = tokenizer.decode( 407 | output_ids, 408 | spaces_between_special_tokens=False, 409 | ) 410 | if conv.stop_str and output.find(conv.stop_str) > 0: 411 | output = output[: output.find(conv.stop_str)] 412 | for special_token in tokenizer.special_tokens_map.values(): 413 | if isinstance(special_token, list): 414 | for special_tok in special_token: 415 | output = output.replace(special_tok, "") 416 | else: 417 | output = output.replace(special_token, "") 418 | output = output.strip() 419 | 420 | if conv.name == "xgen" and output.startswith("Assistant:"): 421 | output = output.replace("Assistant:", "", 1).strip() 422 | conv.messages[-1][-1] = output 423 | 424 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k': 425 | output = tokenizer.decode( 426 | output_ids[0].tolist(), 427 | spaces_between_special_tokens=False, 428 | ) 429 | 430 | except RuntimeError as e: 431 | print("ERROR question ID: ", question["question_id"]) 432 | print(e) 433 | output = "ERROR" 434 | 435 | turns.append(output) 436 | idxs.append(int(idx)) 437 | new_tokens.append(int(new_token)) 438 | wall_time.append(total_time) 439 | if gpu_power: 440 | power.append(gpu_power_result.energy[0] / gpu_power_result.time) 441 | energy.append(gpu_power_result.energy[0]) 442 | # torch.cuda.empty_cache() 443 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time}) 444 | if gpu_power: 445 | choices[-1]["power"] = power 446 | choices[-1]["energy"] = energy 447 | 448 | # Dump answers 449 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 450 | with open(os.path.expanduser(answer_file), "a") as fout: 451 | ans_json = { 452 | "question_id": question_id, 453 | "answer_id": shortuuid.uuid(), 454 | "model_id": model_id, 455 | "choices": choices, 456 | "tstamp": time.time(), 457 | } 458 | fout.write(json.dumps(ans_json) + "\n") 459 | 460 | 461 | def reorg_answer_file(answer_file): 462 | """Sort by question id and de-duplication""" 463 | answers = {} 464 | with open(answer_file, "r") as fin: 465 | for l in fin: 466 | qid = json.loads(l)["question_id"] 467 | answers[qid] = l 468 | 469 | qids = sorted(list(answers.keys())) 470 | with open(answer_file, "w") as fout: 471 | for qid in qids: 472 | fout.write(answers[qid]) 473 | 474 | 475 | if __name__ == "__main__": 476 | parser = argparse.ArgumentParser() 477 | parser.add_argument( 478 | "--model-path", 479 | type=str, 480 | required=True, 481 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 482 | ) 483 | parser.add_argument( 484 | "--model-id", type=str, required=True, help="A custom name for the model." 485 | ) 486 | parser.add_argument( 487 | "--bench-name", 488 | type=str, 489 | default="mt_bench", 490 | help="The name of the benchmark question set.", 491 | ) 492 | parser.add_argument( 493 | "--question-begin", 494 | type=int, 495 | help="A debug option. The begin index of questions.", 496 | ) 497 | parser.add_argument( 498 | "--question-end", type=int, help="A debug option. The end index of questions." 499 | ) 500 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 501 | parser.add_argument( 502 | "--max-new-token", 503 | type=int, 504 | default=1024, 505 | help="The maximum number of new generated tokens.", 506 | ) 507 | parser.add_argument( 508 | "--num-choices", 509 | type=int, 510 | default=1, 511 | help="How many completion choices to generate.", 512 | ) 513 | parser.add_argument( 514 | "--num-gpus-per-model", 515 | type=int, 516 | default=1, 517 | help="The number of GPUs per model.", 518 | ) 519 | parser.add_argument( 520 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 521 | ) 522 | parser.add_argument( 523 | "--max-gpu-memory", 524 | type=str, 525 | help="Maxmum GPU memory used for model weights per GPU.", 526 | ) 527 | parser.add_argument( 528 | "--dtype", 529 | type=str, 530 | choices=["float32", "float16", "bfloat16"], 531 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 532 | default=None, 533 | ) 534 | parser.add_argument( 535 | "--revision", 536 | type=str, 537 | default="main", 538 | help="The model revision to load.", 539 | ) 540 | parser.add_argument( 541 | "--warmup", 542 | type=int, 543 | default=3, 544 | help="The number of warmup steps.", 545 | ) 546 | parser.add_argument( 547 | "--tree-length", 548 | type=str, 549 | default="63", 550 | help="The choices for sampling.", 551 | ) 552 | parser.add_argument( 553 | "--temperature", 554 | type=float, 555 | default=None, 556 | help="The temperature for sampling.", 557 | ) 558 | parser.add_argument( 559 | "--posterior_threshold", 560 | type=float, 561 | default=0.09, 562 | help="The threshold for posterior sampling.", 563 | ) 564 | parser.add_argument( 565 | "--posterior_alpha", 566 | type=float, 567 | default=0.3, 568 | help="The alpha for posterior sampling.", 569 | ) 570 | parser.add_argument( 571 | "--sampling", 572 | type=str, 573 | default='greedy', 574 | help="The sampling method for decoding." 575 | ) 576 | parser.add_argument( 577 | "--gpu-power", 578 | action='store_true', 579 | default=False, 580 | help="Whether to measure power consumption." 581 | ) 582 | 583 | args = parser.parse_args() 584 | 585 | if 'vicuna-13b' in args.model_path.lower(): 586 | if '-2-' in args.model_path: 587 | from prompt.inference.dynamic_sparse_trees_2_13b import * 588 | print('Using 13b sparse trees') 589 | elif '-3-' in args.model_path: 590 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import * 591 | print('Using 13b 3-1 sparse trees') 592 | else: 593 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import * 594 | print('Using 13b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-') 595 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length) 596 | elif 'vicuna-7b' in args.model_path.lower(): 597 | if '-2-' in args.model_path: 598 | from prompt.inference.dynamic_sparse_trees_2_7b import * 599 | print('Using 7b 2-1 sparse trees') 600 | elif '-3-' in args.model_path: 601 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import * 602 | print('Using 7b 3-1 sparse trees') 603 | else: 604 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import * 605 | print('Using 7b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-') 606 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length) 607 | elif 'mobilellama' in args.model_path.lower(): 608 | from prompt.inference.dynamic_sparse_trees_3_MobileLLaMA import * 609 | print('Using MobileLLaMA 3-1 sparse trees') 610 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length) 611 | elif 'vicuna-68m' in args.model_path.lower(): 612 | from prompt.inference.dynamic_sparse_trees_3_vicuna_68m import * 613 | print('Using Vicuna-68m sparse trees') 614 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length) 615 | else: 616 | raise ValueError("Unknown model path") 617 | if args.num_gpus_total // args.num_gpus_per_model > 1: 618 | import ray 619 | 620 | ray.init() 621 | 622 | question_file = None 623 | if args.bench_name == 'mt_bench': 624 | question_file = f"data/{args.bench_name}/question.jsonl" 625 | elif args.bench_name == 'alpaca_eval': 626 | question_file = f"data/{args.bench_name}/alpaca_eval.json" 627 | if args.answer_file: 628 | answer_file = args.answer_file 629 | else: 630 | answer_file_name = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)+"-sampling-"+args.sampling 631 | answer_file = f"data/{args.bench_name}/model_answer/{answer_file_name}.jsonl" 632 | 633 | print(f"Output to {answer_file}") 634 | run_eval( 635 | model_path=args.model_path, 636 | model_id=args.model_id, 637 | question_file=question_file, 638 | question_begin=args.question_begin, 639 | question_end=args.question_end, 640 | answer_file=answer_file, 641 | max_new_token=args.max_new_token, 642 | num_choices=args.num_choices, 643 | num_gpus_per_model=args.num_gpus_per_model, 644 | num_gpus_total=args.num_gpus_total, 645 | max_gpu_memory=args.max_gpu_memory, 646 | dtype=str_to_torch_dtype(args.dtype), 647 | revision=args.revision, 648 | benchname=args.bench_name, 649 | warmup=args.warmup, 650 | tree_length=args.tree_length, 651 | temperature=args.temperature, 652 | posterior_threshold=args.posterior_threshold, 653 | posterior_alpha=args.posterior_alpha, 654 | sampling=args.sampling, 655 | gpu_power=args.gpu_power 656 | ) 657 | 658 | reorg_answer_file(answer_file) 659 | -------------------------------------------------------------------------------- /application/get_throughput_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import pandas as pd 4 | import os 5 | import argparse 6 | 7 | def get_throughput_results(input_file): 8 | with open(input_file, "r") as f: 9 | data = [json.loads(line) for line in f] 10 | 11 | new_tokens = [] 12 | wall_time = [] 13 | throughputs = [] 14 | accept_lengths = [] 15 | idxs = [] 16 | for d in data: 17 | for choice in d["choices"]: 18 | new_tokens.extend(choice["new_tokens"]) 19 | wall_time.extend(choice["wall_time"]) 20 | for i in range(len(choice["new_tokens"])): 21 | throughputs.append(choice["new_tokens"][i] / choice["wall_time"][i]) 22 | accept_lengths.append(choice["new_tokens"][i] / (choice["idxs"][i]+1)) 23 | idxs.extend([idx+1 for idx in choice["idxs"]]) 24 | 25 | return sum(new_tokens) / sum(wall_time), sum(throughputs) / len(throughputs), sum(new_tokens) / sum(idxs), sum(wall_time) / sum(idxs) 26 | 27 | 28 | def get_tree_latency(input_file): 29 | with open(input_file, "r") as f: 30 | data = [json.loads(line) for line in f] 31 | 32 | latencies = {} 33 | for d in data: 34 | latencies[d['tree_length']] = sum(d['choices'][0]['wall_time']) 35 | 36 | return latencies 37 | 38 | 39 | def get_gpu_power(input_file): 40 | with open(input_file, "r") as f: 41 | data = [json.loads(line) for line in f] 42 | 43 | gpu_power = [] 44 | gpu_energy = [] 45 | new_tokens = [] 46 | power_per_token = [] 47 | energy_per_token = [] 48 | for d in data: 49 | for choice in d["choices"]: 50 | gpu_power.extend(choice["power"]) 51 | gpu_energy.extend(choice["energy"]) 52 | new_tokens.extend(choice["new_tokens"]) 53 | power_per_token.append(sum(choice["power"]) / sum(choice["new_tokens"])) 54 | energy_per_token.append(sum(choice["energy"]) / sum(choice["new_tokens"])) 55 | 56 | return sum(gpu_power) / sum(new_tokens), \ 57 | sum(gpu_power) / len(gpu_power), \ 58 | sum(gpu_energy) / sum(new_tokens), \ 59 | sum(gpu_energy) / len(gpu_energy) 60 | 61 | 62 | def print_average_throughputs(input_files): 63 | throughputs1, throughputs2, accepth_lengths, forward_pass_time = list(zip(*[get_throughput_results(input_file) for input_file in input_files])) 64 | print(f"Macro-Average throughput: {sum(throughputs1) / len(throughputs1):.3f} tokens/s") 65 | print(f"std: {pd.Series(throughputs1).std():.3f}") 66 | print(f"Micro-Average throughput: {sum(throughputs2) / len(throughputs2):.3f} tokens/s") 67 | print(f"std: {pd.Series(throughputs2).std():.3f}") 68 | print(f"Average accept lengths: {sum(accepth_lengths) / len(accepth_lengths):.5f}") 69 | print(f"std: {pd.Series(accepth_lengths).std():.5f}") 70 | print(f"Average forward pass time: {sum(forward_pass_time) / len(forward_pass_time):.5f} s") 71 | print(f"std: {pd.Series(forward_pass_time).std():.5f}") 72 | 73 | return (sum(accepth_lengths) / len(accepth_lengths)) / (sum(forward_pass_time) / len(forward_pass_time)) 74 | 75 | 76 | def print_gpu_power(input_files): 77 | power_per_token, power, energy_per_token, energy = list(zip(*[get_gpu_power(input_file) for input_file in input_files])) 78 | print(f"Power per token: {sum(power_per_token) / len(power_per_token):.3f} W/token") 79 | print(f"std: {pd.Series(power_per_token).std():.3f}") 80 | print(f"Total power: {sum(power) / 1000:.3f} W") 81 | print(f"std: {pd.Series(power).std():.3f}") 82 | print(f"Energy per token: {sum(energy_per_token) / len(energy_per_token):.3f} J/token") 83 | print(f"std: {pd.Series(energy_per_token).std():.3f}") 84 | print(f"Total energy: {sum(energy) / len(energy):.3f} J") 85 | print(f"std: {pd.Series(energy).std():.3f}") 86 | 87 | 88 | 89 | def parse_file_name(file_name): 90 | # file name is in the format prefix{ddd}_0.json or prefix{dd}_1.json or prefix{d}_2.json where d is a digit 91 | prefix = file_name.split("_")[-2] 92 | rst = 0 93 | for i in range(1, 4): 94 | if prefix[-i:].isdigit(): 95 | rst = prefix[-i:] 96 | return int(rst) 97 | 98 | 99 | def main(args): 100 | input_files = args.input_files 101 | # if the input is a directory, iterate over all files in the directory 102 | if os.path.isdir(input_files[0]): 103 | input_files = [os.path.join(input_files[0], f) for f in os.listdir(input_files[0]) if 'tree_latency.jsonl' not in f] 104 | # group files by prefix. Assume that the files are named as prefix1_0.json, prefix1_1.json, prefix2_0.json, prefix2_1.json, etc. The resulting list should be prefix0_0.json, prefix0_1.json, prefix0_2.json, prefix1_0.json, prefix1_1.json, etc. 105 | input_files = sorted(input_files, key=lambda x: parse_file_name(x)) 106 | # for each group, get the average throughput using print_average_throughputs 107 | max_accepth_lenght_to_forward_pass_time = 0 108 | best_tree = None 109 | for i in range(0, len(input_files), args.n): 110 | # print prefix 111 | print(">>>", input_files[i]) 112 | if args.gpu_power: 113 | print_gpu_power(input_files[i:i+args.n]) 114 | else: 115 | ratio = print_average_throughputs(input_files[i:i+args.n]) 116 | if ratio > max_accepth_lenght_to_forward_pass_time: 117 | max_accepth_lenght_to_forward_pass_time = ratio 118 | best_tree = input_files[i].split("_")[-2] 119 | if not args.gpu_power: 120 | print(f"Best sparse tree: {best_tree}, ratio: {max_accepth_lenght_to_forward_pass_time}") 121 | elif 'tree_latency.jsonl' in input_files[0]: 122 | latencies = get_tree_latency(input_files[0]) 123 | input_files = [os.path.join(input_files[1], f) for f in os.listdir(input_files[1]) if 'tree_latency.jsonl' not in f] 124 | input_files = sorted(input_files, key=lambda x: parse_file_name(x)) 125 | # for each group, get the average throughput using print_average_throughputs 126 | max_accepth_lenght_to_forward_pass_time = 0 127 | best_tree = None 128 | for i in range(0, len(input_files), 1): 129 | # print prefix 130 | print(">>>", input_files[i]) 131 | tree_length = parse_file_name(input_files[i]) 132 | if tree_length not in latencies: 133 | continue 134 | _, _, accepth_lengths, _ = get_throughput_results(input_files[i]) 135 | ratio = accepth_lengths / latencies[tree_length] 136 | if ratio > max_accepth_lenght_to_forward_pass_time: 137 | max_accepth_lenght_to_forward_pass_time = ratio 138 | best_tree = input_files[i].split("_")[-2] 139 | print("Ratio: ", ratio, 'Accept length: ', accepth_lengths, 'Latency: ', latencies[tree_length]) 140 | print(f"Best sparse tree: {best_tree}, ratio: {max_accepth_lenght_to_forward_pass_time}") 141 | else: 142 | if args.gpu_power: 143 | print_gpu_power(input_files) 144 | else: 145 | print_average_throughputs(input_files) 146 | 147 | 148 | if __name__ == "__main__": 149 | # parse arguments 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("input_files", nargs="+", help="Input files to get throughput results") 152 | parser.add_argument("--n", type=int, default=1, help="Number of files to group") 153 | parser.add_argument("--gpu-power", action="store_true", help="Get GPU power", default=False) 154 | args = parser.parse_args() 155 | main(args) -------------------------------------------------------------------------------- /application/show_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 show_result.py --mode [single|pairwise-baseline|pairwise-all] 4 | """ 5 | import argparse 6 | import pandas as pd 7 | 8 | 9 | def display_result_single(args): 10 | if args.input_file is None: 11 | input_file = ( 12 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" 13 | ) 14 | else: 15 | input_file = args.input_file 16 | 17 | print(f"Input file: {input_file}") 18 | df_all = pd.read_json(input_file, lines=True) 19 | df = df_all[["model", "score", "turn"]] 20 | df = df[df["score"] != -1] 21 | 22 | if args.model_list is not None: 23 | df = df[df["model"].isin(args.model_list)] 24 | 25 | print("\n########## First turn ##########") 26 | df_1 = df[df["turn"] == 1].groupby(["model", "turn"]).mean() 27 | print(df_1.sort_values(by="score", ascending=False)) 28 | 29 | if args.bench_name == "mt_bench": 30 | print("\n########## Second turn ##########") 31 | df_2 = df[df["turn"] == 2].groupby(["model", "turn"]).mean() 32 | print(df_2.sort_values(by="score", ascending=False)) 33 | 34 | print("\n########## Average ##########") 35 | df_3 = df[["model", "score"]].groupby(["model"]).mean() 36 | print(df_3.sort_values(by="score", ascending=False)) 37 | 38 | 39 | def display_result_pairwise(args): 40 | if args.input_file is None: 41 | input_file = ( 42 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" 43 | ) 44 | else: 45 | input_file = args.input_file 46 | 47 | print(f"Input file: {input_file}") 48 | df_all = pd.read_json(input_file, lines=True) 49 | df_all = df_all[(df_all["g1_winner"] != "error") & (df_all["g2_winner"] != "error")] 50 | 51 | model_list = ( 52 | df_all["model_1"].unique().tolist() + df_all["model_2"].unique().tolist() 53 | ) 54 | model_list = list(set(model_list)) 55 | 56 | list_res = [] 57 | # traverse df row by row 58 | for index, row in df_all.iterrows(): 59 | if args.model_list is not None and row["model_1"] not in args.model_list: 60 | continue 61 | if args.baseline_model is not None: 62 | if args.baseline_model not in [row["model_1"], row["model_2"]]: 63 | continue 64 | if row["g1_winner"] == "tie" or row["g1_winner"] != row["g2_winner"]: 65 | list_res.append({"model": row["model_1"], "win": 0, "loss": 0, "tie": 1}) 66 | list_res.append({"model": row["model_2"], "win": 0, "loss": 0, "tie": 1}) 67 | else: 68 | if row["g1_winner"] == "model_1": 69 | winner = row["model_1"] 70 | loser = row["model_2"] 71 | else: 72 | winner = row["model_2"] 73 | loser = row["model_1"] 74 | list_res.append({"model": winner, "win": 1, "loss": 0, "tie": 0}) 75 | list_res.append({"model": loser, "win": 0, "loss": 1, "tie": 0}) 76 | 77 | df = pd.DataFrame(list_res) 78 | df = df.groupby(["model"]).sum() 79 | 80 | # remove baseline model 81 | if args.baseline_model is not None: 82 | df = df[df.index != args.baseline_model] 83 | # add win rate 84 | df["win_rate"] = df["win"] / (df["win"] + df["loss"] + df["tie"]) 85 | df["loss_rate"] = df["loss"] / (df["win"] + df["loss"] + df["tie"]) 86 | # each tie counts as 0.5 win + 0.5 loss 87 | df["win_rate_adjusted"] = (df["win"] + 0.5 * df["tie"]) / ( 88 | df["win"] + df["loss"] + df["tie"] 89 | ) 90 | # print(df.sort_values(by="win_rate", ascending=False)) 91 | # print(df.sort_values(by="loss_rate", ascending=True)) 92 | print(df.sort_values(by="win_rate_adjusted", ascending=False)) 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--bench-name", type=str, default="mt_bench") 98 | parser.add_argument("--input-file", type=str) 99 | parser.add_argument("--judge-model", type=str, default="gpt-4") 100 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") 101 | parser.add_argument( 102 | "--model-list", 103 | type=str, 104 | nargs="+", 105 | default=None, 106 | help="A list of models to be evaluated", 107 | ) 108 | parser.add_argument( 109 | "--mode", 110 | type=str, 111 | default="single", 112 | choices=["pairwise-baseline", "pairwise-all", "single"], 113 | help=( 114 | "Evaluation mode. " 115 | "`pairwise-baseline` runs pairwise comparision against a baseline. " 116 | "`pairwise-all` runs pairwise comparision between all pairs. " 117 | "`single` runs single answer grading." 118 | ), 119 | ) 120 | args = parser.parse_args() 121 | 122 | if args.mode == "single": 123 | display_result_func = display_result_single 124 | else: 125 | if args.mode == "pairwise-all": 126 | args.baseline_model = None 127 | display_result_func = display_result_pairwise 128 | 129 | print(f"Mode: {args.mode}") 130 | display_result_func(args) 131 | -------------------------------------------------------------------------------- /application/webui.py: -------------------------------------------------------------------------------- 1 | # adapted from https://github.com/SafeAILab/EAGLE/blob/d08fe3f23e5f1d986bb50f786af60e5c4f7f757e/eagle/application/webui.py#L4 2 | import os 3 | import time 4 | 5 | import gradio as gr 6 | import argparse 7 | from prompt.utils import * 8 | from prompt.model.model import PromptDecoder, AutoPromptDecoder, PromptConfig 9 | from prompt.model.kv_cache import * 10 | import torch 11 | from fastchat.model import get_conversation_template 12 | import re 13 | 14 | 15 | def truncate_list(lst, num): 16 | if num not in lst: 17 | return lst 18 | 19 | 20 | first_index = lst.index(num) 21 | 22 | 23 | return lst[:first_index + 1] 24 | 25 | 26 | def find_list_markers(text): 27 | 28 | pattern = re.compile(r'(?m)(^\d+\.\s|\n)') 29 | matches = pattern.finditer(text) 30 | 31 | 32 | return [(match.start(), match.end()) for match in matches] 33 | 34 | 35 | def checkin(pointer,start,marker): 36 | for b,e in marker: 37 | if b<=pointer{text[pointer:start]}" 63 | 64 | result += sub_text 65 | 66 | pointer = end 67 | 68 | if pointer < len(text): 69 | result += f"{text[pointer:]}" 70 | 71 | return result 72 | 73 | 74 | def warmup(model): 75 | conv = get_conversation_template('vicuna') 76 | conv.append_message(conv.roles[0], "Hello") 77 | conv.append_message(conv.roles[1], None) 78 | prompt = conv.get_prompt() 79 | input_ids = model.tokenizer([prompt]).input_ids 80 | input_ids = torch.as_tensor(input_ids).cuda() 81 | for output_ids in model.ppd_generate(input_ids): 82 | ol=output_ids.shape[1] 83 | 84 | 85 | def bot(history, temperature, use_ppd, highlight_ppd,session_state,): 86 | if not history: 87 | return history, "0.00 tokens/s", "0.00", session_state 88 | pure_history = session_state.get("pure_history", []) 89 | conv = get_conversation_template('vicuna') 90 | 91 | for query, response in pure_history: 92 | conv.append_message(conv.roles[0], query) 93 | conv.append_message(conv.roles[1], response) 94 | 95 | prompt = conv.get_prompt() 96 | 97 | input_ids = model.tokenizer([prompt]).input_ids 98 | input_ids = torch.as_tensor(input_ids).cuda() 99 | input_len = input_ids.shape[1] 100 | naive_text = [] 101 | cu_len = input_len 102 | totaltime=0 103 | start_time=time.time() 104 | total_ids=0 105 | if use_ppd: 106 | 107 | for output_ids in model.ppd_generate(input_ids, temperature=temperature, max_steps=args.max_new_token): 108 | totaltime+=(time.time()-start_time) 109 | total_ids+=1 110 | decode_ids = output_ids[0, input_len:].tolist() 111 | decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id) 112 | text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False, 113 | clean_up_tokenization_spaces=True, ) 114 | naive_text.append(model.tokenizer.decode(output_ids[0, cu_len], skip_special_tokens=True, 115 | spaces_between_special_tokens=False, 116 | clean_up_tokenization_spaces=True, )) 117 | 118 | cu_len = output_ids.shape[1] 119 | colored_text = highlight_text(text, naive_text, "orange") 120 | if highlight_ppd: 121 | history[-1][1] = colored_text 122 | else: 123 | history[-1][1] = text 124 | pure_history[-1][1] = text 125 | session_state["pure_history"] = pure_history 126 | new_tokens = cu_len-input_len 127 | yield history,f"{new_tokens/totaltime:.2f} tokens/s",f"{new_tokens/total_ids:.2f}",session_state 128 | start_time = time.time() 129 | 130 | 131 | else: 132 | for output_ids in model.naive_generate(input_ids, temperature=temperature, max_steps=args.max_new_token): 133 | totaltime += (time.time() - start_time) 134 | total_ids+=1 135 | decode_ids = output_ids[0, input_len:].tolist() 136 | decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id) 137 | text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False, 138 | clean_up_tokenization_spaces=True, ) 139 | naive_text.append(model.tokenizer.decode(output_ids[0, cu_len], skip_special_tokens=True, 140 | spaces_between_special_tokens=False, 141 | clean_up_tokenization_spaces=True, )) 142 | cu_len = output_ids.shape[1] 143 | colored_text = highlight_text(text, naive_text, "orange") 144 | if highlight_ppd and use_ppd: 145 | history[-1][1] = colored_text 146 | else: 147 | history[-1][1] = text 148 | history[-1][1] = text 149 | pure_history[-1][1] = text 150 | new_tokens = cu_len - input_len 151 | yield history,f"{new_tokens/totaltime:.2f} tokens/s",f"{new_tokens/total_ids:.2f}",session_state 152 | start_time = time.time() 153 | 154 | 155 | def user(user_message, history,session_state): 156 | if history==None: 157 | history=[] 158 | pure_history = session_state.get("pure_history", []) 159 | pure_history += [[user_message, None]] 160 | session_state["pure_history"] = pure_history 161 | return "", history + [[user_message, None]],session_state 162 | 163 | 164 | def regenerate(history,session_state): 165 | if not history: 166 | return history, None,"0.00 tokens/s","0.00",session_state 167 | pure_history = session_state.get("pure_history", []) 168 | pure_history[-1][-1] = None 169 | session_state["pure_history"]=pure_history 170 | if len(history) > 1: # Check if there's more than one entry in history (i.e., at least one bot response) 171 | new_history = history[:-1] # Remove the last bot response 172 | last_user_message = history[-1][0] # Get the last user message 173 | return new_history + [[last_user_message, None]], None,"0.00 tokens/s","0.00",session_state 174 | history[-1][1] = None 175 | return history, None,"0.00 tokens/s","0.00",session_state 176 | 177 | 178 | def clear(history,session_state): 179 | pure_history = session_state.get("pure_history", []) 180 | pure_history = [] 181 | session_state["pure_history"] = pure_history 182 | return [],"0.00 tokens/s","0.00",session_state 183 | 184 | 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument( 187 | "--ppd-path", 188 | type=str, 189 | default="hmarkc/ppd-vicuna-7b-v1.3", 190 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 191 | ) 192 | parser.add_argument( 193 | "--load-in-8bit", action="store_true", help="Use 8-bit quantization" 194 | ) 195 | parser.add_argument( 196 | "--load-in-4bit", action="store_true", help="Use 4-bit quantization" 197 | ) 198 | parser.add_argument( 199 | "--max-new-token", 200 | type=int, 201 | default=512, 202 | help="The maximum number of new generated tokens.", 203 | ) 204 | args = parser.parse_args() 205 | 206 | model = AutoPromptDecoder.from_pretrained( 207 | args.ppd_path, 208 | low_cpu_mem_usage=True, 209 | torch_dtype=torch.float16, 210 | ) 211 | model.cuda() 212 | model.eval() 213 | warmup(model) 214 | 215 | custom_css = """ 216 | #speed textarea { 217 | color: red; 218 | font-size: 30px; 219 | }""" 220 | 221 | with gr.Blocks(css=custom_css) as demo: 222 | gs = gr.State({"pure_history": []}) 223 | gr.Markdown('''## PPD Chatbot''') 224 | with gr.Row(): 225 | speed_box = gr.Textbox(label="Speed", elem_id="speed", interactive=False, value="0.00 tokens/s") 226 | compression_box = gr.Textbox(label="Compression Ratio", elem_id="speed", interactive=False, value="0.00") 227 | 228 | chatbot = gr.Chatbot(height=600,show_label=False) 229 | 230 | 231 | msg = gr.Textbox(label="Your input") 232 | with gr.Row(): 233 | send_button = gr.Button("Send") 234 | stop_button = gr.Button("Stop") 235 | regenerate_button = gr.Button("Regenerate") 236 | clear_button = gr.Button("Clear") 237 | 238 | with gr.Row(): 239 | with gr.Column(): 240 | use_ppd = gr.Checkbox(label="Use PPD", value=True) 241 | highlight_ppd = gr.Checkbox(label="Highlight the tokens generated by PPD", value=True) 242 | temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="temperature", value=0.5) 243 | note=gr.Markdown(show_label=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by PPD" is checked, the tokens correctly guessed by PPD 244 | will be displayed in orange. Note: Checking this option may cause special formatting rendering issues in a few cases, especially when generating code''') 245 | enter_event=msg.submit(user, [msg, chatbot,gs], [msg, chatbot,gs], queue=True).then( 246 | bot, [chatbot, temperature, use_ppd, highlight_ppd,gs], [chatbot,speed_box,compression_box,gs] 247 | ) 248 | clear_button.click(clear, [chatbot,gs], [chatbot,speed_box,compression_box,gs], queue=True) 249 | 250 | send_event=send_button.click(user, [msg, chatbot,gs], [msg, chatbot,gs],queue=True).then( 251 | bot, [chatbot, temperature, use_ppd, highlight_ppd,gs], [chatbot,speed_box,compression_box,gs] 252 | ) 253 | regenerate_event=regenerate_button.click(regenerate, [chatbot,gs], [chatbot, msg,speed_box,compression_box,gs],queue=True).then( 254 | bot, [chatbot, temperature, use_ppd, highlight_ppd,gs], [chatbot,speed_box,compression_box,gs] 255 | ) 256 | stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event]) 257 | demo.queue() 258 | demo.launch(share=True) -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/.DS_Store -------------------------------------------------------------------------------- /assets/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/Overview.png -------------------------------------------------------------------------------- /assets/PPD_LOGO.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/PPD_LOGO.png -------------------------------------------------------------------------------- /assets/Speed_Mem_Train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/Speed_Mem_Train.png -------------------------------------------------------------------------------- /assets/latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/latency.png -------------------------------------------------------------------------------- /assets/ppd_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/ppd_demo.gif -------------------------------------------------------------------------------- /dataset_generation/generate_dataset.py: -------------------------------------------------------------------------------- 1 | from prompt.utils import * 2 | import argparse 3 | import math 4 | from transformers import LlamaForCausalLM 5 | from transformers import BitsAndBytesConfig 6 | 7 | def generate_fine_tune_dataset(args): 8 | tokenizer = transformers.AutoTokenizer.from_pretrained( 9 | args.model_name_or_path, 10 | model_max_length=args.model_max_length, 11 | padding_side="left", 12 | use_fast=False, 13 | truncation=True 14 | ) 15 | tokenizer.pad_token = tokenizer.unk_token 16 | data = get_finetune_dataset(tokenizer=tokenizer, data_path=args.data_path, size=args.size, offset=args.num_special_tokens+1) 17 | 18 | torch.save(data, f"{args.save_path}_{args.num_special_tokens}_finetune_{args.model_max_length}.pt") 19 | 20 | 21 | def generate_self_distillation_dataset(args): 22 | # Set RoPE scaling factor 23 | config = transformers.AutoConfig.from_pretrained(args.model_name_or_path) 24 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 25 | if orig_ctx_len and args.model_max_length > orig_ctx_len: 26 | scaling_factor = float(math.ceil(args.model_max_length / orig_ctx_len)) 27 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 28 | config.use_cache = False 29 | 30 | config = transformers.AutoConfig.from_pretrained(args.model_name_or_path) 31 | 32 | quantization_config = BitsAndBytesConfig( 33 | load_in_4bit=True, 34 | bnb_4bit_compute_dtype=torch.bfloat16, 35 | bnb_4bit_use_double_quant=True, 36 | bnb_4bit_quant_type="nf4", 37 | ) 38 | 39 | if config.model_type == "llama": 40 | model = LlamaForCausalLM.from_pretrained( 41 | args.model_name_or_path, 42 | config=config, 43 | low_cpu_mem_usage=True, 44 | quantization_config=quantization_config, 45 | ) 46 | else: 47 | raise ValueError("Only support llama for now") 48 | data = get_self_distillation_dataset(model=model, data_path=args.data_path, num_special_tokens=args.num_special_tokens) 49 | 50 | model_name = args.model_name_or_path.split("/")[-1] 51 | torch.save(data, f"{args.save_path}_{args.num_special_tokens}_{model_name}_distillation__{args.model_max_length}.pt") 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--model_name_or_path", type=str, default="lmsys/vicuna-7b-v1.3") 57 | parser.add_argument("--model_max_length", type=int, default=2048) 58 | parser.add_argument("--save_path", type=str, default="data/ShareGPT_training_dataset") 59 | parser.add_argument("--data_path", type=str, default="data/ShareGPT_training_dataset_2.pt") 60 | parser.add_argument("--size", type=int, default=None) 61 | parser.add_argument("--num_special_tokens", type=int, default=2) 62 | parser.add_argument("--dataset_type", type=str, default="finetune", choices=["finetune", "distillation"]) 63 | args = parser.parse_args() 64 | 65 | if args.dataset_type == "finetune": 66 | generate_fine_tune_dataset(args) 67 | elif args.dataset_type == "distillation": 68 | generate_self_distillation_dataset(args) 69 | -------------------------------------------------------------------------------- /prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/__init__.py -------------------------------------------------------------------------------- /prompt/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/evaluation/__init__.py -------------------------------------------------------------------------------- /prompt/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | import numpy as np 5 | from prompt.model.model import * 6 | import matplotlib.pyplot as plt 7 | import torch.nn.functional as F 8 | from fastchat.model.model_adapter import get_conversation_template 9 | from tqdm import tqdm 10 | import argparse 11 | import matplotlib.pyplot as plt 12 | 13 | # Once 14 | # 1st iteration: Once upon [a time] 15 | # 2nd iteration: Once upon a [time there] 16 | # 3rd iteration: Once upon a time [there was] 17 | def get_accuracies(approximate_ids, logit): 18 | results = [] 19 | _, _, num_special_tokens, _ = approximate_ids.shape 20 | for i in range(num_special_tokens): 21 | match = approximate_ids[:-1-i, :, i].eq(logit[1+i:, :, :1]) 22 | results.append(match) 23 | # print(match.shape) 24 | # accuracy = match.any(dim=-1).sum().float() / (match.shape[0] * match.shape[1]) 25 | # print(approximate_ids.shape, logit.shape) 26 | return results 27 | 28 | def plot_accuracies(eval_data, save_path): 29 | plt.figure() 30 | for i, data in enumerate(eval_data): 31 | results= [] 32 | for K in range(1, 11): 33 | results.append((data[:, :, :K].any(dim=-1).sum().float() / (data.shape[0] * data.shape[1])).cpu()) 34 | plt.plot(results, label=f"{i}th prediction") 35 | print(f"{i+1}th accuracy - {', '.join(['Top '+str(i+1)+' : '+str(result.item()) for i, result in enumerate(results)])}") 36 | plt.xlabel("K") 37 | plt.ylabel("Accuracy") 38 | plt.xticks(range(10)) 39 | plt.yticks(np.arange(0, 1.1, 0.1)) 40 | plt.legend() 41 | plt.show() 42 | plt.savefig(save_path) 43 | 44 | def main(args): 45 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 46 | data = json.load(open(args.data_path)) 47 | 48 | if args.eval_result_path: 49 | eval_data = torch.load(args.eval_result_path) 50 | plot_accuracies(eval_data, os.path.join(args.save_dir, args.model_name + "_accuracy.png")) 51 | return 52 | 53 | 54 | model = AutoPromptDecoder.from_pretrained( 55 | args.model_path, 56 | torch_dtype=torch.float16, 57 | low_cpu_mem_usage=True 58 | ) 59 | tokenizer = model.tokenizer 60 | model = model.to(device) 61 | 62 | config = model.active_peft_config 63 | num_special_tokens = config.num_special_tokens 64 | virtual_tokens_per_special_token = config.virtual_tokens_per_special_token 65 | total_virtual_tokens = num_special_tokens * virtual_tokens_per_special_token 66 | # TODO: KV Cache 67 | results = None 68 | 69 | for sample in tqdm((data)): 70 | conv = get_conversation_template("vicuna") 71 | conv.messages = [] 72 | conv.append_message(conv.roles[0], sample["instruction"]) 73 | conv.append_message(conv.roles[1], "") 74 | prompt = conv.get_prompt() 75 | steps = args.steps 76 | logits_ids = [] 77 | approximate_ids = [] 78 | 79 | with torch.inference_mode(): 80 | input_ids = tokenizer([prompt]).input_ids 81 | input_ids = torch.as_tensor(input_ids).to(device) 82 | outputs = model(input_ids) 83 | logits = outputs.logits 84 | pred = torch.argmax(logits[:, -num_special_tokens-1, :], dim=-1) 85 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous() 86 | _, approximate_tokens = prompt_logits.topk(10, dim=-1) 87 | # print(pred.device, input_ids.device, model.device) 88 | preds = torch.cat((input_ids, pred.unsqueeze(0)), dim=-1) 89 | # print(f"Exact token: {tokenizer.batch_decode(pred)}, approximate tokens: {tokenizer.batch_decode(approximate_tokens.squeeze(0))}") 90 | logits_ids.append(preds[:, -1:].detach()) 91 | approximate_ids.append(approximate_tokens.detach()) 92 | for _ in range(steps): 93 | outputs= model(preds) 94 | logits = outputs.logits 95 | pred = torch.argmax(logits[:, -num_special_tokens-1, :], dim=-1) 96 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous() 97 | _, approximate_tokens = prompt_logits.topk(10, dim=-1) 98 | # print(f"Exact token: {tokenizer.batch_decode(pred)}, approximate tokens: {tokenizer.batch_decode(approximate_tokens.squeeze(0))}") 99 | preds = torch.cat((preds, pred.unsqueeze(0)), dim=-1) 100 | logits_ids.append(preds[:, -1:].detach()) 101 | approximate_ids.append(approximate_tokens.detach()) 102 | logits_ids = torch.stack(logits_ids, dim=0) 103 | approximate_ids = torch.stack(approximate_ids, dim=0).squeeze(2) 104 | if results is None: 105 | results = get_accuracies(approximate_ids, logits_ids) 106 | else: 107 | # cat sub results 108 | cur_results = get_accuracies(approximate_ids, logits_ids) 109 | for i in range(len(results)): 110 | results[i] = torch.cat((results[i], cur_results[i]), dim=0) 111 | 112 | save_path = os.path.join(args.save_dir, args.model_name + "_accuracy.pt") 113 | torch.save(results, save_path) 114 | plot_accuracies(results, os.path.join(args.save_dir, args.model_name + "_accuracy.png")) 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser(description="Model Evaluator") 118 | 119 | parser.add_argument("--model_path", type=str, required=True, 120 | help="Path to the pre-trained model.") 121 | parser.add_argument("--model_name", type=str, required=True, 122 | help="Name of the model.") 123 | parser.add_argument("--data_path", type=str, required=True, 124 | help="Path to the evaluation data in JSON format.") 125 | parser.add_argument("--save_dir", type=str, default="./", 126 | help="Directory to save the results.") 127 | parser.add_argument("--steps", type=int, default=20, 128 | help="Number of steps to run the model.") 129 | parser.add_argument("--eval_result_path", type=str, default=None, required=False, 130 | help="Path to the evaluation result.") 131 | args = parser.parse_args() 132 | 133 | # If the save directory doesn't exist, create it 134 | if not os.path.exists(args.save_dir): 135 | os.makedirs(args.save_dir) 136 | main(args) -------------------------------------------------------------------------------- /prompt/hf_utils.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser("Upload PPD model to HuggingFace Hub") 5 | parser.add_argument("--folder", type=str, help="Path to model folder") 6 | parser.add_argument("--repo", type=str, help="Repo name") 7 | parser.add_argument("--private", action="store_true", help="Make repo private") 8 | 9 | args = parser.parse_args() 10 | 11 | api = HfApi() 12 | 13 | api.create_repo( 14 | repo_id=args.repo, 15 | private=args.private, 16 | exist_ok=True, 17 | ) 18 | 19 | api.upload_folder( 20 | folder_path=args.folder, 21 | repo_id=args.repo, 22 | ) -------------------------------------------------------------------------------- /prompt/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/inference/__init__.py -------------------------------------------------------------------------------- /prompt/inference/sparse_tree_builder.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | import argparse 6 | 7 | from pprint import pprint 8 | from copy import deepcopy 9 | import numpy as np 10 | import concurrent.futures 11 | 12 | 13 | def load_accuracy(file_name): 14 | eval_data = torch.load(file_name) 15 | accuracies = [] 16 | for i, data in enumerate(eval_data): 17 | results= [] 18 | for K in range(1, 11): 19 | results.append((data[:, :, :K].any(dim=-1).sum().float() / (data.shape[0] * data.shape[1])).cpu()) 20 | print(f"{i+1}th accuracy - {', '.join(['Top '+str(i+1)+' : '+str(result.item()) for i, result in enumerate(results)])}") 21 | accuracy = [results[0]] 22 | for i in range(1, len(results)): 23 | accuracy.append(results[i] - results[i-1]) 24 | accuracies.append(accuracy) 25 | accuracies = torch.tensor(accuracies) 26 | return accuracies 27 | 28 | 29 | def expected_accuracy(es, vs, print_output=False): 30 | # only calculate depth of 3 31 | e10, e20, e30 = es[0] 32 | e11, e21, e31 = es[1] 33 | e12, e22, e32 = es[2] 34 | e13, e23, e33 = es[3] 35 | v1, v2, v3 = vs 36 | a = np.array([[0, e11-1, e21, e31], 37 | [0, e12, e22-1, e32], 38 | [1, e13, e23, e33-1], 39 | [1, 1, 1, 1]]) 40 | output = np.linalg.solve(a, [0, 0, 0, 1]) 41 | if print_output: 42 | print("Expected Probability:") 43 | pprint(output) 44 | return output[1]*v1 + output[2]*v2 + output[3]*v3 45 | 46 | 47 | def find_all_candidates(parent, width, depth): 48 | if depth == 0: 49 | return [] 50 | candidates = [] 51 | for i in range(width): 52 | candidates.append(parent+[i]) 53 | candidates.extend(find_all_candidates(parent+[i], width, depth-1)) 54 | return candidates 55 | 56 | 57 | def find_optimal_sparse_tree(accuracies, num_candidates, max_depth=None): 58 | # Generate all possible candidates of varying lengths 59 | if max_depth: 60 | accuracies = accuracies[:max_depth] 61 | candidates = find_all_candidates([], accuracies.shape[1], accuracies.shape[0]) 62 | 63 | # Calculate cumulative accuracy for each candidate 64 | candidate_accuracies = [] 65 | for candidate in candidates: 66 | cumulative_accuracy = 1.0 67 | for idx, top_i in enumerate(candidate): 68 | cumulative_accuracy *= accuracies[idx, top_i] 69 | candidate_accuracies.append((cumulative_accuracy, candidate)) 70 | 71 | # Sort candidates by their cumulative accuracy in descending order and select top n 72 | top_candidates = sorted(candidate_accuracies, key=lambda x: x[0], reverse=True)[:num_candidates] 73 | 74 | # Extract just the candidate paths 75 | top_candidate_paths = [list(candidate) for _, candidate in top_candidates] 76 | top_candidate_accs = [round(acc.cpu().item(), 5) for acc, _ in top_candidates] 77 | 78 | return top_candidate_paths, top_candidate_accs 79 | 80 | 81 | def find_optimal_extended_sparse_tree(accuracies, input_length_limit): 82 | # input_length_limit = num_candidates + sum(candidate_accuracy_n * num_special_tokens_n) 83 | # n is the depth of accuracies 84 | n = accuracies.shape[0] 85 | # generate and store the optimal sparse tree for each num_candidates and each depth 86 | optimal_sparse_trees = {} 87 | optimal_sparse_tree_accuracies = {} 88 | for depth in range(1, n+1): 89 | candidates, accs = find_optimal_sparse_tree(accuracies, input_length_limit, depth) 90 | candidate_acc_pairs = list(zip(candidates, accs)) 91 | for length in range(1, input_length_limit+1): 92 | ls = [] 93 | for candidate, acc in candidate_acc_pairs[:length]: 94 | sum_children_acc = sum([a for c, a in candidate_acc_pairs[:length] if c[:-1] == candidate]) 95 | ls.append((candidate, acc - sum_children_acc)) 96 | optimal_sparse_trees[(depth, length)] = ls 97 | for size in range(1, input_length_limit+1): 98 | optimal_sparse_tree_accuracies[(depth, size)] = sum(accs[:size]) 99 | best_sparse_trees = None 100 | best_expected_acc = -1 101 | best_num_tree_nodes = None 102 | # only calculate depth of 3 for now 103 | for tree_node2 in range(input_length_limit//(n+1), input_length_limit//2 + 1): 104 | for tree_node3 in range(input_length_limit//(n+1), input_length_limit//2 + 1): 105 | tree_nodes = {1: 10, 2: tree_node2, 3: tree_node3} 106 | sparse_trees, expected_acc = find_extended_sparse_tree_fixed_tree_nodes(tree_nodes, input_length_limit, optimal_sparse_tree_accuracies, deepcopy(optimal_sparse_trees), n) 107 | if expected_acc > best_expected_acc: 108 | best_sparse_trees = sparse_trees 109 | best_expected_acc = expected_acc 110 | best_num_tree_nodes = tree_nodes 111 | print("Input limit", input_length_limit, "Tree nodes:", best_num_tree_nodes, "Expected Acc:", best_expected_acc) 112 | return best_sparse_trees 113 | 114 | 115 | def find_extended_sparse_tree_fixed_tree_nodes(tree_nodes, input_length_limit, optimal_sparse_tree_accuracies, candidate_acc_pairs, n): 116 | optimal_sparse_trees = {} 117 | for depth in range(1, n+1): 118 | num_nodes = n * len(candidate_acc_pairs[(depth, tree_nodes[depth])]) 119 | optimal_sparse_trees[depth] = [[candidate, acc, n] for candidate, acc in candidate_acc_pairs[(depth, tree_nodes[depth])]] 120 | while num_nodes > input_length_limit - tree_nodes[depth]: 121 | min_accuracy_loss = float('-inf') 122 | min_index = 0 123 | for i, (_, candidate_acc, num_special_token) in enumerate(optimal_sparse_trees[depth]): 124 | if num_special_token == 1: 125 | continue 126 | accuracy_loss = candidate_acc * (optimal_sparse_tree_accuracies[(num_special_token-1, tree_nodes[num_special_token-1])] - \ 127 | optimal_sparse_tree_accuracies[(num_special_token, tree_nodes[num_special_token])]) 128 | if accuracy_loss > min_accuracy_loss: 129 | min_accuracy_loss = accuracy_loss 130 | min_index = i 131 | if min_accuracy_loss == float('inf'): 132 | break 133 | optimal_sparse_trees[depth][min_index][2] = optimal_sparse_trees[depth][min_index][2] - 1 134 | num_nodes -= 1 135 | 136 | es = [] 137 | vs = [] 138 | for depth in range(1, n+1): 139 | e_depth = [0] * (n+1) 140 | for (_, candidate_acc, num_special_token) in optimal_sparse_trees[depth]: 141 | e_depth[num_special_token] += candidate_acc 142 | e_depth[0] = 1 - sum(e_depth[1:]) 143 | es.append(e_depth) 144 | vs.append(optimal_sparse_tree_accuracies[(depth, tree_nodes[depth])]) 145 | es = np.array(es).T.tolist() 146 | acc = expected_accuracy(es, vs) 147 | 148 | return optimal_sparse_trees, acc 149 | 150 | 151 | def sparse_tree_info(best_sparse_trees): 152 | print("Best Sparse Trees:") 153 | pprint(best_sparse_trees) 154 | 155 | es = [] 156 | vs = [] 157 | for depth, best_sparse_tree in best_sparse_trees.items(): 158 | print(f"Depth: {depth}") 159 | print("Number of tree nodes:", len(best_sparse_tree)) 160 | print("Number of special tokens:", sum([num_special_token for _, _, num_special_token in best_sparse_tree])) 161 | acc_1 = sum([accuracy for _, accuracy, num_special_token in best_sparse_tree if num_special_token == 1]) 162 | acc_2 = sum([accuracy for _, accuracy, num_special_token in best_sparse_tree if num_special_token == 2]) 163 | acc_3 = sum([accuracy for _, accuracy, num_special_token in best_sparse_tree if num_special_token == 3]) 164 | print("Probabilities to 1 special token:", acc_1) 165 | print("Probabilities to 2 special tokens:", acc_2) 166 | print("Probabilities to 3 special tokens:", acc_3) 167 | print("Probability to None", 1 - (acc_1 + acc_2 + acc_3)) 168 | 169 | 170 | def write_sparse_tree_to_file(file_name, min_input_length, max_input_length, accuracies): 171 | def task(i): 172 | # This is the task that will be executed by each thread. 173 | # It returns a tuple of (i, result) so we know which iteration it belongs to. 174 | return i, find_optimal_extended_sparse_tree(accuracies, i) 175 | 176 | # Prepare to collect the results 177 | results = [] 178 | 179 | # Using ThreadPoolExecutor to execute tasks concurrently 180 | with concurrent.futures.ThreadPoolExecutor() as executor: 181 | # Map the task function to the range of values concurrently 182 | future_to_i = {executor.submit(task, i): i for i in range(min_input_length, max_input_length+1)} 183 | 184 | for future in concurrent.futures.as_completed(future_to_i): 185 | i = future_to_i[future] 186 | try: 187 | # Collect results as they are completed 188 | results.append(future.result()) 189 | except Exception as exc: 190 | print(f'Generated an exception: {exc}') 191 | 192 | # Sorting results to ensure they are in order 193 | results.sort(key=lambda x: x[0]) 194 | 195 | # Writing results to file sequentially 196 | with open(file_name, "w") as f: 197 | for i, result in results: 198 | f.write(f"# Dynamic Sparse Trees for input length limit of {i}\n") 199 | f.write(f"dynamic_sparse_trees_{i} = {result}\n") 200 | 201 | 202 | def main(args): 203 | accuracies = load_accuracy(args.accuracies_file) 204 | write_sparse_tree_to_file(args.output_file, args.min_input_length, args.max_input_length, accuracies) 205 | 206 | 207 | if __name__ == "__main__": 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument("--accuracies-file", type=str, required=True, help="Path to the accuracies file.") 210 | parser.add_argument("--output-file", type=str, required=True, help="Path to the output file.") 211 | parser.add_argument("--min-input-length", type=int, required=True, help="Minimum input length limit.") 212 | parser.add_argument("--max-input-length", type=int, required=True, help="Maximum input length limit.") 213 | args = parser.parse_args() 214 | main(args) 215 | 216 | 217 | -------------------------------------------------------------------------------- /prompt/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/model/__init__.py -------------------------------------------------------------------------------- /prompt/model/kv_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | # adapted from Medusa: https://github.com/FasterDecoding/Medusa/blob/5e980538695096e7e372c1e27a6bcf142bfeab11/medusa/model/kv_cache.py 5 | class KVCache: 6 | """ 7 | A key-value cache for the model. 8 | 9 | This class provides a mechanism to maintain a growing cache of keys and values, 10 | particularly useful for models that benefit from caching previous states, 11 | like transformers during autoregressive decoding. 12 | 13 | Attributes: 14 | data (torch.Tensor): The tensor storing keys and values. 15 | current_length (int): Current length of the data being stored. 16 | """ 17 | 18 | def __init__(self, data: torch.Tensor, current_length: int): 19 | """ 20 | Initialize the KVCache. 21 | 22 | Args: 23 | data (torch.Tensor): Initial tensor to store the keys and values. 24 | current_length (int): Initial length of the data. 25 | normal_token_indices (torch.Tensor): Indices of the normal tokens in the data tensor. 26 | """ 27 | self.data = data 28 | self.current_length = current_length 29 | 30 | @property 31 | def shape(self): 32 | """Return the shape of the data tensor with updated length.""" 33 | return ( 34 | self.data.shape[0], 35 | self.data.shape[1], 36 | self.current_length.item(), 37 | self.data.shape[3], 38 | ) 39 | 40 | def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2): 41 | """ 42 | Copy values from the current data at specified indices to a new location. 43 | 44 | Args: 45 | indices (torch.Tensor): Indices of the data tensor to be copied. 46 | prev_length (int): Previous length before adding new data. 47 | dim (int, optional): Dimension along which copying should be performed. Default is 2. 48 | """ 49 | tgt = self.data.index_select(dim, indices) 50 | dst = self.data.narrow(dim, prev_length, tgt.shape[dim]) 51 | dst.copy_(tgt, non_blocking=True) 52 | self.current_length.fill_(prev_length + tgt.shape[dim]) 53 | 54 | def cat(self, tensor: torch.Tensor, dim: int = 2): 55 | """ 56 | Concatenate the given tensor with the current data. 57 | 58 | Args: 59 | tensor (torch.Tensor): The tensor to be concatenated. 60 | dim (int, optional): The dimension along which concatenation should be done. Default is 2. 61 | 62 | Returns: 63 | torch.Tensor: The data tensor after concatenation up to the current length. 64 | """ 65 | # if normal_token_indices is None: 66 | # dst = self.data.narrow(dim, self.current_length, tensor.shape[dim]) 67 | # dst.copy_(tensor, non_blocking=True) 68 | # self.current_length.add_(tensor.shape[dim]) 69 | # return torch.narrow(self.data, 2, 0, self.current_length) 70 | # dst = self.data.narrow(dim, self.current_length, len(normal_token_indices)) 71 | # dst.copy_(tensor.index_select(dim, normal_token_indices), non_blocking=True) 72 | # rst = torch.cat([torch.narrow(self.data, 2, 0, self.current_length), tensor], dim) 73 | # self.current_length.add_(len(normal_token_indices)) 74 | # return rst 75 | dst = self.data.narrow(dim, self.current_length, tensor.shape[dim]) 76 | dst.copy_(tensor, non_blocking=True) 77 | self.current_length.add_(tensor.shape[dim]) 78 | return torch.narrow(self.data, 2, 0, self.current_length) 79 | 80 | 81 | def initialize_past_key_values(model): 82 | """ 83 | Initialize past key and value states for a given transformer model. 84 | 85 | This function prepares key-value cache structures for the model, allowing it to store and reuse 86 | past key and value states during autoregressive decoding, which can improve efficiency. 87 | 88 | Args: 89 | model (nn.Module): The transformer model for which past key-value states need to be initialized. 90 | 91 | Returns: 92 | tuple: 93 | - past_key_values (list): A list of KVCache objects for each layer in the model. 94 | - past_key_values_data (torch.Tensor): The tensor that will store all keys and values. 95 | - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache. 96 | """ 97 | # Extracting configuration from the model 98 | config = model.config 99 | # Initializing the batch size to 1, this can be modified if different batch sizes are required 100 | batch_size = 1 101 | # Initializing a tensor to store past keys and values for all layers 102 | past_key_values_data = torch.zeros( 103 | config.num_hidden_layers * 2, 104 | batch_size, 105 | config.num_key_value_heads, 106 | # llama max_position_embeddings is 4096 instead of 2048 107 | config.max_position_embeddings*2, 108 | config.hidden_size // config.num_attention_heads, 109 | device=model.device, 110 | dtype=model.dtype, 111 | ) 112 | # Initialize tensor to store the current length of the cached data for all layers. 113 | # [IMPORTANT] It needs to be kept on CPU for quick access and updates. 114 | current_length_data = torch.zeros( 115 | config.num_hidden_layers * 2, dtype=torch.long, device="cpu" 116 | ) 117 | # Creating a KVCache for each pair of key and value in all layers 118 | past_key_values = [] * config.num_hidden_layers 119 | for i in range(config.num_hidden_layers): 120 | past_key_values.append( 121 | [ 122 | KVCache(past_key_values_data[i * 2 + j], current_length_data[i * 2 + j]) 123 | for j in range(2) 124 | ] 125 | ) 126 | return past_key_values, past_key_values_data, current_length_data -------------------------------------------------------------------------------- /prompt/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/train/__init__.py -------------------------------------------------------------------------------- /prompt/train/train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import warnings 3 | import math 4 | import pathlib 5 | from typing import Dict, Optional 6 | 7 | import torch 8 | import transformers 9 | from transformers import Trainer, BitsAndBytesConfig 10 | from transformers.trainer_pt_utils import LabelSmoother 11 | 12 | from torch.nn import CrossEntropyLoss 13 | import torch.nn.functional as F 14 | 15 | from prompt.utils import * 16 | from prompt.model.model import PromptDecoder, PromptConfig, AutoPromptDecoder 17 | from prompt.model.modeling_llama_custom import LlamaForCausalLM as CustomLlamaForCausalLM 18 | 19 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 20 | 21 | class ParamEfficientFineTuner(Trainer): 22 | def compute_loss(self, model, inputs, return_outputs=False): 23 | """ 24 | Compute the training loss for the model. 25 | 26 | Args: 27 | model (torch.nn.Module): The model for which to compute the loss. 28 | inputs (dict): The input data, including input IDs, attention mask, and labels. 29 | return_outputs (bool): Whether to return model outputs along with the loss. 30 | 31 | Returns: 32 | Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs. 33 | """ 34 | num_special_tokens = self.model.active_peft_config.num_special_tokens 35 | if torch.any(inputs["input_ids"][:, -1] == self.tokenizer.eos_token_id): 36 | warnings.warn("Input ends with EOS token.") 37 | input_ids = inputs["input_ids"] 38 | attention_mask = inputs["attention_mask"] 39 | labels = inputs["labels"] 40 | 41 | outputs = model( 42 | input_ids=input_ids, attention_mask=attention_mask 43 | ) 44 | logits = outputs.logits 45 | 46 | # Calculate loss on the prompt tokens 47 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous() 48 | prompt_labels = labels[..., -num_special_tokens:].contiguous() 49 | prompt_labels = prompt_labels.to(logits.device) 50 | loss = 0 51 | loss_fn = CrossEntropyLoss() 52 | decay_coefficient = 0.8 53 | for i in range(num_special_tokens): 54 | loss += loss_fn(prompt_logits[:, i, :], prompt_labels[:, i]) * (decay_coefficient ** i) 55 | if num_special_tokens > 0: 56 | loss /= num_special_tokens 57 | return (loss, outputs) if return_outputs else loss 58 | 59 | 60 | class DistillationTrainer(Trainer): 61 | def compute_loss(self, model, inputs, return_outputs=False): 62 | """ 63 | Compute the training loss for the model. 64 | 65 | Args: 66 | model (torch.nn.Module): The model for which to compute the loss. 67 | inputs (dict): The input data, including input IDs, attention mask, and labels. 68 | return_outputs (bool): Whether to return model outputs along with the loss. 69 | 70 | Returns: 71 | Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs. 72 | """ 73 | num_special_tokens = self.model.active_peft_config.num_special_tokens 74 | if torch.any(inputs["input_ids"][:, -1] == self.tokenizer.eos_token_id): 75 | warnings.warn("Input ends with EOS token.") 76 | input_ids = inputs["input_ids"] 77 | attention_mask = inputs["attention_mask"] 78 | labels = inputs["labels"] 79 | 80 | outputs = model( 81 | input_ids=input_ids, attention_mask=attention_mask 82 | ) 83 | logits = outputs.logits 84 | 85 | # Calculate loss on the prompt tokens 86 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous() 87 | prompt_labels = labels.contiguous() 88 | prompt_labels = prompt_labels.to(logits.device) 89 | loss = 0 90 | decay_coefficient = 0.8 91 | for i in range(num_special_tokens): 92 | loss_i = F.kl_div( 93 | F.log_softmax(prompt_logits[:, i, :], dim=-1), 94 | F.softmax(prompt_labels[:, i, :], dim=-1), 95 | reduction='sum' 96 | ) / prompt_logits.shape[0] 97 | loss += loss_i * (decay_coefficient ** i) 98 | if num_special_tokens > 0: 99 | loss /= num_special_tokens 100 | return (loss, outputs) if return_outputs else loss 101 | 102 | 103 | @dataclass 104 | class ModelArguments: 105 | model_name_or_path: str = field(default="lmsys/vicuna-7b-v1.3") 106 | num_special_tokens: int = field(default=1) 107 | virtual_tokens_per_special_token: int = field(default=1) 108 | use_custom_lm_head: bool = field(default=False) 109 | use_prefix_tuning: bool = field(default=False) 110 | prefix_virtual_tokens: int = field(default=10) 111 | vt_attention_type: str = field(default="decoder") 112 | aggregation_type: str = field(default="mean") 113 | num_exits: int = field(default=1) 114 | load_in_4bit: bool = field( 115 | default=False, 116 | metadata={"help": "Load in 4 bit."}, 117 | ) 118 | load_in_8bit: bool = field( 119 | default=False, 120 | metadata={"help": "Load in 8 bit."}, 121 | ) 122 | 123 | 124 | @dataclass 125 | class DataArguments: 126 | dataset_path: Optional[str] = field( 127 | default=None, metadata={"help": "Path to the saved dataset."}, 128 | ) 129 | size: Optional[int] = field( 130 | default=None, metadata={"help": "Number of examples to use."} 131 | ) 132 | use_chunked: bool = field( 133 | default=False, metadata={"help": "Whether to use chunked dataset."} 134 | ) 135 | 136 | 137 | @dataclass 138 | class TrainingArguments(transformers.TrainingArguments): 139 | cache_dir: str = field(default=None) 140 | optim: str = field(default="adamw_torch") 141 | trainer_type: str = field(default="param_efficient_fine_tuner", metadata={"help": "Trainer type: param_efficient_fine_tuner, distillation_trainer"}) 142 | stage1_model_path: Optional[str] = field( 143 | default=None, 144 | metadata={"help": "Path to the stage 1 model."}, 145 | ) 146 | lm_head_lr_multiplier: float = field(default=0.1) 147 | model_max_length: int = field( 148 | default=1024, 149 | metadata={ 150 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 151 | }, 152 | ) 153 | 154 | 155 | 156 | def train(): 157 | parser = transformers.HfArgumentParser( 158 | (ModelArguments, DataArguments, TrainingArguments) 159 | ) 160 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 161 | 162 | quantization_config = BitsAndBytesConfig( 163 | load_in_4bit=True, 164 | bnb_4bit_compute_dtype=torch.bfloat16, 165 | bnb_4bit_use_double_quant=True, 166 | bnb_4bit_quant_type="nf4", 167 | ) 168 | 169 | print("load_in_4_bits", model_args.load_in_4bit) 170 | 171 | # Create model 172 | peft_config = PromptConfig( 173 | tokenizer_name_or_path=model_args.model_name_or_path, 174 | base_model_name_or_path=model_args.model_name_or_path, 175 | num_special_tokens=model_args.num_special_tokens, 176 | virtual_tokens_per_special_token=model_args.virtual_tokens_per_special_token, 177 | use_prefix_tuning=model_args.use_prefix_tuning, 178 | prefix_virtual_tokens=model_args.prefix_virtual_tokens, 179 | vt_attention_type=VTAttentionType.from_str(model_args.vt_attention_type), 180 | aggregation_type=AggregationType.from_str(model_args.aggregation_type), 181 | use_custom_lm_head=model_args.use_custom_lm_head, 182 | num_exits=model_args.num_exits, 183 | ) 184 | if training_args.stage1_model_path: 185 | model = AutoPromptDecoder.from_pretrained( 186 | training_args.stage1_model_path, 187 | low_cpu_mem_usage=True, 188 | cache_dir=training_args.cache_dir, 189 | quantization_config=quantization_config if model_args.load_in_4bit else None, 190 | new_config=peft_config, 191 | ) 192 | else: 193 | # Set RoPE scaling factor 194 | config = transformers.AutoConfig.from_pretrained( 195 | model_args.model_name_or_path, 196 | cache_dir=training_args.cache_dir, 197 | ) 198 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 199 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 200 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 201 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 202 | config.use_cache = False 203 | 204 | config = transformers.AutoConfig.from_pretrained( 205 | model_args.model_name_or_path, 206 | cache_dir=training_args.cache_dir 207 | ) 208 | 209 | if config.model_type == "llama": 210 | base_model = CustomLlamaForCausalLM.from_pretrained( 211 | model_args.model_name_or_path, 212 | config=config, 213 | cache_dir=training_args.cache_dir, 214 | low_cpu_mem_usage=True, 215 | quantization_config=quantization_config if model_args.load_in_4bit else None, 216 | # load_in_4bit=model_args.load_in_4bit, 217 | # load_in_8bit=model_args.load_in_8bit, 218 | ) 219 | else: 220 | raise ValueError("Only support llama for now") 221 | 222 | for param in base_model.base_model.parameters(): 223 | param.requires_grad = False 224 | model = PromptDecoder(base_model, peft_config) 225 | print(model.print_trainable_parameters(), model) 226 | 227 | # Output dir 228 | training_args.output_dir = f"{training_args.output_dir}/prompt_{model_args.model_name_or_path.split('/')[-1]}_{model_args.num_special_tokens}_{model_args.virtual_tokens_per_special_token}_cl{training_args.model_max_length}_{model_args.vt_attention_type.upper()}_{model_args.aggregation_type}{'_custom_lm_head' if model_args.use_custom_lm_head else ''}{'_prefix' + str(model_args.prefix_virtual_tokens) if model_args.use_prefix_tuning else ''}_exits{model_args.num_exits}" 229 | 230 | tokenizer = transformers.AutoTokenizer.from_pretrained( 231 | model_args.model_name_or_path, 232 | cache_dir=training_args.cache_dir, 233 | model_max_length=training_args.model_max_length, 234 | padding_side="left", 235 | use_fast=False, 236 | truncation=True 237 | ) 238 | tokenizer.pad_token = tokenizer.unk_token 239 | 240 | # Load data 241 | if data_args.use_chunked: 242 | data = ChunkDataset(data_args.dataset_path) 243 | else: 244 | data = torch.load(data_args.dataset_path) 245 | data.set_size(data_args.size) 246 | 247 | # Set up optimizer 248 | optimizer_grouped_parameters = [ 249 | { 250 | "params": [ 251 | p for n, p in model.named_parameters() if (p.requires_grad and "lm_head" in n) 252 | ], 253 | "lr": training_args.learning_rate * training_args.lm_head_lr_multiplier, 254 | "weight_decay": training_args.weight_decay, 255 | }, 256 | { 257 | "params": [ 258 | p for n, p in model.named_parameters() if (p.requires_grad and "prompt_encoder" in n) 259 | ], 260 | "lr": training_args.learning_rate, 261 | "weight_decay": training_args.weight_decay, 262 | }, 263 | ] 264 | optim_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) 265 | optimizer = optim_cls(optimizer_grouped_parameters, **optim_kwargs) 266 | 267 | # Start trainner 268 | if training_args.trainer_type == "distillation_trainer": 269 | trainer = DistillationTrainer( 270 | model=model, tokenizer=tokenizer, args=training_args, train_dataset=data, eval_dataset=None, optimizers=(optimizer, None) 271 | ) 272 | elif training_args.trainer_type == "param_efficient_fine_tuner": 273 | trainer = ParamEfficientFineTuner( 274 | model=model, tokenizer=tokenizer, args=training_args, train_dataset=data, eval_dataset=None, optimizers=(optimizer, None) 275 | ) 276 | else: 277 | raise ValueError(f"Trainer type {training_args.trainer_type} not supported.") 278 | 279 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 280 | print("Resuming training...") 281 | trainer.train(resume_from_checkpoint=True) 282 | else: 283 | trainer.train() 284 | 285 | # Save model 286 | model.save_pretrained(training_args.output_dir) 287 | 288 | if __name__ == "__main__": 289 | train() -------------------------------------------------------------------------------- /prompt/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import json 3 | import os 4 | from enum import Enum 5 | from typing import Dict, Optional, Sequence 6 | 7 | from tqdm import tqdm 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset, DataLoader 11 | import transformers 12 | from transformers.trainer_pt_utils import LabelSmoother 13 | from fastchat.conversation import SeparatorStyle 14 | from fastchat.model.model_adapter import get_conversation_template 15 | 16 | import prompt.inference.dynamic_sparse_trees_3_vicuna_13b as dynamic_sparse_trees_3_vicuna_13b 17 | import prompt.inference.dynamic_sparse_trees_3_vicuna_7b as dynamic_sparse_trees_3_vicuna_7b 18 | import prompt.inference.dynamic_sparse_trees_3_MobileLLaMA as dynamic_sparse_trees_3_mobilellama 19 | 20 | 21 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 22 | 23 | 24 | class VTAttentionType(str, Enum): 25 | """Attention type for VicunaTuning 26 | """ 27 | DECODER = "decoder" 28 | ENCODER = "encoder" 29 | ENSEMBLE = "ensemble" 30 | 31 | def __str__(self): 32 | return self.value 33 | 34 | @staticmethod 35 | def from_str(s): 36 | s = s.lower() 37 | if s == "decoder": 38 | return VTAttentionType.DECODER 39 | elif s == "encoder": 40 | return VTAttentionType.ENCODER 41 | elif s == "ensemble": 42 | return VTAttentionType.ENSEMBLE 43 | else: 44 | raise ValueError(f"Invalid attention type: {s}") 45 | 46 | 47 | class AggregationType(str, Enum): 48 | """Aggregation type for VicunaTuning 49 | """ 50 | MEAN = "mean" 51 | WEIGHTED = "weighted" 52 | ADAPTIVAE_WEIGHTED = "adaptive_weighted" 53 | 54 | def __str__(self): 55 | return self.value 56 | 57 | @staticmethod 58 | def from_str(s): 59 | s = s.lower() 60 | if s == "mean": 61 | return AggregationType.MEAN 62 | elif s == "weighted": 63 | return AggregationType.WEIGHTED 64 | elif s == "adaptive_weighted": 65 | return AggregationType.ADAPTIVAE_WEIGHTED 66 | else: 67 | raise ValueError(f"Invalid aggregation type: {s}") 68 | 69 | 70 | def preprocess( 71 | sources, 72 | tokenizer: transformers.PreTrainedTokenizer, 73 | ) -> Dict: 74 | """ 75 | Preprocesses conversation data and tokenizes it for model input. 76 | 77 | Args: 78 | sources: A list of conversation sources. 79 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for tokenization. 80 | 81 | Returns: 82 | Dict: A dictionary containing tokenized inputs, labels, and attention mask. 83 | """ 84 | conv = get_conversation_template("vicuna") 85 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 86 | 87 | # Apply prompt templates 88 | conversations = [] 89 | for i, source in enumerate(sources): 90 | if roles[source[0]["from"]] != conv.roles[0]: 91 | # Skip the first one if it is not from human 92 | source = source[1:] 93 | 94 | conv.messages = [] 95 | for j, sentence in enumerate(source): 96 | role = roles[sentence["from"]] 97 | assert role == conv.roles[j % 2], f"{i}, {j}, {role}, {conv.roles[j % 2]}" 98 | conv.append_message(role, sentence["value"]) 99 | conversations.append(conv.get_prompt()) 100 | 101 | # Tokenize conversations 102 | input_ids = tokenizer( 103 | conversations, 104 | return_tensors="pt", 105 | padding="max_length", 106 | max_length=tokenizer.model_max_length, 107 | truncation=True 108 | ).input_ids 109 | targets = input_ids.clone() 110 | 111 | assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO 112 | 113 | # Mask targets. Only compute loss on the assistant outputs. 114 | sep = conv.sep + conv.roles[1] + ": " 115 | # print("sep", sep) 116 | for conversation, target in zip(conversations, targets): 117 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 118 | 119 | turns = conversation.split(conv.sep2) 120 | # the number of preceding padding tokens 121 | cur_len = 1 122 | for p in target: 123 | if p == tokenizer.pad_token_id: 124 | cur_len += 1 125 | else: 126 | break 127 | target[:cur_len] = IGNORE_TOKEN_ID 128 | # target_imm = target.clone() 129 | # target_imm[target_imm == -100] = 0 130 | # print("target1", tokenizer.decode(target_imm)) 131 | for i, turn in enumerate(turns): 132 | if turn == "": 133 | break 134 | turn_len = len(tokenizer(turn).input_ids) 135 | 136 | parts = turn.split(sep) 137 | if len(parts) != 2: 138 | break 139 | parts[0] += sep 140 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct. 141 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 142 | 143 | # Ignore the user instructions 144 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID 145 | # print(cur_len, cur_len + instruction_len) 146 | # target_imm = target.clone() 147 | # target_imm[target_imm == -100] = 0 148 | # print("target2", tokenizer.decode(target_imm)) 149 | cur_len += turn_len 150 | 151 | target[cur_len:] = IGNORE_TOKEN_ID 152 | 153 | if cur_len < tokenizer.model_max_length: 154 | if cur_len != total_len: 155 | target[:] = IGNORE_TOKEN_ID 156 | 157 | # a= (input_ids[0, :] != targets[0, :]).nonzero(as_tuple=False) 158 | # print("input_ids compare to targets", a) 159 | # print("targets compare to input_ids", a.shape) 160 | # print(targets[0, input_ids[0, :] != targets[0, :]]) 161 | return dict( 162 | input_ids=input_ids, 163 | labels=targets, 164 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 165 | ) 166 | 167 | 168 | 169 | class FineTuningDataset(Dataset): 170 | """Dataset for fine-tuning. 171 | 172 | Args: 173 | raw_data (list): A list of raw data examples. 174 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. 175 | """ 176 | 177 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, offset): 178 | super(FineTuningDataset, self).__init__() 179 | 180 | sources = [example["conversations"] for example in raw_data] 181 | data_dict = preprocess(sources, tokenizer) 182 | block_indices = find_last_positive_block(data_dict["labels"], IGNORE_TOKEN_ID, offset) 183 | input_ids, attention_mask, labels = randomly_truncate(data_dict["input_ids"], 184 | data_dict["attention_mask"], 185 | data_dict["labels"], 186 | block_indices, 187 | offset, 188 | tokenizer.pad_token_id, 189 | IGNORE_TOKEN_ID) 190 | 191 | self.input_ids = input_ids 192 | self.labels = labels 193 | self.attention_mask = attention_mask 194 | 195 | def __len__(self): 196 | return len(self.input_ids) 197 | 198 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 199 | return dict( 200 | input_ids=self.input_ids[i], 201 | labels=self.labels[i], 202 | attention_mask=self.attention_mask[i], 203 | ) 204 | 205 | def set_size(self, size): 206 | if size is None: 207 | return 208 | self.input_ids = self.input_ids[:size] 209 | self.labels = self.labels[:size] 210 | self.attention_mask = self.attention_mask[:size] 211 | 212 | 213 | def get_finetune_dataset( 214 | tokenizer: transformers.PreTrainedTokenizer, data_path, size: Optional[int] = None, offset=0 215 | ) -> Dict: 216 | """Make dataset and collator for supervised fine-tuning. 217 | 218 | Args: 219 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. 220 | data_args: Data arguments. 221 | 222 | Returns: 223 | dict: A dictionary containing train and eval datasets. 224 | """ 225 | 226 | json_file = json.load(open(data_path, "r")) 227 | size = size or len(json_file) 228 | dataset = FineTuningDataset(json_file[:size], tokenizer=tokenizer, offset=offset) 229 | return dataset 230 | 231 | 232 | def find_last_positive_block(A, ignored_id, n): 233 | """ 234 | Find the start and end index of the last block of positive numbers of at least size n in each row of A. 235 | 236 | Args: 237 | - A (torch.Tensor): Input tensor of shape [N, L] 238 | - n (int): Minimum size of the block 239 | 240 | Returns: 241 | - torch.Tensor: Tensor of shape [N, 2] containing start and end indices of the last block of positive numbers of at least size n 242 | """ 243 | N, L = A.shape 244 | indices = torch.full((N, 2), -1, dtype=torch.long) # Initialize with -1 245 | 246 | for i in range(N): 247 | last_pos_end = -1 248 | block_size = 0 249 | 250 | for j in range(L-1, -1, -1): 251 | if A[i, j] != ignored_id: 252 | if last_pos_end == -1: 253 | last_pos_end = j # Mark the end of a positive block 254 | block_size += 1 255 | else: 256 | if last_pos_end != -1: 257 | if block_size >= n: 258 | indices[i, 0] = j + 1 # Start of the last positive block 259 | indices[i, 1] = last_pos_end 260 | break 261 | else: 262 | # Reset for next block search 263 | last_pos_end = -1 264 | block_size = 0 265 | if j == 0 and last_pos_end != -1 and block_size >= n: 266 | indices[i, 0] = 0 267 | indices[i, 1] = last_pos_end 268 | 269 | return indices 270 | 271 | 272 | def randomly_truncate(input_ids, attention_mask, labels, positions, k, pad_token_id=0, IGNORE_TOKEN_ID=IGNORE_TOKEN_ID): 273 | N, L = input_ids.shape 274 | # Initialize the tensor that will hold the truncated sequences 275 | truncated_batch = torch.full_like(input_ids, pad_token_id) 276 | truncated_attention_mask = torch.full_like(attention_mask, 0) 277 | truncated_labels = torch.full_like(labels, IGNORE_TOKEN_ID) 278 | 279 | for i in range(N): 280 | start, end = positions[i] 281 | # The cut has to leave at least k elements truncated, so we adjust the end accordingly 282 | # Also, ensure the cut is at least at the start position or further to the right 283 | if start == -1 or end == -1: 284 | cut = L-k 285 | else: 286 | valid_end = max(start + 1, end - k + 1) 287 | # Randomly choose a cut point from start to the valid_end 288 | cut = torch.randint(start, valid_end, (1,)).item() 289 | # print(start, cut, L-cut) 290 | # Truncate the sequence and pad from the left 291 | try: 292 | truncated_batch[i, L-cut:] = input_ids[i, :cut] 293 | truncated_attention_mask[i, L-cut:] = attention_mask[i, :cut] 294 | truncated_labels[i, L-cut-k:] = labels[i, :cut+k] 295 | except: 296 | print(valid_end, cut, start, end) 297 | print(i, L-cut, cut, L, input_ids[i, :cut].shape, truncated_batch[i, L-cut:].shape) 298 | print(i, L-cut, cut, L, attention_mask[i, :cut].shape, truncated_attention_mask[i, L-cut:].shape) 299 | print(i, L-cut-k, cut+k, L, labels[i, :cut+k].shape, truncated_labels[i, L-cut-k:].shape) 300 | raise Exception("Error in truncation") 301 | 302 | return truncated_batch, truncated_attention_mask, truncated_labels 303 | 304 | 305 | class DistillationDataset(Dataset): 306 | """Dataset for fine-tuning. 307 | 308 | Args: 309 | data (list): A list of data containing input_ids, labels, and attention_mask. 310 | """ 311 | 312 | def __init__(self, data): 313 | super(DistillationDataset, self).__init__() 314 | self.input_ids = [d["input_ids"] for d in data] 315 | self.labels = [d["labels"] for d in data] 316 | self.attention_mask = [d["attention_mask"] for d in data] 317 | 318 | def __len__(self): 319 | return len(self.input_ids) 320 | 321 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 322 | return dict( 323 | input_ids=self.input_ids[i], 324 | labels=self.labels[i], 325 | attention_mask=self.attention_mask[i], 326 | ) 327 | 328 | def set_size(self, size): 329 | if size is None: 330 | return 331 | self.input_ids = self.input_ids[:size] 332 | self.labels = self.labels[:size] 333 | self.attention_mask = self.attention_mask[:size] 334 | 335 | 336 | def get_self_distillation_dataset(model, data_path, num_special_tokens): 337 | dataset = torch.load(data_path) 338 | dataloader = DataLoader(dataset, batch_size=4, shuffle=False) 339 | data = [] 340 | model.eval() 341 | # dataloader is faster but batched input need more memory 342 | for batch in tqdm(dataloader): 343 | input_ids = batch["input_ids"] 344 | attention_mask = batch["attention_mask"] 345 | batch_size, seq_length = input_ids.shape 346 | preds = input_ids.clone() 347 | batch_labels = [] 348 | 349 | for j in range(num_special_tokens+1): 350 | with torch.inference_mode(): 351 | outputs = model(input_ids=preds, attention_mask=attention_mask) 352 | logits = outputs.logits 353 | input_id = logits[:, -1:, :].argmax(-1) 354 | 355 | if j > 0: 356 | batch_labels.append(logits[:, -1, :]) 357 | 358 | preds = torch.cat([preds, input_id], dim=1) 359 | attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1).to(attention_mask.device)], dim=1) 360 | 361 | labels = torch.stack(batch_labels, dim=1) 362 | for i in range(batch_size): 363 | data.append({"input_ids": preds[i, :-num_special_tokens-1], "labels": labels[i], "attention_mask": attention_mask[i, :-num_special_tokens-1]}) 364 | return DistillationDataset(data) 365 | 366 | 367 | def chunk_dataset(dataset_path, chunk_size, output_dir): 368 | dataset = torch.load(dataset_path) 369 | total_size = len(dataset) 370 | print(f"Total size: {total_size}") 371 | for i in tqdm(range(0, total_size, chunk_size)): 372 | chunk = dataset[i:i+chunk_size] 373 | torch.save(chunk, os.path.join(output_dir, f'dataset_chunk_{i//chunk_size}.pt')) 374 | 375 | 376 | class ChunkDataset(Dataset): 377 | def __init__(self, chunk_dir): 378 | super(ChunkDataset, self).__init__() 379 | self.chunk_dir = chunk_dir 380 | # List all chunk files 381 | self.chunk_files = [os.path.join(chunk_dir, f) for f in os.listdir(chunk_dir) if f.startswith('dataset_chunk_')] 382 | self.chunk_files.sort(key=lambda x: (len(x), x)) 383 | # Calculate offsets and total length 384 | self.lengths = [torch.load(f, map_location='cpu')['input_ids'].__len__() for f in self.chunk_files] 385 | self.cumulative_lengths = [sum(self.lengths[:i+1]) for i in range(len(self.lengths))] 386 | 387 | def __len__(self): 388 | return self.cumulative_lengths[-1] 389 | 390 | def __getitem__(self, idx): 391 | # Find which chunk contains the item 392 | chunk_idx = next(i for i, total in enumerate(self.cumulative_lengths) if total > idx) 393 | if chunk_idx > 0: 394 | idx -= self.cumulative_lengths[chunk_idx-1] # Adjust index relative to the chunk 395 | 396 | # Load the chunk 397 | chunk = torch.load(self.chunk_files[chunk_idx], map_location='cpu') 398 | 399 | # Extract and return the item 400 | return dict( 401 | input_ids=chunk['input_ids'][idx], 402 | labels=chunk['labels'][idx], 403 | attention_mask=chunk['attention_mask'][idx], 404 | ) 405 | 406 | 407 | def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha): 408 | original_logit = logit.clone() 409 | logit = logit / temperature 410 | probs = torch.softmax(logit, dim=-1) 411 | entropy = -torch.sum( 412 | probs * torch.log(probs + 1e-5), dim=-1 413 | ) 414 | threshold = torch.minimum( 415 | torch.ones_like(entropy) * posterior_threshold, 416 | torch.exp(-entropy) * posterior_alpha, 417 | ) 418 | indices_to_remove = probs < threshold.unsqueeze(-1) 419 | logit[indices_to_remove] = float('-inf') 420 | prob = F.softmax(logit, dim=-1) 421 | try: 422 | sampled_tokens = torch.multinomial(prob, 1) 423 | except: 424 | print(prob.max(), prob.min()) 425 | print(logit.max(), logit.min()) 426 | print(original_logit.max(), original_logit.min()) 427 | print(temperature, original_logit.max()/ temperature, original_logit.min()/ temperature) 428 | print(indices_to_remove.any()) 429 | raise Exception("Error in sampling") 430 | return sampled_tokens 431 | 432 | 433 | def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha): 434 | logits = logits[:, :-1] / temperature 435 | n_samples, n_tokens = logits.shape[0], logits.shape[1] 436 | logits = logits.view(n_samples*n_tokens, -1) 437 | probs = F.softmax(logits, dim=-1) 438 | entropy = -torch.sum( 439 | probs * torch.log(probs + 1e-5), dim=-1 440 | ) 441 | threshold = torch.minimum( 442 | torch.ones_like(entropy) * posterior_threshold, 443 | torch.exp(-entropy) * posterior_alpha, 444 | ) 445 | indices_to_remove = probs < threshold.unsqueeze(-1) 446 | logits[indices_to_remove] = float('-1e4') 447 | sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1) 448 | sampled_tokens = sampled_tokens.view(n_samples, n_tokens) 449 | posterior_mask = (candidates[:, 1:] == sampled_tokens).int() 450 | return posterior_mask 451 | 452 | 453 | def pad_path(path, length, pad_value=-2): 454 | """ 455 | Pad the given path list with a specific value up to a specified length. 456 | 457 | Parameters: 458 | - path (list): The original list that needs padding. 459 | - length (int): The desired length of the padded list. 460 | - pad_value (optional, default=-2): The value to use for padding. 461 | 462 | Returns: 463 | - list: A new list based on the original path but padded to the desired length. 464 | 465 | Example: 466 | >>> pad_path([1,2,3], 5) 467 | [1, 2, 3, -2, -2] 468 | 469 | Note: 470 | If the given path is already longer than the specified length, 471 | then no padding occurs, and the original path is returned. 472 | """ 473 | 474 | # Calculate the number of padding values needed by subtracting the length 475 | # of the path from the desired length. 476 | # Append the padding values to the original path and return the new list. 477 | return path + [pad_value] * (length - len(path)) 478 | 479 | def get_dynamic_sparse_tree(model_path): 480 | if 'vicuna-13b' in model_path.lower(): 481 | print('Using 13b 3-1 sparse trees') 482 | tree = dynamic_sparse_trees_3_vicuna_13b.dynamic_sparse_trees_60 483 | elif 'vicuna-7b' in model_path.lower(): 484 | print('Using 7b 3-1 sparse trees') 485 | tree = dynamic_sparse_trees_3_vicuna_7b.dynamic_sparse_trees_105 486 | elif 'mobilellama' in model_path.lower(): 487 | print('Using MobileLLaMA 3-1 sparse trees') 488 | tree = dynamic_sparse_trees_3_mobilellama.dynamic_sparse_trees_285 489 | else: 490 | raise ValueError("Unknown model path") 491 | return tree -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["poetry-core"] 3 | build-backend = "poetry.core.masonry.api" 4 | 5 | [tool.poetry] 6 | name = "parallel-prompt-decoding" 7 | version = "1.0" 8 | description = "Cheap and efficient LLM inference Acceleration using Prompting" 9 | packages = [ 10 | { include = "prompt", from = "." } 11 | ] 12 | authors = ["Mark (Hao) Chen "] 13 | readme = "README.md" 14 | license = "Apache-2.0" 15 | keywords = ["LLM", "Prompting", "Inference Acceleration", "NLP", "Machine Learning", "Language Model"] 16 | 17 | [tool.poetry.dependencies] 18 | python = ">=3.9" 19 | fschat = "^0.2.36" 20 | torch = "^2.0.1" 21 | transformers = "4.37.2" 22 | accelerate = "^0.27.2" 23 | peft = "^0.8.0" 24 | datasets = ">=2.17.0" 25 | numpy = ">=1.26.0" 26 | bitsandbytes = "^0.42.0" 27 | setuptools = "*" 28 | sentencepiece = "*" 29 | protobuf = "^4.25.3" 30 | matplotlib = "*" 31 | gradio = "*" 32 | openai = "*" 33 | anthropic = "*" 34 | zeus-ml = "*" 35 | human_eval = "*" 36 | -------------------------------------------------------------------------------- /script/eval/eval-2-1-ensemble.sh: -------------------------------------------------------------------------------- 1 | python3 prompt/evaluation/eval.py --model_path hmarkc/ppd-vicuna-7b-v1.3\ 2 | --model_name ppd-vicuna-7b-v1.3 \ 3 | --data_path ./data/alpaca_eval.json \ 4 | --save_dir ./log/eval/ \ 5 | -------------------------------------------------------------------------------- /script/latency/optimal-sparse-tree.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # use the alpaca eval dataset to find the optimal sparse tree 3 | # Accept length for Vicuna 7b 4 | python accept_length.py \ 5 | --dir-path ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-7b \ 6 | --file-name dynamic_sparse_tree \ 7 | --model-name hmarkc/ppd-vicuna-7b-v1.3 \ 8 | --eval-file-name gen_model_answer_prompt_decoding.py \ 9 | --run-baseline \ 10 | --n 1 \ 11 | --max-length 79 \ 12 | --min-length 60 \ 13 | --length-interval 9 \ 14 | --choices "[5, 10, 20, 35, 60, 120, 200, 500]" \ 15 | 16 | 17 | # latency for Vicuna 7b 18 | python3 tree_latency.py \ 19 | --model-path hmarkc/ppd-vicuna-7b-v1.3 \ 20 | --model-id vicuna_faster \ 21 | --answer-file ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-7b-21-04/tree_latency.jsonl \ 22 | --bench-name alpaca_eval \ 23 | --min-tree-length 60 \ 24 | --max-tree-length 120 \ 25 | --length-interval 3 \ 26 | --max-new-token 1024 27 | 28 | # Accept length for Vicuna 13b 29 | python accept_length.py \ 30 | --dir-path ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-13b \ 31 | --file-name dynamic_sparse_tree \ 32 | --model-name hmarkc/ppd-vicuna-13b-v1.3\ 33 | --eval-file-name gen_model_answer_prompt_decoding.py \ 34 | --max-length 120 \ 35 | --min-length 60 \ 36 | --length-interval 3 \ 37 | --n 1 \ 38 | 39 | # latency for Vicuna 13b 40 | python3 tree_latency.py \ 41 | --model-path hmarkc/ppd-vicuna-13b-v1.3 \ 42 | --model-id vicuna_faster \ 43 | --answer-file ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-13b/tree_latency.jsonl \ 44 | --bench-name alpaca_eval \ 45 | --min-tree-length 60 \ 46 | --max-tree-length 120 \ 47 | --length-interval 3 \ 48 | --max-new-token 1024 49 | 50 | # Accept length for full sparse tree 51 | # python accept_length.py \ 52 | # --dir-path data/alpaca_eval/sparse_tree_search/3-1-7b/ \ 53 | # --file-name full_sparse_tree \ 54 | # --model-name hmarkc/ppd-vicuna-7b-v1.3 \ 55 | # --eval-file-name gen_model_answer_full_sparse_tree.py \ 56 | # --choices "[5, 10, 20, 35, 60, 120, 200, 500]" \ 57 | # --n 1 \ 58 | 59 | # python accept_length.py \ 60 | # --dir-path data/alpaca_eval/sparse_tree_search/3-1-13b/ \ 61 | # --file-name sparse_tree \ 62 | # --model-name hmarkc/ppd-vicuna-13b-v1.3 \ 63 | # --eval-file-name gen_model_answer_full_sparse_tree.py \ 64 | # --max-length 120 \ 65 | # --min-length 60 \ 66 | # --length-interval 3 \ 67 | # --n 1 \ 68 | 69 | # Accept length for random sparse tree 70 | python accept_length.py \ 71 | --dir-path data/alpaca_eval/random_tree_search/3-1-7b/ \ 72 | --file-name random_sparse_tree \ 73 | --model-name hmarkc/ppd-vicuna-7b-v1.3 \ 74 | --eval-file-name gen_model_answer_random_sparse_tree.py \ 75 | --choices "[5, 10, 20, 35, 60, 120, 200, 500]" \ 76 | --n 1 \ 77 | 78 | # Accept length for MobileLLaMA 79 | python accept_length.py \ 80 | --dir-path ./data/alpaca_eval/dynamic_sparse_tree_search/MobileLLaMA \ 81 | --file-name dynamic_sparse_tree \ 82 | --model-name ../test/MobileLLaMA \ 83 | --eval-file-name gen_model_answer_prompt_decoding.py \ 84 | --n 1 \ 85 | --max-length 79 \ 86 | --min-length 60 \ 87 | --length-interval 9 \ 88 | --choices "[75, 105, 135, 165, 195, 225, 255, 285]" \ 89 | 90 | 91 | # latency for MobileLLaMA 92 | python3 tree_latency.py \ 93 | --model-path ../test/MobileLLaMA \ 94 | --model-id MobileLLaMA \ 95 | --answer-file ./data/alpaca_eval/dynamic_sparse_tree_search/MobileLLaMA/tree_latency.jsonl \ 96 | --bench-name alpaca_eval \ 97 | --min-tree-length 60 \ 98 | --max-tree-length 285 \ 99 | --length-interval 3 \ 100 | --max-new-token 1024 -------------------------------------------------------------------------------- /script/latency/vicuna-13b-gen.sh: -------------------------------------------------------------------------------- 1 | # MT bench, temperature sampling 2 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/mt_bench/experiments/vicuna-13b-baseline1.jsonl --bench-name mt_bench 3 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/mt_bench/experiments/vicuna-13b-baseline2.jsonl --bench-name mt_bench 4 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/mt_bench/experiments/vicuna-13b-baseline3.jsonl --bench-name mt_bench 5 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-13b-faster1.jsonl --tree-length 60 --bench-name mt_bench 6 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-13b-faster2.jsonl --tree-length 60 --bench-name mt_bench 7 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-13b-faster3.jsonl --tree-length 60 --bench-name mt_bench 8 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-13b-faster4.jsonl --tree-length 60 --bench-name mt_bench 9 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-13b-faster5.jsonl --tree-length 60 --bench-name mt_bench 10 | 11 | # MT bench, greedy sampling 12 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/mt_bench/experiments/vicuna-13b-baseline1-greedy.jsonl --bench-name mt_bench --temperature 0.0 13 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/mt_bench/experiments/vicuna-13b-baseline2-greedy.jsonl --bench-name mt_bench --temperature 0.0 14 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/mt_bench/experiments/vicuna-13b-baseline3-greedy.jsonl --bench-name mt_bench --temperature 0.0 15 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-13b-faster1-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0 16 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-13b-faster2-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0 17 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-13b-faster3-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0 18 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-13b-faster4-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0 19 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-13b-faster5-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0 20 | 21 | # HumanEval 22 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/humaneval/experiments/vicuna-13b-baseline1.jsonl --bench-name humaneval --max-new-token 512 23 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/humaneval/experiments/vicuna-13b-baseline2.jsonl --bench-name humaneval --max-new-token 512 24 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/humaneval/experiments/vicuna-13b-baseline3.jsonl --bench-name humaneval --max-new-token 512 25 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/humaneval/experiments/vicuna-13b-faster1.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512 26 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/humaneval/experiments/vicuna-13b-faster2.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512 27 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/humaneval/experiments/vicuna-13b-faster3.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512 28 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/humaneval/experiments/vicuna-13b-faster4.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512 29 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/humaneval/experiments/vicuna-13b-faster5.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512 30 | 31 | # GSM8K 32 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/gsm8k/experiments/vicuna-13b-baseline1.jsonl --bench-name gsm8k 33 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/gsm8k/experiments/vicuna-13b-baseline2.jsonl --bench-name gsm8k 34 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/gsm8k/experiments/vicuna-13b-baseline3.jsonl --bench-name gsm8k 35 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/gsm8k/experiments/vicuna-13b-faster1.jsonl --tree-length 60 --bench-name gsm8k 36 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/gsm8k/experiments/vicuna-13b-faster2.jsonl --tree-length 60 --bench-name gsm8k 37 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/gsm8k/experiments/vicuna-13b-faster3.jsonl --tree-length 60 --bench-name gsm8k 38 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/gsm8k/experiments/vicuna-13b-faster4.jsonl --tree-length 60 --bench-name gsm8k 39 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/gsm8k/experiments/vicuna-13b-faster5.jsonl --tree-length 60 --bench-name gsm8k 40 | -------------------------------------------------------------------------------- /script/latency/vicuna-7b-gen.sh: -------------------------------------------------------------------------------- 1 | # MT bench, temperature sampling 2 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/mt_bench/experiments/vicuna-7b-baseline1.jsonl --bench-name mt_bench 3 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/mt_bench/experiments/vicuna-7b-baseline2.jsonl --bench-name mt_bench 4 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/mt_bench/experiments/vicuna-7b-baseline3.jsonl --bench-name mt_bench 5 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-7b-faster1.jsonl --tree-length 105 --bench-name mt_bench 6 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-7b-faster2.jsonl --tree-length 105 --bench-name mt_bench 7 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-7b-faster3.jsonl --tree-length 105 --bench-name mt_bench 8 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-7b-faster4.jsonl --tree-length 105 --bench-name mt_bench 9 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-7b-faster5.jsonl --tree-length 105 --bench-name mt_bench 10 | 11 | # MT bench, greedy sampling 12 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/mt_bench/experiments/vicuna-7b-baseline1-greedy.jsonl --bench-name mt_bench --temperature 0.0 13 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/mt_bench/experiments/vicuna-7b-baseline2-greedy.jsonl --bench-name mt_bench --temperature 0.0 14 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/mt_bench/experiments/vicuna-7b-baseline3-greedy.jsonl --bench-name mt_bench --temperature 0.0 15 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-7b-faster1-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0 16 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-7b-faster2-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0 17 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-7b-faster3-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0 18 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-7b-faster4-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0 19 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-7b-faster5-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0 20 | 21 | # HumanEval 22 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/humaneval/experiments/vicuna-7b-baseline1.jsonl --bench-name humaneval --max-new-token 512 23 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/humaneval/experiments/vicuna-7b-baseline2.jsonl --bench-name humaneval --max-new-token 512 24 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/humaneval/experiments/vicuna-7b-baseline3.jsonl --bench-name humaneval --max-new-token 512 25 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/humaneval/experiments/vicuna-7b-faster1.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512 26 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/humaneval/experiments/vicuna-7b-faster2.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512 27 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/humaneval/experiments/vicuna-7b-faster3.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512 28 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/humaneval/experiments/vicuna-7b-faster4.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512 29 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/humaneval/experiments/vicuna-7b-faster5.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512 30 | 31 | # GSM8K 32 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/gsm8k/experiments/vicuna-7b-baseline1.jsonl --bench-name gsm8k 33 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/gsm8k/experiments/vicuna-7b-baseline2.jsonl --bench-name gsm8k 34 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/gsm8k/experiments/vicuna-7b-baseline3.jsonl --bench-name gsm8k 35 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/gsm8k/experiments/vicuna-7b-faster1.jsonl --tree-length 105 --bench-name gsm8k 36 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/gsm8k/experiments/vicuna-7b-faster2.jsonl --tree-length 105 --bench-name gsm8k 37 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/gsm8k/experiments/vicuna-7b-faster3.jsonl --tree-length 105 --bench-name gsm8k 38 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/gsm8k/experiments/vicuna-7b-faster4.jsonl --tree-length 105 --bench-name gsm8k 39 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/gsm8k/experiments/vicuna-7b-faster5.jsonl --tree-length 105 --bench-name gsm8k 40 | -------------------------------------------------------------------------------- /script/train/train-ensemble-attention-kd.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | accelerate launch --num_processes 4 prompt/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \ 3 | --dataset_path "./ShareGPT_training_dataset_2_distillation.pt" \ 4 | --output_dir test/ \ 5 | --num_train_epochs 1 \ 6 | --save_steps 500 \ 7 | --model_max_length 2048 \ 8 | --num_special_tokens 3 \ 9 | --virtual_tokens_per_special_token 1 \ 10 | --per_device_train_batch_size 1 \ 11 | --per_device_eval_batch_size 1 \ 12 | --gradient_accumulation_steps 4 \ 13 | --evaluation_strategy "no" \ 14 | --learning_rate 1e-2 \ 15 | --weight_decay 0.0 \ 16 | --warmup_ratio 0.0 \ 17 | --lr_scheduler_type "cosine" \ 18 | --logging_steps 10 \ 19 | --load_in_4bit \ 20 | --vt_attention_type "ensemble" \ 21 | --trainer_type "distillation_trainer" 22 | # --use_prefix_tuning \ 23 | # --prefix_virtual_tokens 10 \ 24 | # --size 100 25 | # --tf32 True \ requires at least Ampere 26 | # 2048 length not working for batch 1 27 | --------------------------------------------------------------------------------