├── .gitignore
├── LICENSE
├── README.md
├── application
├── README.md
├── accept_length.py
├── common.py
├── gen_judgment.py
├── gen_model_answer_baseline.py
├── gen_model_answer_prompt_decoding.py
├── get_throughput_results.py
├── show_result.py
├── tree_latency.py
└── webui.py
├── assets
├── .DS_Store
├── Overview.png
├── PPD_LOGO.png
├── Speed_Mem_Train.png
├── latency.png
└── ppd_demo.gif
├── dataset_generation
└── generate_dataset.py
├── examples
└── tutorial.ipynb
├── poetry.lock
├── prompt
├── __init__.py
├── evaluation
│ ├── __init__.py
│ └── eval.py
├── hf_utils.py
├── inference
│ ├── __init__.py
│ ├── dynamic_sparse_trees_2_13b.py
│ ├── dynamic_sparse_trees_2_7b.py
│ ├── dynamic_sparse_trees_3_MobileLLaMA.py
│ ├── dynamic_sparse_trees_3_vicuna_13b.py
│ ├── dynamic_sparse_trees_3_vicuna_68m.py
│ ├── dynamic_sparse_trees_3_vicuna_7b.py
│ ├── full_sparse_trees_3_13b.py
│ ├── full_sparse_trees_3_7b.py
│ ├── random_sparse_trees_3_7b.py
│ └── sparse_tree_builder.py
├── model
│ ├── __init__.py
│ ├── kv_cache.py
│ ├── model.py
│ └── modeling_llama_custom.py
├── train
│ ├── __init__.py
│ └── train.py
└── utils.py
├── pyproject.toml
└── script
├── eval
└── eval-2-1-ensemble.sh
├── latency
├── optimal-sparse-tree.sh
├── vicuna-13b-gen.sh
└── vicuna-7b-gen.sh
└── train
└── train-ensemble-attention-kd.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | venv
3 | data
4 | ShareGPT_Vicuna_unfiltered/
5 | test*
6 | log
7 | *egg-info
8 | *log
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
Parallel Prompt Decoding: Accelerate LLM Inference with Parallel Prompting
2 |
3 | **PPD** (Parallel Prompt Decoding) is a cost-efficient method to accelerate LLMs with trained appended prompt tokens. Our technique stands out for three key features:
4 | - *Orthogonal Optimization*: Orthogonal to speculative decoding, PPD provides the potential for synergistic integration.
5 | - *Memory Efficiency*: With a minimal runtime memory overhead of just 0.0004%, PPD is highly suitable for edge and mobile settings.
6 | - *Training Efficiency*: The training process is efficient, requiring only 16 hours on a single A100-40GB GPU.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | PPD on Vicuna-7b.
15 |
16 |
17 |
18 |
19 | The key intuition of **PPD** lies in the observation that if trained properly, prompt tokens appended to the input can approximate tokens generated at future timesteps, thereby partially recovering the missing conditional dependency information for multi-token generation.
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | Inspired by the human natural language generation process where continuous words like common expressions and phrases are produced simultaneously, PPD introduces the use of prompt tokens for multi-token prediction. Specifically, these trained prompt tokens are appended to the original input sequence, enabling the concurrent generation of multiple output tokens in a single forward pass.
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | **PPD** is designed to address the following challenges faced by the current speculative decoding methods:
42 |
43 | - Low step compression rate due to conditional independence assumption.
44 | - High complexity and cost of training/maintaining draft models.
45 | - Limited applicability in edge and mobile environments due to memory overhead.
46 |
47 | Through extensive experiments across LLMs ranging from MobileLlama to Vicuna-13B on a wide range of benchmarks, our approach demonstrates up to 2.49 $\times$ speedup and maintains a minimal runtime memory overhead of just 0.0004%.
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | Evaluated on a single A100 with MT-Bench.
56 |
57 |
58 |
59 |
60 | Our paper is available ([here](https://arxiv.org/abs/2405.18628v1))! If you found it helpful, pls cite us:
61 | ```
62 | @article{hao2024ppd,
63 | title={Hardware-aware parallel prompt decoding for memory-efficient acceleration of LLM inference},
64 | author={Chen, Hao (Mark) and Luk, Wayne and Yiu, Ka Fai Cedric and Li, Rui and Mishchenko, Konstantin and Venieris, Stylianos I and Fan, Hongxiang,
65 | journal={arXiv preprint arXiv:2405.18628},
66 | year={2024}
67 | }
68 | ```
69 |
70 | ## Contents
71 |
72 | - [Installation](#installation)
73 | - [Model Weights](#model-weights)
74 | - [Dataset Generation](#dataset-generation)
75 | - [Truncated Dataset](#truncated-dataset)
76 | - [Distillation Dataset](#distillation-dataset)
77 | - [Training Special tokens](#special-tokens)
78 | - [Inference](#inference)
79 | - [Chat Application](#chat-application)
80 | - [MT Bench](#mt-bench)
81 | - [Alpaca Eval](#alpaca-eval)
82 | - [HumanEval](#humaneval)
83 | - [GSM8K](#gsm8k)
84 |
85 | ## Installation
86 |
87 | ```bash
88 | git clone https://github.com/hmarkc/prompt-decoding.git
89 | cd prompt-decoding
90 | pip install -e .
91 | ```
92 |
93 | ## Model Weights
94 |
95 | | Original Model | PPD embedding weights |
96 | | ------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- |
97 | | [lmsys/vicuna-7b-v1.3](https://huggingface.co/lmsys/vicuna-7b-v1.3) | [hmarkc/ppd-vicuna-7b-v1.3](https://huggingface.co/hmarkc/ppd-vicuna-7b-v1.3) |
98 | | [lmsys/vicuna-13b-v1.3](https://huggingface.co/lmsys/vicuna-13b-v1.3) | [hmarkc/ppd-vicuna-13b-v1.3](https://huggingface.co/hmarkc/ppd-vicuna-13b-v1.3) |
99 | | [mtgv/MobileLLaMA-1.4B-Chat](https://huggingface.co/mtgv/MobileLLaMA-1.4B-Chat) | [hmarkc/ppd-MobileLLaMA-1.4B-Chat](https://huggingface.co/hmarkc/ppd-MobileLLaMA-1.4B-Chat) |
100 |
101 | ## Dataset Generation
102 |
103 | ### Truncated Dataset
104 |
105 | With a given dataset, a random truncation is performed to reduce the contextual bias of the training of special tokens. Then, a distillation dataset is generated from the truncated dataset.
106 |
107 | The truncated datasets need to be generated first. Here is how a dataset for 3 special tokens can be generated for the ShareGPT dataset.
108 |
109 | ```
110 | python generate_dataset.py --dataset_type finetune --num_special_tokens 3 --data_path ShareGPT_V4.3_unfiltered_cleaned_split.json --model_max_length 2048
111 | ```
112 |
113 | ### Distillation Dataset
114 |
115 | Then, we can generate the distillation dataset from the truncated dataset. `--data_path` is the path to the previously generated truncated dataset and `--model_name_or_path` is the model the distribution of which we want to obtain.
116 |
117 | ```
118 | python generate_dataset.py --dataset_type distillation --num_special_tokens 3 --data_path ShareGPT_training_dataset_3_finetune_2048.pt --model_max_length 2048 --model_name_or_path lmsys/vicuna-7b-v1.3
119 | ```
120 |
121 | ## Training Special tokens
122 |
123 | Example script to train Vicuna-7b with distillation dataset named "ShareGPT_training_dataset_2_distillation.pt".
124 |
125 | ```
126 | accelerate launch --num_processes 4 prompt/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
127 | --dataset_path "./ShareGPT_training_dataset_2_distillation.pt" \
128 | --output_dir test/ \
129 | --num_train_epochs 1 \
130 | --save_steps 500 \
131 | --model_max_length 2048 \
132 | --num_special_tokens 3 \
133 | --virtual_tokens_per_special_token 1 \
134 | --per_device_train_batch_size 1 \
135 | --per_device_eval_batch_size 1 \
136 | --gradient_accumulation_steps 4 \
137 | --evaluation_strategy "no" \
138 | --learning_rate 1e-2 \
139 | --weight_decay 0.0 \
140 | --warmup_ratio 0.0 \
141 | --lr_scheduler_type "cosine" \
142 | --logging_steps 10 \
143 | --load_in_4bit \
144 | --vt_attention_type "ensemble" \
145 | --trainer_type "distillation_trainer"
146 | ```
147 |
148 | You need to change the `--dataset_path` to the location of the distillation dataset and specify `--trainer_type` as "distillation_trainer" to train with knowledge distillation. `--num_special_tokens` specifies the number of special tokens for training. `--virtual_tokens_per_special_token` is the number of virtual tokens used for 1 special token, which should be set to 1 to achieve the lowest latency results.
149 |
150 | ## Inference
151 |
152 | We employ a dynamically extended tree attention and top K candidates for inference. The supported evaluation datasets currently include [Alpaca Eval](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json), [MT Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge), and [HumanEval](https://github.com/openai/human-eval).
153 |
154 | Refer to this [README.md](application/README.md) on how to install libraries for these datasets.
155 |
156 | ### Chat Application
157 |
158 | We implemented a simple chat application using `gradio`. To start a server for the chat application, run `python application/webui.py --ppd-path `.
159 |
160 | ### MT Bench
161 |
162 | - To obtain the latency of the baseline model (without the use of special tokens), run
163 |
164 | ```
165 | python3 gen_model_answer_baseline.py
166 | --model-path
167 | --model-id
168 | --answer-file
169 | --bench-name mt_bench
170 | --temperature
171 | ```
172 |
173 | - To obtain the latency of the model with special tokens, run
174 |
175 | ```
176 | python3 gen_model_answer_prompt_decoding.py
177 | --model-path
178 | --model-id
179 | --answer-file
180 | --tree-length 105
181 | --bench-name mt_bench
182 | --temperature
183 | ```
184 |
185 | `--model-path` is the path to the trained special tokens and `--tree-length` is the length of the sparse tree used.
186 |
187 | - To view the latency results of a generated `.jsonl` file, run
188 |
189 | ```
190 | python get_throughput_results.py data/mt_bench/experiments/vicuna-7b-faster1.jsonl --n 3
191 | ```
192 |
193 | `--n` specifies the number of experiment runs to get the average of.
194 |
195 | ### Alpaca Eval
196 |
197 | We use Alpaca Eval dataset as the evaluation dataset. The latency results can be obtained using the same script as MT Bench and adding `--bench-name alpaca_eval`.
198 |
199 | - To compare the latencies and accept lengths between sparse trees with different sizes, run:
200 |
201 | ```
202 | python accept_length.py \
203 | --dir-path \
204 | --file-name \
205 | --model-name \
206 | --eval-file-name gen_model_answer_prompt_decoding.py \
207 | --n 1 \
208 | --max-length 120 \
209 | --min-length 60 \
210 | --length-interval 9 \
211 | --choices "[75, 105, 135, 165, 195, 225, 255, 285]" \
212 |
213 | python3 tree_latency.py \
214 | --model-path \
215 | --model-id \
216 | --answer-file \
217 | --bench-name alpaca_eval \
218 | --min-tree-length 60 \
219 | --max-tree-length 120 \
220 | --length-interval 3 \
221 | --max-new-token 1024
222 | ```
223 |
224 | This [script](script/latency/optimal-sparse-tree.sh) runs the latency tests on a range of sparse trees.
225 |
226 | ### HumanEval
227 |
228 | The latency results of HumanEval can be obtained using the same script as MT Bench and adding `--bench-name humaneval`.
229 |
230 | ### GSM8K
231 |
232 | The latency results of GSM8K can be obtained using the same script as MT Bench and adding `--bench-name gsm8k`.
233 |
234 |
--------------------------------------------------------------------------------
/application/README.md:
--------------------------------------------------------------------------------
1 | ## LLM Judge
2 |
3 | | [installation](https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/README.md)
4 |
5 | ### Citation
6 | ```
7 | @misc{zheng2023judging,
8 | title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
9 | author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
10 | year={2023},
11 | eprint={2306.05685},
12 | archivePrefix={arXiv},
13 | primaryClass={cs.CL}
14 | }
15 | ```
16 |
17 | ## Alpaca Eval
18 |
19 | | [installation](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json)
20 |
21 | ### Citation
22 |
23 | ```
24 | @misc{alpaca_eval,
25 | author = {Xuechen Li and Tianyi Zhang and Yann Dubois and Rohan Taori and Ishaan Gulrajani and Carlos Guestrin and Percy Liang and Tatsunori B. Hashimoto },
26 | title = {AlpacaEval: An Automatic Evaluator of Instruction-following Models},
27 | year = {2023},
28 | publisher = {GitHub},
29 | journal = {GitHub repository},
30 | howpublished = {\url{https://github.com/tatsu-lab/alpaca_eval}}
31 | }
32 | ```
33 |
34 | ## HumanEval
35 |
36 | | [installation](https://github.com/openai/human-eval/blob/master/README.md)
37 |
38 | ### Citation
39 |
40 | ```
41 | @article{chen2021codex,
42 | title={Evaluating Large Language Models Trained on Code},
43 | author={Mark Chen and Jerry Tworek and Heewoo Jun and Qiming Yuan and Henrique Ponde de Oliveira Pinto and Jared Kaplan and Harri Edwards and Yuri Burda and Nicholas Joseph and Greg Brockman and Alex Ray and Raul Puri and Gretchen Krueger and Michael Petrov and Heidy Khlaaf and Girish Sastry and Pamela Mishkin and Brooke Chan and Scott Gray and Nick Ryder and Mikhail Pavlov and Alethea Power and Lukasz Kaiser and Mohammad Bavarian and Clemens Winter and Philippe Tillet and Felipe Petroski Such and Dave Cummings and Matthias Plappert and Fotios Chantzis and Elizabeth Barnes and Ariel Herbert-Voss and William Hebgen Guss and Alex Nichol and Alex Paino and Nikolas Tezak and Jie Tang and Igor Babuschkin and Suchir Balaji and Shantanu Jain and William Saunders and Christopher Hesse and Andrew N. Carr and Jan Leike and Josh Achiam and Vedant Misra and Evan Morikawa and Alec Radford and Matthew Knight and Miles Brundage and Mira Murati and Katie Mayer and Peter Welinder and Bob McGrew and Dario Amodei and Sam McCandlish and Ilya Sutskever and Wojciech Zaremba},
44 | year={2021},
45 | eprint={2107.03374},
46 | archivePrefix={arXiv},
47 | primaryClass={cs.LG}
48 | }
49 | ```
50 |
51 | ## GSM8K
52 |
53 | | There is no need to install the ```GSM8K``` dataset manually.
54 |
55 | ### Citation
56 |
57 | ```
58 | @article{cobbe2021gsm8k,
59 | title={Training Verifiers to Solve Math Word Problems},
60 | author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Chen, Mark and Jun, Heewoo and Kaiser, Lukasz and Plappert, Matthias and Tworek, Jerry and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
61 | journal={arXiv preprint arXiv:2110.14168},
62 | year={2021}
63 | }
64 | ```
65 |
--------------------------------------------------------------------------------
/application/accept_length.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pandas as pd
4 | import argparse
5 | from tqdm import tqdm
6 | import os
7 | import json
8 | import pandas as pd
9 | import argparse
10 | from tqdm import tqdm
11 |
12 | def get_throughput_results(input_file):
13 | with open(input_file, "r") as f:
14 | data = [json.loads(line) for line in f]
15 |
16 | new_tokens = []
17 | wall_time = []
18 | throughputs = []
19 | for d in data:
20 | for choice in d["choices"]:
21 | new_tokens.extend(choice["new_tokens"])
22 | wall_time.extend(choice["wall_time"])
23 | for i in range(len(choice["new_tokens"])):
24 | throughputs.append(choice["new_tokens"][i] / choice["wall_time"][i])
25 |
26 | return sum(new_tokens) / sum(wall_time)
27 | # return sum(throughputs) / len(throughputs)
28 |
29 | def main(model_name, dir_path, file_name, eval_file_name, max_length, min_length, length_interval, run_baseline, choices, n):
30 | max_throughput = 0
31 | # run baseline first
32 | if run_baseline:
33 | for j in range(n):
34 | if os.system(f"python3 gen_model_answer_baseline.py --model-path ../test/{model_name} --model-id vicuna --answer-file {dir_path}/baseline_{j}.jsonl --bench-name alpaca_eval --max-new-token 20"):
35 | raise ValueError("Failed to generate baseline")
36 | if choices is not None:
37 | r = eval(choices)
38 | else:
39 | r = range(min_length, max_length+1, length_interval)
40 | for i in tqdm(r):
41 | throughputs = []
42 | for j in range(n):
43 | # if file does not exist, generate it
44 | if not os.path.exists(f"{dir_path}/{file_name}{i}_{j}.jsonl"):
45 | # use alpaca dataset for evaluation
46 | if os.system(f"python3 {eval_file_name} --model-path ../test/{model_name} --model-id vicuna_faster --answer-file {dir_path}/{file_name}{i}_{j}.jsonl --tree-length {i} --bench-name alpaca_eval --max-new-token 20"):
47 | raise ValueError("Failed to generate prompt decoding model")
48 |
49 | if os.path.exists(f"{dir_path}/{file_name}{i}_{j}.jsonl"):
50 | throughput = get_throughput_results(f"{dir_path}/{file_name}{i}_{j}.jsonl")
51 | throughputs.append(throughput)
52 |
53 | if len(throughputs) > 0:
54 | throughput = sum(throughputs) / len(throughputs)
55 | std = pd.Series(throughputs).std()
56 |
57 | if throughput > max_throughput:
58 | max_throughput = throughput
59 | best_sparse_tree = i
60 |
61 | print(f"Throughput for sparse tree {i}: {throughput:.3f} tokens/s", f"std: {std:.3f}")
62 | print(f"Best sparse tree: {best_sparse_tree}")
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--dir-path", type=str, default="data/mt_bench/dynamic_sparse_tree_search/3-1-7b", help="Path to the directory")
67 | parser.add_argument("--file-name", type=str, default="dynamic_sparse_tree", help="Name of the file")
68 | parser.add_argument("--model-name", type=str, default="vicuna-7b-3-1", help="Name of the model")
69 | parser.add_argument("--eval-file-name", type=str, default="gen_model_answer_prompt_decoding.py", help="Name of the evaluation file")
70 | parser.add_argument("--max-length", type=int, default=100, help="Max length of the sparse tree")
71 | parser.add_argument("--min-length", type=int, default=60, help="Min length of the sparse tree")
72 | parser.add_argument("--length-interval", type=int, default=1, help="Interval of the length of the sparse tree")
73 | parser.add_argument("--run-baseline", action="store_true", help="Run baseline first")
74 | parser.add_argument("--choices", type=str, default=None, help="Choices for the prompt decoding model")
75 | parser.add_argument("--n", type=int, default=1, help="Number of files to group")
76 | args = parser.parse_args()
77 |
78 | main(model_name=args.model_name,
79 | dir_path=args.dir_path,
80 | file_name=args.file_name,
81 | eval_file_name=args.eval_file_name,
82 | max_length=args.max_length,
83 | min_length=args.min_length,
84 | length_interval=args.length_interval,
85 | run_baseline=args.run_baseline,
86 | choices=args.choices,
87 | n=args.n)
88 |
89 |
90 |
91 |
92 |
--------------------------------------------------------------------------------
/application/common.py:
--------------------------------------------------------------------------------
1 | """
2 | Common data structures and utilities.
3 | """
4 |
5 | import ast
6 | import dataclasses
7 | import glob
8 | import json
9 | import os
10 | import re
11 | import time
12 | from typing import Optional
13 |
14 | import openai
15 | import anthropic
16 |
17 | from fastchat.model.model_adapter import (
18 | get_conversation_template,
19 | ANTHROPIC_MODEL_LIST,
20 | OPENAI_MODEL_LIST,
21 | )
22 |
23 | # API setting constants
24 | API_MAX_RETRY = 16
25 | API_RETRY_SLEEP = 10
26 | API_ERROR_OUTPUT = "$ERROR$"
27 |
28 | TIE_DELTA = 0.1
29 |
30 | # Categories that need reference answers
31 | NEED_REF_CATS = ["math", "reasoning", "coding", "arena-hard-200"]
32 |
33 | # Extract scores from judgments
34 | two_score_pattern = re.compile("\[\[(\d+\.?\d*),\s?(\d+\.?\d*)\]\]")
35 | two_score_pattern_backup = re.compile("\[(\d+\.?\d*),\s?(\d+\.?\d*)\]")
36 | one_score_pattern = re.compile("\[\[(\d+\.?\d*)\]\]")
37 | one_score_pattern_backup = re.compile("\[(\d+\.?\d*)\]")
38 |
39 | # Sampling temperature configs for
40 | temperature_config = {
41 | "writing": 0.7,
42 | "roleplay": 0.7,
43 | "extraction": 0.0,
44 | "math": 0.0,
45 | "coding": 0.0,
46 | "reasoning": 0.0,
47 | "stem": 0.1,
48 | "humanities": 0.1,
49 | "arena-hard-200": 0.0,
50 | }
51 |
52 | reverse_model_map = {
53 | "model_1": "model_2",
54 | "model_2": "model_1",
55 | }
56 |
57 |
58 | @dataclasses.dataclass
59 | class Judge:
60 | model_name: str
61 | prompt_template: dict
62 | ref_based: bool = False
63 | multi_turn: bool = False
64 |
65 |
66 | @dataclasses.dataclass
67 | class MatchSingle:
68 | question: dict
69 | model: str
70 | answer: dict
71 | judge: Judge
72 | ref_answer: dict = None
73 | multi_turn: bool = False
74 |
75 |
76 | @dataclasses.dataclass
77 | class MatchPair:
78 | question: dict
79 | model_1: str
80 | model_2: str
81 | answer_1: dict
82 | answer_2: dict
83 | judge: Judge
84 | ref_answer: dict = None
85 | multi_turn: bool = False
86 |
87 |
88 | def load_questions(question_file: str, begin: Optional[int], end: Optional[int]):
89 | """Load questions from a file."""
90 | questions = []
91 | with open(question_file, "r") as ques_file:
92 | for line in ques_file:
93 | if line:
94 | questions.append(json.loads(line))
95 | questions = questions[begin:end]
96 | return questions
97 |
98 |
99 | def load_model_answers(answer_dir: str):
100 | """Load model answers.
101 |
102 | The return value is a python dict of type:
103 | Dict[model_name: str -> Dict[question_id: int -> answer: dict]]
104 | """
105 | filenames = glob.glob(os.path.join(answer_dir, "*.jsonl"))
106 | filenames.sort()
107 | model_answers = {}
108 |
109 | for filename in filenames:
110 | model_name = os.path.basename(filename)[:-6]
111 | answer = {}
112 | with open(filename) as fin:
113 | for line in fin:
114 | line = json.loads(line)
115 | answer[line["question_id"]] = line
116 | model_answers[model_name] = answer
117 |
118 | return model_answers
119 |
120 |
121 | def load_judge_prompts(prompt_file: str):
122 | """Load judge prompts.
123 |
124 | The return value is a python dict of type:
125 | Dict[judge_name: str -> dict]
126 | """
127 | prompts = {}
128 | with open(prompt_file) as fin:
129 | for line in fin:
130 | line = json.loads(line)
131 | prompts[line["name"]] = line
132 | return prompts
133 |
134 |
135 | def run_judge_single(question, answer, judge, ref_answer, multi_turn=False):
136 | kwargs = {}
137 | model = judge.model_name
138 | if ref_answer is not None:
139 | kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0]
140 | if multi_turn:
141 | kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1]
142 |
143 | if multi_turn:
144 | user_prompt = judge.prompt_template["prompt_template"].format(
145 | question_1=question["turns"][0],
146 | question_2=question["turns"][1],
147 | answer_1=answer["choices"][0]["turns"][0],
148 | answer_2=answer["choices"][0]["turns"][1],
149 | **kwargs,
150 | )
151 | else:
152 | user_prompt = judge.prompt_template["prompt_template"].format(
153 | question=question["turns"][0],
154 | answer=answer["choices"][0]["turns"][0],
155 | **kwargs,
156 | )
157 |
158 | rating = -1
159 |
160 | system_prompt = judge.prompt_template["system_prompt"]
161 | conv = get_conversation_template(model)
162 | conv.set_system_message(system_prompt)
163 | conv.append_message(conv.roles[0], user_prompt)
164 | conv.append_message(conv.roles[1], None)
165 |
166 | if model in OPENAI_MODEL_LIST:
167 | judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048)
168 | elif model in ANTHROPIC_MODEL_LIST:
169 | judgment = chat_completion_anthropic(
170 | model, conv, temperature=0, max_tokens=1024
171 | )
172 | else:
173 | raise ValueError(f"Invalid judge model name: {model}")
174 |
175 | if judge.prompt_template["output_format"] == "[[rating]]":
176 | match = re.search(one_score_pattern, judgment)
177 | if not match:
178 | match = re.search(one_score_pattern_backup, judgment)
179 |
180 | if match:
181 | rating = ast.literal_eval(match.groups()[0])
182 | else:
183 | rating = -1
184 | else:
185 | raise ValueError(
186 | f"invalid output format: {judge.prompt_template['output_format']}"
187 | )
188 |
189 | return rating, user_prompt, judgment
190 |
191 |
192 | def play_a_match_single(match: MatchSingle, output_file: str):
193 | question, model, answer, judge, ref_answer, multi_turn = (
194 | match.question,
195 | match.model,
196 | match.answer,
197 | match.judge,
198 | match.ref_answer,
199 | match.multi_turn,
200 | )
201 |
202 | if judge.prompt_template["type"] == "single":
203 | score, user_prompt, judgment = run_judge_single(
204 | question, answer, judge, ref_answer, multi_turn=multi_turn
205 | )
206 |
207 | question_id = question["question_id"]
208 | turn = 1 if not multi_turn else 2
209 | result = {
210 | "question_id": question_id,
211 | "model": model,
212 | "judge": (judge.model_name, judge.prompt_template["name"]),
213 | "user_prompt": user_prompt,
214 | "judgment": judgment,
215 | "score": score,
216 | "turn": turn,
217 | "tstamp": time.time(),
218 | }
219 | print(
220 | f"question: {question_id}, turn: {turn}, model: {model}, "
221 | f"score: {score}, "
222 | f"judge: {(judge.model_name, judge.prompt_template['name'])}"
223 | )
224 | else:
225 | raise ValueError(f"invalid judge type: {judge['type']}")
226 |
227 | if output_file:
228 | os.makedirs(os.path.dirname(output_file), exist_ok=True)
229 | with open(output_file, "a") as fout:
230 | fout.write(json.dumps(result) + "\n")
231 |
232 | return result
233 |
234 |
235 | def run_judge_pair(question, answer_a, answer_b, judge, ref_answer, multi_turn=False):
236 | kwargs = {}
237 | model = judge.model_name
238 | if ref_answer is not None:
239 | kwargs["ref_answer_1"] = ref_answer["choices"][0]["turns"][0]
240 | if multi_turn:
241 | kwargs["ref_answer_2"] = ref_answer["choices"][0]["turns"][1]
242 |
243 | if multi_turn:
244 | system_prompt = judge.prompt_template["system_prompt"]
245 | user_prompt = judge.prompt_template["prompt_template"].format(
246 | question_1=question["turns"][0],
247 | question_2=question["turns"][1],
248 | answer_a_1=answer_a["choices"][0]["turns"][0],
249 | answer_b_1=answer_b["choices"][0]["turns"][0],
250 | answer_a_2=answer_a["choices"][0]["turns"][1],
251 | answer_b_2=answer_b["choices"][0]["turns"][1],
252 | **kwargs,
253 | )
254 | else:
255 | system_prompt = judge.prompt_template["system_prompt"]
256 | user_prompt = judge.prompt_template["prompt_template"].format(
257 | question=question["turns"][0],
258 | answer_a=answer_a["choices"][0]["turns"][0],
259 | answer_b=answer_b["choices"][0]["turns"][0],
260 | **kwargs,
261 | )
262 |
263 | winner = "error"
264 |
265 | conv = get_conversation_template(model)
266 | conv.append_message(conv.roles[0], user_prompt)
267 | conv.append_message(conv.roles[1], None)
268 |
269 | if model in OPENAI_MODEL_LIST:
270 | conv.set_system_message(system_prompt)
271 | judgment = chat_completion_openai(model, conv, temperature=0, max_tokens=2048)
272 | elif model in ANTHROPIC_MODEL_LIST:
273 | if system_prompt != "You are a helpful assistant.":
274 | user_prompt = "[Instruction]\n" + system_prompt + "\n\n" + user_prompt
275 | conv.messages[0][1] = user_prompt
276 | judgment = chat_completion_anthropic(
277 | model, conv, temperature=0, max_tokens=1024
278 | )
279 | else:
280 | raise ValueError(f"Invalid judge model name: {model}")
281 |
282 | if judge.prompt_template["output_format"] == "[[A]]":
283 | if "[[A]]" in judgment:
284 | winner = "A"
285 | elif "[[B]]" in judgment:
286 | winner = "B"
287 | elif "[[C]]" in judgment:
288 | winner = "tie"
289 | else:
290 | winner = "error"
291 | elif judge.prompt_template["output_format"] == "[[rating_a,rating_b]]":
292 | match = re.search(two_score_pattern, judgment)
293 | if not match:
294 | match = re.search(two_score_pattern_backup, judgment)
295 | if match:
296 | scores = [ast.literal_eval(s.strip()) for s in match.groups()]
297 | if abs(scores[0] - scores[1]) <= TIE_DELTA:
298 | winner = "tie"
299 | elif scores[0] > scores[1]:
300 | winner = "A"
301 | else:
302 | winner = "B"
303 | else:
304 | winner = "error"
305 | else:
306 | raise ValueError(
307 | f"invalid output format: {judge.prompt_template['output_format']}"
308 | )
309 |
310 | return winner, user_prompt, judgment
311 |
312 |
313 | def play_a_match_pair(match: MatchPair, output_file: str):
314 | question, model_1, model_2, answer_1, answer_2, judge, ref_answer, multi_turn = (
315 | match.question,
316 | match.model_1,
317 | match.model_2,
318 | match.answer_1,
319 | match.answer_2,
320 | match.judge,
321 | match.ref_answer,
322 | match.multi_turn,
323 | )
324 |
325 | if judge.prompt_template["type"] == "pairwise":
326 | g1_winner, g1_user_prompt, g1_judgment = run_judge_pair(
327 | question, answer_1, answer_2, judge, ref_answer, multi_turn=multi_turn
328 | )
329 | g2_winner, g2_user_prompt, g2_judgment = run_judge_pair(
330 | question, answer_2, answer_1, judge, ref_answer, multi_turn=multi_turn
331 | )
332 |
333 | g1_map = {"A": "model_1", "B": "model_2"}
334 | g2_map = {"A": "model_2", "B": "model_1"}
335 | g1_winner = g1_map.get(g1_winner, g1_winner)
336 | g2_winner = g2_map.get(g2_winner, g2_winner)
337 | question_id = question["question_id"]
338 | turn = 1 if not multi_turn else 2
339 |
340 | result = {
341 | "question_id": question_id,
342 | "model_1": model_1,
343 | "model_2": model_2,
344 | "g1_winner": g1_winner,
345 | "g2_winner": g2_winner,
346 | "judge": (judge.model_name, judge.prompt_template["name"]),
347 | "g1_user_prompt": g1_user_prompt,
348 | "g1_judgment": g1_judgment,
349 | "g2_user_prompt": g2_user_prompt,
350 | "g2_judgment": g2_judgment,
351 | "turn": turn,
352 | "tstamp": time.time(),
353 | }
354 |
355 | print(
356 | f"question: {question_id}, turn: {turn}, model_1: {model_1}, model_2: {model_2}, "
357 | f"g1_winner: {g1_winner}, g2_winner: {g2_winner}, "
358 | f"judge: {(judge.model_name, judge.prompt_template['name'])}"
359 | )
360 | elif judge.prompt_template["type"] == "single":
361 | m1_score, m1_user_prompt, m1_judgment = run_judge_single(
362 | question, answer_1, judge
363 | )
364 | m2_score, m2_user_prompt, m2_judgment = run_judge_single(
365 | question, answer_2, judge
366 | )
367 |
368 | if abs(m1_score - m2_score) <= TIE_DELTA:
369 | winner = "tie"
370 | elif m1_score > m2_score:
371 | winner = "model_1"
372 | else:
373 | winner = "model_2"
374 |
375 | question_id = question["question_id"]
376 | result = {
377 | "question_id": question_id,
378 | "model_1": model_1,
379 | "model_2": model_2,
380 | "g1_winner": winner,
381 | "g2_winner": winner,
382 | "judge": (judge.model_name, judge.prompt_template["name"]),
383 | "g1_user_prompt": m1_user_prompt,
384 | "g1_judgment": m1_judgment,
385 | "g2_user_prompt": m2_user_prompt,
386 | "g2_judgment": m2_judgment,
387 | "m1_score": m1_score,
388 | "m2_score": m2_score,
389 | "tstamp": time.time(),
390 | }
391 | print(
392 | f"question: {question_id}, model_1: {model_1}, model_2: {model_2}, "
393 | f"winner: {winner}, m1_score: {m1_score}, m2_score: {m2_score}, "
394 | f"judge: {(judge.model_name, judge.prompt_template['name'])}"
395 | )
396 | else:
397 | raise ValueError(f"invalid judge type: {judge['type']}")
398 |
399 | if output_file:
400 | os.makedirs(os.path.dirname(output_file), exist_ok=True)
401 | with open(output_file, "a") as fout:
402 | fout.write(json.dumps(result) + "\n")
403 |
404 | return result
405 |
406 |
407 | def chat_completion_openai(model, conv, temperature, max_tokens, api_dict=None):
408 | if api_dict is not None:
409 | openai.api_base = api_dict["api_base"]
410 | openai.api_key = api_dict["api_key"]
411 | output = API_ERROR_OUTPUT
412 | for _ in range(API_MAX_RETRY):
413 | try:
414 | messages = conv.to_openai_api_messages()
415 | response = openai.ChatCompletion.create(
416 | model=model,
417 | messages=messages,
418 | n=1,
419 | temperature=temperature,
420 | max_tokens=max_tokens,
421 | )
422 | output = response["choices"][0]["message"]["content"]
423 | break
424 | except openai.error.OpenAIError as e:
425 | print(type(e), e)
426 | time.sleep(API_RETRY_SLEEP)
427 |
428 | return output
429 |
430 |
431 | def chat_completion_openai_azure(model, conv, temperature, max_tokens, api_dict=None):
432 | openai.api_type = "azure"
433 | openai.api_version = "2023-07-01-preview"
434 | if api_dict is not None:
435 | openai.api_base = api_dict["api_base"]
436 | openai.api_key = api_dict["api_key"]
437 | else:
438 | openai.api_base = os.environ["AZURE_OPENAI_ENDPOINT"]
439 | openai.api_key = os.environ["AZURE_OPENAI_KEY"]
440 |
441 | if "azure-" in model:
442 | model = model[6:]
443 |
444 | output = API_ERROR_OUTPUT
445 | for _ in range(API_MAX_RETRY):
446 | try:
447 | messages = conv.to_openai_api_messages()
448 | response = openai.ChatCompletion.create(
449 | engine=model,
450 | messages=messages,
451 | n=1,
452 | temperature=temperature,
453 | max_tokens=max_tokens,
454 | )
455 | output = response["choices"][0]["message"]["content"]
456 | break
457 | except openai.error.OpenAIError as e:
458 | print(type(e), e)
459 | time.sleep(API_RETRY_SLEEP)
460 | except openai.error.InvalidRequestError as e:
461 | print(type(e), e)
462 | break
463 | except KeyError:
464 | print(response)
465 | break
466 |
467 | return output
468 |
469 |
470 | def chat_completion_anthropic(model, conv, temperature, max_tokens, api_dict=None):
471 | if api_dict is not None and "api_key" in api_dict:
472 | api_key = api_dict["api_key"]
473 | else:
474 | api_key = os.environ["ANTHROPIC_API_KEY"]
475 |
476 | output = API_ERROR_OUTPUT
477 | for _ in range(API_MAX_RETRY):
478 | try:
479 | c = anthropic.Anthropic(api_key=api_key)
480 | prompt = conv.get_prompt()
481 | response = c.completions.create(
482 | model=model,
483 | prompt=prompt,
484 | stop_sequences=[anthropic.HUMAN_PROMPT],
485 | max_tokens_to_sample=max_tokens,
486 | temperature=temperature,
487 | )
488 | output = response.completion
489 | break
490 | except anthropic.APIError as e:
491 | print(type(e), e)
492 | time.sleep(API_RETRY_SLEEP)
493 | return output.strip()
494 |
495 |
496 | def chat_completion_palm(chat_state, model, conv, temperature, max_tokens):
497 | from fastchat.serve.api_provider import init_palm_chat
498 |
499 | assert model == "palm-2-chat-bison-001"
500 |
501 | if chat_state is None:
502 | chat_state = init_palm_chat("chat-bison@001")
503 |
504 | parameters = {
505 | "temperature": temperature,
506 | "top_p": 0.8,
507 | "top_k": 40,
508 | "max_output_tokens": max_tokens,
509 | }
510 | output = API_ERROR_OUTPUT
511 | for _ in range(API_MAX_RETRY):
512 | try:
513 | response = chat_state.send_message(conv.messages[-2][1], **parameters)
514 | output = response.text
515 | break
516 | except Exception as e:
517 | print(type(e), e)
518 | time.sleep(API_RETRY_SLEEP)
519 | return chat_state, output
520 |
521 |
522 | def normalize_game_key_single(gamekey, result):
523 | """Make the model names sorted in a game key."""
524 | qid, model_1, model_2 = gamekey
525 | if model_1 < model_2:
526 | return gamekey, result
527 | else:
528 | new_gamekey = (qid, model_2, model_1)
529 | new_result = {
530 | "winners": tuple(reverse_model_map.get(x, x) for x in result["winners"]),
531 | "g1_judgment": result["g2_judgment"],
532 | "g2_judgment": result["g1_judgment"],
533 | }
534 | return new_gamekey, new_result
535 |
536 |
537 | def normalize_game_key_dict(judgment_dict):
538 | """Make the model names sorted in the game keys."""
539 | ret = {}
540 | for key, value in judgment_dict.items():
541 | new_key, new_value = normalize_game_key_single(key, value)
542 | ret[new_key] = new_value
543 | return ret
544 |
545 |
546 | def load_pairwise_model_judgments(filename: str):
547 | """Load model judgments.
548 |
549 | The return value is a dict of type:
550 | Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict]
551 | """
552 | judge_dict = {}
553 |
554 | for line in open(filename):
555 | obj = json.loads(line)
556 | judge = tuple(obj["judge"])
557 | qid, model_1, model_2 = obj["question_id"], obj["model_1"], obj["model_2"]
558 |
559 | if judge not in judge_dict:
560 | judge_dict[judge] = {}
561 |
562 | if "winner" in obj:
563 | winner = obj["winner"]
564 | elif "g1_winner" in obj and "g2_winner" in obj:
565 | g1_winner, g2_winner = obj["g1_winner"], obj["g2_winner"]
566 | if g1_winner == g2_winner:
567 | winner = g1_winner
568 | else:
569 | winner = "inconsistent"
570 | else:
571 | raise ValueError(f"Invalid keys: {list(obj.keys())}")
572 |
573 | gamekey = (qid, model_1, model_2)
574 | winners = (winner,)
575 |
576 | judge_dict[judge][gamekey] = {
577 | "winners": winners,
578 | "g1_judgment": obj["g1_judgment"],
579 | "g2_judgment": obj["g2_judgment"],
580 | }
581 |
582 | # Make the model names sorted in the game keys
583 | normalized = {}
584 | for judge, value in judge_dict.items():
585 | normalized[judge] = normalize_game_key_dict(value)
586 | return normalized
587 |
588 |
589 | def load_single_model_judgments(filename: str):
590 | """Load model judgments.
591 |
592 | The return value is a dict of type:
593 | Dict[judge: Tuple -> Dict[game_key: tuple -> game_result: dict]
594 | """
595 | judge_dict = {}
596 |
597 | for line in open(filename):
598 | obj = json.loads(line)
599 | judge = tuple(obj["judge"])
600 | qid, model = obj["question_id"], obj["model"]
601 |
602 | if judge not in judge_dict:
603 | judge_dict[judge] = {}
604 |
605 | gamekey = (qid, model)
606 |
607 | judge_dict[judge][gamekey] = {
608 | "score": obj["score"],
609 | "judgment": obj["judgment"],
610 | }
611 | return judge_dict
612 |
613 |
614 | def resolve_pairwise_judgment_dict(
615 | question, model_judgments_normal, model_judgments_math, multi_turn=False
616 | ):
617 | """Return the correct pairwise judge."""
618 | if multi_turn:
619 | if question["category"] in NEED_REF_CATS:
620 | return model_judgments_math[("gpt-4", "pair-math-v1-multi-turn")]
621 | return model_judgments_normal[("gpt-4", "pair-v2-multi-turn")]
622 |
623 | if question["category"] in NEED_REF_CATS:
624 | return model_judgments_math[("gpt-4", "pair-math-v1")]
625 | else:
626 | return model_judgments_normal[("gpt-4", "pair-v2")]
627 |
628 |
629 | def resolve_single_judgment_dict(
630 | question, model_judgments_normal, model_judgments_math, multi_turn=False
631 | ):
632 | """Return the correct single answer grading judge."""
633 | if multi_turn:
634 | if question["category"] in NEED_REF_CATS:
635 | return model_judgments_math[("gpt-4", "single-math-v1-multi-turn")]
636 | return model_judgments_normal[("gpt-4", "single-v1-multi-turn")]
637 |
638 | if question["category"] in NEED_REF_CATS:
639 | return model_judgments_math[("gpt-4", "single-math-v1")]
640 | else:
641 | return model_judgments_normal[("gpt-4", "single-v1")]
642 |
643 |
644 | def get_pairwise_judge_explanation(gamekey, judgment_dict):
645 | """Get model judge explanation."""
646 | try:
647 | qid, model_1, model_2 = gamekey
648 | if model_1 < model_2:
649 | res = judgment_dict[gamekey]
650 | g1_judgment, g2_judgment = res["g1_judgment"], res["g2_judgment"]
651 | else:
652 | new_gamekey = (qid, model_2, model_1)
653 | res = judgment_dict[new_gamekey]
654 |
655 | model_1, model_2 = model_1, model_2
656 | g1_judgment, g2_judgment = res["g2_judgment"], res["g1_judgment"]
657 |
658 | return (
659 | f"**Game 1**. **A**: {model_1}, **B**: {model_2}\n\n"
660 | f"**Judgment**: {g1_judgment}"
661 | + f"\n\n`--------------------------`\n\n"
662 | + f"**Game 2**. **A**: {model_2}, **B**: {model_1}\n\n"
663 | f"**Judgment**: {g2_judgment}"
664 | )
665 | except KeyError:
666 | return "N/A"
667 |
668 |
669 | def get_single_judge_explanation(gamekey, judgment_dict):
670 | """Get model judge explanation."""
671 | try:
672 | qid, model = gamekey
673 |
674 | res = judgment_dict[gamekey]
675 |
676 | g1_judgment = res["judgment"]
677 | g1_score = res["score"]
678 |
679 | return (
680 | f"**Game 1**. **A**: {model}, **Score**: {g1_score}\n\n"
681 | f"**Judgment**: {g1_judgment}"
682 | )
683 | except KeyError:
684 | return "N/A"
685 |
686 |
687 | def check_data(questions, model_answers, ref_answers, models, judges):
688 | # check model answers
689 | for m in models:
690 | assert m in model_answers, f"Missing model answer for {m}"
691 | m_answer = model_answers[m]
692 | for q in questions:
693 | assert (
694 | q["question_id"] in m_answer
695 | ), f"Missing model {m}'s answer to Question {q['question_id']}"
696 | # check ref answers
697 | for jg in judges.values():
698 | if not jg.ref_based:
699 | continue
700 | for q in questions:
701 | if q["category"] not in NEED_REF_CATS:
702 | continue
703 | assert (
704 | q["question_id"] in ref_answers[jg.model_name]
705 | ), f"Missing reference answer to Question {q['question_id']} for judge {jg.model_name}"
706 |
707 |
708 | def get_model_list(answer_dir):
709 | file_paths = glob.glob(f"{answer_dir}/*.jsonl")
710 | file_names = [os.path.splitext(os.path.basename(f))[0] for f in file_paths]
711 | return file_names
712 |
--------------------------------------------------------------------------------
/application/gen_judgment.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all]
4 | """
5 | import argparse
6 | from concurrent.futures import ThreadPoolExecutor
7 | import json
8 |
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | from fastchat.llm_judge.common import (
13 | load_questions,
14 | load_model_answers,
15 | load_judge_prompts,
16 | check_data,
17 | play_a_match_pair,
18 | play_a_match_single,
19 | get_model_list,
20 | Judge,
21 | MatchPair,
22 | MatchSingle,
23 | NEED_REF_CATS,
24 | )
25 |
26 |
27 | def make_match(
28 | questions,
29 | models,
30 | model_answers,
31 | judge,
32 | baseline_model,
33 | ref_answers=None,
34 | multi_turn=False,
35 | ):
36 | matches = []
37 | for q in questions:
38 | if multi_turn and len(q["turns"]) != 2:
39 | continue
40 | for i in range(len(models)):
41 | q_id = q["question_id"]
42 | m_1 = models[i]
43 | m_2 = baseline_model
44 | if m_1 == m_2:
45 | continue
46 | a_1 = model_answers[m_1][q_id]
47 | a_2 = model_answers[baseline_model][q_id]
48 | if ref_answers is not None:
49 | ref = ref_answers[judge.model_name][q_id]
50 | match = MatchPair(
51 | dict(q),
52 | m_1,
53 | m_2,
54 | a_1,
55 | a_2,
56 | judge,
57 | ref_answer=ref,
58 | multi_turn=multi_turn,
59 | )
60 | else:
61 | match = MatchPair(
62 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn
63 | )
64 | matches.append(match)
65 | return matches
66 |
67 |
68 | def make_match_all_pairs(
69 | questions,
70 | models,
71 | model_answers,
72 | judge,
73 | baseline_model=None,
74 | ref_answers=None,
75 | multi_turn=False,
76 | ):
77 | matches = []
78 | for q in questions:
79 | if multi_turn and len(q["turns"]) != 2:
80 | continue
81 | for i in range(len(models)):
82 | for j in range(i + 1, len(models)):
83 | q_id = q["question_id"]
84 | m_1 = models[i]
85 | m_2 = models[j]
86 | a_1 = model_answers[m_1][q_id]
87 | a_2 = model_answers[m_2][q_id]
88 | if ref_answers is not None:
89 | ref = ref_answers[judge.model_name][q_id]
90 | match = MatchPair(
91 | dict(q),
92 | m_1,
93 | m_2,
94 | a_1,
95 | a_2,
96 | judge,
97 | ref_answer=ref,
98 | multi_turn=multi_turn,
99 | )
100 | else:
101 | match = MatchPair(
102 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn
103 | )
104 | matches.append(match)
105 | return matches
106 |
107 |
108 | def make_match_single(
109 | questions,
110 | models,
111 | model_answers,
112 | judge,
113 | baseline_model=None,
114 | ref_answers=None,
115 | multi_turn=False,
116 | ):
117 | matches = []
118 | for q in questions:
119 | if multi_turn and len(q["turns"]) != 2:
120 | continue
121 | for i in range(len(models)):
122 | q_id = q["question_id"]
123 | m = models[i]
124 | a = model_answers[m][q_id]
125 | if ref_answers is not None:
126 | ref = ref_answers[judge.model_name][q_id]
127 | matches.append(
128 | MatchSingle(
129 | dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn
130 | )
131 | )
132 | else:
133 | matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn))
134 | return matches
135 |
136 |
137 | def make_judge_pairwise(judge_model, judge_prompts):
138 | judges = {}
139 | judges["default"] = Judge(judge_model, judge_prompts["pair-v2"])
140 | judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True)
141 | judges["default-mt"] = Judge(
142 | judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True
143 | )
144 | judges["math-mt"] = Judge(
145 | judge_model,
146 | judge_prompts["pair-math-v1-multi-turn"],
147 | ref_based=True,
148 | multi_turn=True,
149 | )
150 | return judges
151 |
152 |
153 | def make_judge_single(judge_model, judge_prompts):
154 | judges = {}
155 | judges["default"] = Judge(judge_model, judge_prompts["single-v1"])
156 | judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True)
157 | judges["default-mt"] = Judge(
158 | judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True
159 | )
160 | judges["math-mt"] = Judge(
161 | judge_model,
162 | judge_prompts["single-math-v1-multi-turn"],
163 | ref_based=True,
164 | multi_turn=True,
165 | )
166 | return judges
167 |
168 |
169 | if __name__ == "__main__":
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument(
172 | "--bench-name",
173 | type=str,
174 | default="mt_bench",
175 | help="The name of the benchmark question set.",
176 | )
177 | parser.add_argument(
178 | "--judge-file",
179 | type=str,
180 | default="data/judge_prompts.jsonl",
181 | help="The file of judge prompts.",
182 | )
183 | parser.add_argument("--judge-model", type=str, default="gpt-4")
184 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo")
185 | parser.add_argument(
186 | "--mode",
187 | type=str,
188 | default="single",
189 | choices=["pairwise-baseline", "pairwise-all", "single"],
190 | help=(
191 | "Evaluation mode. "
192 | "`pairwise-baseline` runs pairwise comparision against a baseline. "
193 | "`pairwise-all` runs pairwise comparision between all pairs. "
194 | "`single` runs single answer grading."
195 | ),
196 | )
197 | parser.add_argument(
198 | "--model-list",
199 | type=str,
200 | nargs="+",
201 | default=None,
202 | help="A list of models to be evaluated",
203 | )
204 | parser.add_argument(
205 | "--parallel", type=int, default=1, help="The number of concurrent API calls."
206 | )
207 | parser.add_argument(
208 | "--first-n", type=int, help="A debug option. Only run the first `n` judgments."
209 | )
210 | args = parser.parse_args()
211 |
212 | question_file = f"data/{args.bench_name}/question.jsonl"
213 | answer_dir = f"data/{args.bench_name}/model_answer"
214 | ref_answer_dir = f"data/{args.bench_name}/reference_answer"
215 |
216 | # Load questions
217 | questions = load_questions(question_file, None, None)
218 |
219 | # Load answers
220 | model_answers = load_model_answers(answer_dir)
221 | ref_answers = load_model_answers(ref_answer_dir)
222 |
223 | # Load judge
224 | judge_prompts = load_judge_prompts(args.judge_file)
225 |
226 | if args.first_n:
227 | questions = questions[: args.first_n]
228 |
229 | if args.model_list is None:
230 | models = get_model_list(answer_dir)
231 | else:
232 | models = args.model_list
233 |
234 | if args.mode == "single":
235 | judges = make_judge_single(args.judge_model, judge_prompts)
236 | play_a_match_func = play_a_match_single
237 | output_file = (
238 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl"
239 | )
240 | make_match_func = make_match_single
241 | baseline_model = None
242 | else:
243 | judges = make_judge_pairwise(args.judge_model, judge_prompts)
244 | play_a_match_func = play_a_match_pair
245 | output_file = (
246 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl"
247 | )
248 | if args.mode == "pairwise-all":
249 | make_match_func = make_match_all_pairs
250 | baseline_model = None
251 | else:
252 | make_match_func = make_match
253 | baseline_model = args.baseline_model
254 |
255 | check_data(questions, model_answers, ref_answers, models, judges)
256 |
257 | question_math = [q for q in questions if q["category"] in NEED_REF_CATS]
258 | question_default = [q for q in questions if q["category"] not in NEED_REF_CATS]
259 |
260 | # Make matches
261 | matches = []
262 | matches += make_match_func(
263 | question_default, models, model_answers, judges["default"], baseline_model
264 | )
265 | matches += make_match_func(
266 | question_math,
267 | models,
268 | model_answers,
269 | judges["math"],
270 | baseline_model,
271 | ref_answers,
272 | )
273 | matches += make_match_func(
274 | question_default,
275 | models,
276 | model_answers,
277 | judges["default-mt"],
278 | baseline_model,
279 | multi_turn=True,
280 | )
281 | matches += make_match_func(
282 | question_math,
283 | models,
284 | model_answers,
285 | judges["math-mt"],
286 | baseline_model,
287 | ref_answers,
288 | multi_turn=True,
289 | )
290 |
291 | match_stat = {}
292 | match_stat["bench_name"] = args.bench_name
293 | match_stat["mode"] = args.mode
294 | match_stat["judge"] = args.judge_model
295 | match_stat["baseline"] = baseline_model
296 | match_stat["model_list"] = models
297 | match_stat["total_num_questions"] = len(questions)
298 | match_stat["total_num_matches"] = len(matches)
299 | match_stat["output_path"] = output_file
300 |
301 | # Show match stats and prompt enter to continue
302 | print("Stats:")
303 | print(json.dumps(match_stat, indent=4))
304 | input("Press Enter to confirm...")
305 |
306 | # Play matches
307 | if args.parallel == 1:
308 | for match in tqdm(matches):
309 | play_a_match_func(match, output_file=output_file)
310 | else:
311 |
312 | def play_a_match_wrapper(match):
313 | play_a_match_func(match, output_file=output_file)
314 |
315 | np.random.seed(0)
316 | np.random.shuffle(matches)
317 |
318 | with ThreadPoolExecutor(args.parallel) as executor:
319 | for match in tqdm(
320 | executor.map(play_a_match_wrapper, matches), total=len(matches)
321 | ):
322 | pass
323 |
--------------------------------------------------------------------------------
/application/gen_model_answer_baseline.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | import argparse
7 | import json
8 | import os
9 | import random
10 | import time
11 | import transformers
12 |
13 | import shortuuid
14 | import torch
15 | from tqdm import tqdm
16 |
17 | from fastchat.llm_judge.common import load_questions, temperature_config
18 | from fastchat.model import load_model, get_conversation_template
19 | from fastchat.utils import str_to_torch_dtype
20 |
21 | import sys
22 |
23 | from prompt.model.kv_cache import initialize_past_key_values
24 | from prompt.model.model import AutoPromptDecoder, PromptDecoder, PromptConfig
25 | from prompt.model.modeling_llama_custom import LlamaForCausalLM as CustomLlamaForCausalLM
26 | from human_eval.data import write_jsonl, read_problems
27 | from datasets import load_dataset
28 | from zeus.monitor import ZeusMonitor
29 |
30 |
31 | def infer(input_ids, model, tokenizer, choices, max_steps = 512, temperature=0.7, posterior_threshold = 0.09, posterior_alpha = 0.3, sampling='greedy', max_new_token=1024):
32 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
33 | # Avoid modifying the input_ids in-place
34 | input_ids = input_ids.clone()
35 |
36 | if not hasattr(model, "inference_buffers"):
37 | print('Generate buffers')
38 | model.generate_dynamic_buffers(choices)
39 | # Initialize the past key and value states
40 | if hasattr(model, "past_key_values"):
41 | past_key_values = model.past_key_values
42 | past_key_values_data = model.past_key_values_data
43 | current_length_data = model.current_length_data
44 | # Reset the past key and value states
45 | current_length_data.zero_()
46 | else:
47 | print('Initialize past key values')
48 | (
49 | past_key_values,
50 | past_key_values_data,
51 | current_length_data,
52 | ) = initialize_past_key_values(model.base_model)
53 | model.past_key_values = past_key_values
54 | model.past_key_values_data = past_key_values_data
55 | model.current_length_data = current_length_data
56 |
57 | input_len = input_ids.shape[1]
58 | model.base_model.model.tree_mask = None
59 | model.base_model.model.vt_attention_mask = None
60 | model.base_model.model.prompt_token_indices = None
61 | outputs = model.base_model(input_ids, past_key_values = past_key_values, use_cache=True)
62 | new_token = 0
63 |
64 | for idx in range(max_steps):
65 | input_id = outputs.logits[:, -1:].argmax(dim=-1)
66 | outputs = model.base_model(input_id, use_cache=True, past_key_values = past_key_values)
67 | input_ids = torch.cat([input_ids, input_id], dim=-1)
68 | new_token += 1
69 |
70 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
71 | break
72 | if new_token > max_new_token:
73 | break
74 |
75 | return input_ids, new_token, idx
76 |
77 |
78 | def run_eval(
79 | model_path,
80 | model_id,
81 | question_file,
82 | question_begin,
83 | question_end,
84 | answer_file,
85 | max_new_token,
86 | num_choices,
87 | num_gpus_per_model,
88 | num_gpus_total,
89 | max_gpu_memory,
90 | dtype,
91 | revision,
92 | benchname,
93 | warmup,
94 | tree_length,
95 | temperature,
96 | posterior_threshold,
97 | posterior_alpha,
98 | sampling,
99 | gpu_power
100 | ):
101 | if benchname == 'mt_bench':
102 | questions = load_questions(question_file, question_begin, question_end)
103 | elif benchname == 'humaneval':
104 | questions = read_problems()
105 | questions = list(questions.values())[question_begin:question_end]
106 | elif benchname == 'alpaca_eval':
107 | questions = json.load(open(question_file))
108 | elif benchname == 'gsm8k':
109 | # only use the first 1000 questions from test set
110 | questions = load_dataset('gsm8k', 'main', streaming=False, split='test')['question'][:500]
111 | else:
112 | raise ValueError("Unknown benchmark name")
113 |
114 | # random shuffle the questions to balance the loading
115 | # random.shuffle(questions)
116 |
117 | # Split the question file into `num_gpus` files
118 | assert num_gpus_total % num_gpus_per_model == 0
119 | use_ray = num_gpus_total // num_gpus_per_model > 1
120 |
121 | if use_ray:
122 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
123 | get_model_answers
124 | ).remote
125 | else:
126 | get_answers_func = get_model_answers
127 |
128 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
129 | ans_handles = []
130 | for i in range(0, len(questions), chunk_size):
131 | ans_handles.append(
132 | get_answers_func(
133 | model_path,
134 | model_id,
135 | questions[i : i + chunk_size],
136 | answer_file,
137 | max_new_token,
138 | num_choices,
139 | num_gpus_per_model,
140 | max_gpu_memory,
141 | dtype=dtype,
142 | revision=revision,
143 | benchname=benchname,
144 | warmup=warmup,
145 | tree_length=tree_length,
146 | temperature=temperature,
147 | posterior_threshold=posterior_threshold,
148 | posterior_alpha=posterior_alpha,
149 | sampling=sampling,
150 | gpu_power=gpu_power
151 | )
152 | )
153 |
154 | if use_ray:
155 | ray.get(ans_handles)
156 |
157 |
158 | @torch.inference_mode()
159 | def get_model_answers(
160 | model_path,
161 | model_id,
162 | questions,
163 | answer_file,
164 | max_new_token,
165 | num_choices,
166 | num_gpus_per_model,
167 | max_gpu_memory,
168 | dtype,
169 | revision,
170 | benchname,
171 | warmup,
172 | tree_length,
173 | temperature,
174 | posterior_threshold,
175 | posterior_alpha,
176 | sampling,
177 | gpu_power
178 | ):
179 | model = AutoPromptDecoder.from_pretrained(
180 | model_path,
181 | low_cpu_mem_usage=True,
182 | # load_in_4bit=True,
183 | torch_dtype=torch.float16,
184 | # device_map="auto"
185 | )
186 | model.cuda()
187 | tokenizer = model.tokenizer
188 |
189 | model.eval()
190 | print('Check model training state:',model.training)
191 |
192 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
193 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
194 |
195 | # warmup
196 | for i in range(warmup):
197 | torch.manual_seed(0)
198 | question = questions[i]
199 | if benchname == 'mt_bench':
200 | num_turns = len(question["turns"])
201 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k':
202 | num_turns = 1
203 | conv = get_conversation_template(model_id)
204 | turns = []
205 | idxs = []
206 | new_tokens = []
207 | wall_time = []
208 | for j in range(num_turns):
209 | if benchname == 'mt_bench':
210 | qs = question["turns"][j]
211 | conv.append_message(conv.roles[0], qs)
212 | conv.append_message(conv.roles[1], None)
213 | prompt = conv.get_prompt()
214 | elif benchname == 'humaneval':
215 | qs = question["prompt"]
216 | prompt = qs
217 | elif benchname == 'alpaca_eval':
218 | conv = get_conversation_template(model_id)
219 | conv.messages = []
220 | conv.append_message(conv.roles[0], question["instruction"])
221 | conv.append_message(conv.roles[1], "")
222 | prompt = conv.get_prompt()
223 | elif benchname == 'gsm8k':
224 | qs = question
225 | conv.append_message(conv.roles[0], qs)
226 | conv.append_message(conv.roles[1], None)
227 | prompt = conv.get_prompt()
228 |
229 | input_ids = tokenizer([prompt]).input_ids
230 |
231 | # try:
232 | torch.cuda.synchronize()
233 | start_time = time.time()
234 | output_ids, new_token, idx = infer(
235 | torch.as_tensor(input_ids).cuda(),
236 | model,
237 | tokenizer,
238 | tree_length,
239 | temperature=0,
240 | posterior_threshold=posterior_threshold,
241 | posterior_alpha=posterior_alpha,
242 | sampling=sampling,
243 | max_new_token=max_new_token
244 | )
245 | torch.cuda.synchronize()
246 | total_time = time.time() - start_time
247 | if benchname == 'mt_bench':
248 | output_ids = output_ids[0][len(input_ids[0]) :]
249 | # be consistent with the template's stop_token_ids
250 | if conv.stop_token_ids:
251 | stop_token_ids_index = [
252 | i
253 | for i, id in enumerate(output_ids)
254 | if id in conv.stop_token_ids
255 | ]
256 | if len(stop_token_ids_index) > 0:
257 | output_ids = output_ids[: stop_token_ids_index[0]]
258 |
259 | output = tokenizer.decode(
260 | output_ids,
261 | spaces_between_special_tokens=False,
262 | )
263 |
264 | if conv.stop_str and output.find(conv.stop_str) > 0:
265 | output = output[: output.find(conv.stop_str)]
266 | for special_token in tokenizer.special_tokens_map.values():
267 | if isinstance(special_token, list):
268 | for special_tok in special_token:
269 | output = output.replace(special_tok, "")
270 | else:
271 | output = output.replace(special_token, "")
272 | output = output.strip()
273 |
274 | if conv.name == "xgen" and output.startswith("Assistant:"):
275 | output = output.replace("Assistant:", "", 1).strip()
276 | conv.messages[-1][-1] = output
277 |
278 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k':
279 | output = tokenizer.decode(
280 | output_ids[0].tolist(),
281 | spaces_between_special_tokens=False,
282 | )
283 | # except RuntimeError as e:
284 | # print(e)
285 | # print("ERROR question ID: ", question["question_id"])
286 | # output = "ERROR"
287 |
288 | turns.append(output)
289 | idxs.append(int(idx))
290 | new_tokens.append(int(new_token))
291 | wall_time.append(total_time)
292 | print('Warmup done', warmup, 'steps')
293 |
294 |
295 | for i, question in tqdm(enumerate(questions), total=len(questions)):
296 | if benchname == 'mt_bench':
297 | question_id = question["question_id"]
298 | num_turns = len(question["turns"])
299 | elif benchname == 'humaneval':
300 | question_id = question["task_id"]
301 | num_turns = 1
302 | elif benchname == 'alpaca_eval' or benchname == 'gsm8k':
303 | question_id = i
304 | num_turns = 1
305 |
306 | if "category" in question and question["category"] in temperature_config:
307 | if temperature is not None:
308 | print(f"Overwriting temperature with {temperature} from command line")
309 | temp = temperature
310 | else:
311 | print(f"Using temperature from config for category {question['category']}")
312 | temp = temperature_config[question["category"]]
313 | else:
314 | print(f"Unknown category, using default temperature 0.0")
315 | temp = 0.0
316 |
317 | choices = []
318 | for i in range(num_choices):
319 | torch.manual_seed(0)
320 | conv = get_conversation_template(model_id)
321 | turns = []
322 | idxs = []
323 | new_tokens = []
324 | wall_time = []
325 | power = []
326 | energy = []
327 | for j in range(num_turns):
328 | if benchname == 'mt_bench':
329 | qs = question["turns"][j]
330 | conv.append_message(conv.roles[0], qs)
331 | conv.append_message(conv.roles[1], None)
332 | prompt = conv.get_prompt()
333 | elif benchname == 'humaneval':
334 | qs = question["prompt"]
335 | prompt = qs
336 | elif benchname == 'alpaca_eval':
337 | conv.messages = []
338 | conv.append_message(conv.roles[0], question["instruction"])
339 | conv.append_message(conv.roles[1], "")
340 | prompt = conv.get_prompt()
341 | elif benchname == 'gsm8k':
342 | qs = question
343 | conv.append_message(conv.roles[0], qs)
344 | conv.append_message(conv.roles[1], None)
345 | prompt = conv.get_prompt()
346 | input_ids = tokenizer([prompt]).input_ids
347 |
348 | try:
349 | torch.cuda.synchronize()
350 | if gpu_power:
351 | gpu_indices = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
352 | assert len(gpu_indices) == 1, "Only support single GPU for power measurement now"
353 | monitor = ZeusMonitor([0])
354 | monitor.begin_window("infer")
355 | start_time = time.time()
356 | output_ids, new_token, idx = infer(
357 | torch.as_tensor(input_ids).cuda(),
358 | model,
359 | tokenizer,
360 | tree_length,
361 | temperature=temp,
362 | posterior_threshold=posterior_threshold,
363 | posterior_alpha=posterior_alpha,
364 | sampling=sampling,
365 | max_new_token=max_new_token
366 | )
367 |
368 | torch.cuda.synchronize()
369 | if gpu_power:
370 | gpu_power_result = monitor.end_window("infer")
371 | total_time = time.time() - start_time
372 | if benchname == 'mt_bench':
373 |
374 | # if model.config.is_encoder_decoder:
375 | # output_ids = output_ids[0]
376 | # else:
377 | output_ids = output_ids[0][len(input_ids[0]) :]
378 |
379 | # be consistent with the template's stop_token_ids
380 | if conv.stop_token_ids:
381 | stop_token_ids_index = [
382 | i
383 | for i, id in enumerate(output_ids)
384 | if id in conv.stop_token_ids
385 | ]
386 | if len(stop_token_ids_index) > 0:
387 | output_ids = output_ids[: stop_token_ids_index[0]]
388 |
389 | output = tokenizer.decode(
390 | output_ids,
391 | spaces_between_special_tokens=False,
392 | )
393 | if conv.stop_str and output.find(conv.stop_str) > 0:
394 | output = output[: output.find(conv.stop_str)]
395 | for special_token in tokenizer.special_tokens_map.values():
396 | if isinstance(special_token, list):
397 | for special_tok in special_token:
398 | output = output.replace(special_tok, "")
399 | else:
400 | output = output.replace(special_token, "")
401 | output = output.strip()
402 |
403 | if conv.name == "xgen" and output.startswith("Assistant:"):
404 | output = output.replace("Assistant:", "", 1).strip()
405 | conv.messages[-1][-1] = output
406 |
407 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k':
408 | output = tokenizer.decode(
409 | output_ids[0].tolist(),
410 | spaces_between_special_tokens=False,
411 | )
412 |
413 | except RuntimeError as e:
414 | print("ERROR question ID: ", question["question_id"])
415 | print(e)
416 | output = "ERROR"
417 |
418 | turns.append(output)
419 | idxs.append(int(idx))
420 | new_tokens.append(int(new_token))
421 | wall_time.append(total_time)
422 | if gpu_power:
423 | power.append(gpu_power_result.energy[0] / gpu_power_result.time)
424 | energy.append(gpu_power_result.energy[0])
425 | # torch.cuda.empty_cache()
426 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time})
427 | if gpu_power:
428 | choices[-1]["power"] = power
429 | choices[-1]["energy"] = energy
430 |
431 | # Dump answers
432 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
433 | with open(os.path.expanduser(answer_file), "a") as fout:
434 | ans_json = {
435 | "question_id": question_id,
436 | "answer_id": shortuuid.uuid(),
437 | "model_id": model_id,
438 | "choices": choices,
439 | "tstamp": time.time(),
440 | }
441 | fout.write(json.dumps(ans_json) + "\n")
442 |
443 |
444 | def reorg_answer_file(answer_file):
445 | """Sort by question id and de-duplication"""
446 | answers = {}
447 | with open(answer_file, "r") as fin:
448 | for l in fin:
449 | qid = json.loads(l)["question_id"]
450 | answers[qid] = l
451 |
452 | qids = sorted(list(answers.keys()))
453 | with open(answer_file, "w") as fout:
454 | for qid in qids:
455 | fout.write(answers[qid])
456 |
457 |
458 | if __name__ == "__main__":
459 | parser = argparse.ArgumentParser()
460 | parser.add_argument(
461 | "--model-path",
462 | type=str,
463 | required=True,
464 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
465 | )
466 | parser.add_argument(
467 | "--model-id", type=str, required=True, help="A custom name for the model."
468 | )
469 | parser.add_argument(
470 | "--bench-name",
471 | type=str,
472 | default="mt_bench",
473 | help="The name of the benchmark question set.",
474 | )
475 | parser.add_argument(
476 | "--question-begin",
477 | type=int,
478 | help="A debug option. The begin index of questions.",
479 | )
480 | parser.add_argument(
481 | "--question-end", type=int, help="A debug option. The end index of questions."
482 | )
483 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
484 | parser.add_argument(
485 | "--max-new-token",
486 | type=int,
487 | default=1024,
488 | help="The maximum number of new generated tokens.",
489 | )
490 | parser.add_argument(
491 | "--num-choices",
492 | type=int,
493 | default=1,
494 | help="How many completion choices to generate.",
495 | )
496 | parser.add_argument(
497 | "--num-gpus-per-model",
498 | type=int,
499 | default=1,
500 | help="The number of GPUs per model.",
501 | )
502 | parser.add_argument(
503 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
504 | )
505 | parser.add_argument(
506 | "--max-gpu-memory",
507 | type=str,
508 | help="Maxmum GPU memory used for model weights per GPU.",
509 | )
510 | parser.add_argument(
511 | "--dtype",
512 | type=str,
513 | choices=["float32", "float16", "bfloat16"],
514 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
515 | default=None,
516 | )
517 | parser.add_argument(
518 | "--revision",
519 | type=str,
520 | default="main",
521 | help="The model revision to load.",
522 | )
523 | parser.add_argument(
524 | "--warmup",
525 | type=int,
526 | default=3,
527 | help="The number of warmup steps.",
528 | )
529 | parser.add_argument(
530 | "--tree-length",
531 | type=str,
532 | default="75",
533 | help="The choices for sampling.",
534 | )
535 | parser.add_argument(
536 | "--temperature",
537 | type=float,
538 | default=None,
539 | help="The temperature for sampling.",
540 | )
541 | parser.add_argument(
542 | "--posterior_threshold",
543 | type=float,
544 | default=0.09,
545 | help="The threshold for posterior sampling.",
546 | )
547 | parser.add_argument(
548 | "--posterior_alpha",
549 | type=float,
550 | default=0.3,
551 | help="The alpha for posterior sampling.",
552 | )
553 | parser.add_argument(
554 | "--sampling",
555 | type=str,
556 | default='greedy',
557 | help="The sampling method for decoding."
558 | )
559 | parser.add_argument(
560 | "--gpu-power",
561 | action='store_true',
562 | default=False,
563 | help="Whether to measure power consumption."
564 | )
565 |
566 | args = parser.parse_args()
567 |
568 | if 'vicuna-13b' in args.model_path.lower():
569 | if '-2-' in args.model_path:
570 | from prompt.inference.dynamic_sparse_trees_2_13b import *
571 | print('Using 13b sparse trees')
572 | elif '-3-' in args.model_path:
573 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import *
574 | print('Using 13b 3-1 sparse trees')
575 | else:
576 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import *
577 | print('Using 13b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-')
578 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length)
579 | elif 'vicuna-7b' in args.model_path.lower():
580 | if '-2-' in args.model_path:
581 | from prompt.inference.dynamic_sparse_trees_2_7b import *
582 | print('Using 7b 2-1 sparse trees')
583 | elif '-3-' in args.model_path:
584 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import *
585 | print('Using 7b 3-1 sparse trees')
586 | else:
587 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import *
588 | print('Using 7b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-')
589 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length)
590 | elif 'mobilellama' in args.model_path.lower():
591 | from prompt.inference.dynamic_sparse_trees_3_MobileLLaMA import *
592 | print('Using MobileLLaMA 3-1 sparse trees')
593 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length)
594 | else:
595 | raise ValueError("Unknown model path")
596 |
597 | if args.num_gpus_total // args.num_gpus_per_model > 1:
598 | import ray
599 |
600 | ray.init()
601 |
602 | question_file = None
603 | if args.bench_name == 'mt_bench':
604 | question_file = f"data/{args.bench_name}/question.jsonl"
605 | elif args.bench_name == 'alpaca_eval':
606 | question_file = f"data/{args.bench_name}/alpaca_eval.json"
607 | if args.answer_file:
608 | answer_file = args.answer_file
609 | else:
610 | answer_file_name = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)+"-sampling-"+args.sampling
611 | answer_file = f"data/{args.bench_name}/model_answer/{answer_file_name}.jsonl"
612 |
613 | print(f"Output to {answer_file}")
614 | run_eval(
615 | model_path=args.model_path,
616 | model_id=args.model_id,
617 | question_file=question_file,
618 | question_begin=args.question_begin,
619 | question_end=args.question_end,
620 | answer_file=answer_file,
621 | max_new_token=args.max_new_token,
622 | num_choices=args.num_choices,
623 | num_gpus_per_model=args.num_gpus_per_model,
624 | num_gpus_total=args.num_gpus_total,
625 | max_gpu_memory=args.max_gpu_memory,
626 | dtype=str_to_torch_dtype(args.dtype),
627 | revision=args.revision,
628 | benchname=args.bench_name,
629 | warmup=args.warmup,
630 | tree_length=args.tree_length,
631 | temperature=args.temperature,
632 | posterior_threshold=args.posterior_threshold,
633 | posterior_alpha=args.posterior_alpha,
634 | sampling=args.sampling,
635 | gpu_power=args.gpu_power
636 | )
637 |
638 | reorg_answer_file(answer_file)
639 |
--------------------------------------------------------------------------------
/application/gen_model_answer_prompt_decoding.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | import argparse
7 | import json
8 | import os
9 | import random
10 | import time
11 | import transformers
12 |
13 | import shortuuid
14 | import torch
15 | from tqdm import tqdm
16 |
17 | from fastchat.llm_judge.common import load_questions, temperature_config
18 | from fastchat.model import load_model, get_conversation_template
19 | from fastchat.utils import str_to_torch_dtype
20 |
21 | import sys
22 |
23 | from prompt.model.kv_cache import initialize_past_key_values
24 | from prompt.model.model import AutoPromptDecoder, PromptDecoder, PromptConfig
25 | from prompt.model.modeling_llama_custom import LlamaForCausalLM as CustomLlamaForCausalLM
26 | from human_eval.data import write_jsonl, read_problems
27 | from datasets import load_dataset
28 | from zeus.monitor import ZeusMonitor
29 |
30 |
31 | def infer(input_ids, model, tokenizer, tree_length, max_steps = 512, temperature=0.7, posterior_threshold = 0.09, posterior_alpha = 0.3, sampling='greedy', max_new_token=1024):
32 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
33 | # Avoid modifying the input_ids in-place
34 | input_ids = input_ids.clone()
35 |
36 | if not hasattr(model, "inference_buffers"):
37 | print('Generate buffers')
38 | model.generate_dynamic_buffers(tree_length)
39 | # Initialize the past key and value states
40 | if hasattr(model, "past_key_values"):
41 | past_key_values = model.past_key_values
42 | past_key_values_data = model.past_key_values_data
43 | current_length_data = model.current_length_data
44 | # Reset the past key and value states
45 | current_length_data.zero_()
46 | else:
47 | (
48 | past_key_values,
49 | past_key_values_data,
50 | current_length_data,
51 | ) = initialize_past_key_values(model.base_model)
52 | model.past_key_values = past_key_values
53 | model.past_key_values_data = past_key_values_data
54 | model.current_length_data = current_length_data
55 |
56 | input_len = input_ids.shape[1]
57 | logits, prompt_logits = model.start_inference(input_ids, past_key_values, current_length_data)
58 | new_token = 0
59 |
60 | for idx in range(max_steps):
61 | candidates, tree_candidates_embeds = model.generate_candidates(
62 | logits,
63 | prompt_logits,
64 | temperature,
65 | posterior_threshold,
66 | posterior_alpha,
67 | sampling)
68 | logits, all_logits = model.tree_decoding(tree_candidates_embeds, past_key_values, input_ids)
69 | best_candidate, accept_length = model.evaluate_posterior(
70 | logits,
71 | candidates,
72 | temperature,
73 | posterior_threshold,
74 | posterior_alpha,
75 | sampling)
76 | input_ids, logits, prompt_logits, new_token = model.update_inference_inputs(
77 | input_ids,
78 | candidates,
79 | best_candidate,
80 | accept_length,
81 | logits,
82 | all_logits,
83 | new_token,
84 | past_key_values_data,
85 | current_length_data,
86 | )
87 |
88 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
89 | break
90 | if new_token > max_new_token:
91 | break
92 |
93 | return input_ids, new_token, idx
94 |
95 | def run_eval(
96 | model_path,
97 | model_id,
98 | question_file,
99 | question_begin,
100 | question_end,
101 | answer_file,
102 | max_new_token,
103 | num_choices,
104 | num_gpus_per_model,
105 | num_gpus_total,
106 | max_gpu_memory,
107 | dtype,
108 | revision,
109 | benchname,
110 | warmup,
111 | tree_length,
112 | temperature,
113 | posterior_threshold,
114 | posterior_alpha,
115 | sampling,
116 | gpu_power
117 | ):
118 | if benchname == 'mt_bench':
119 | questions = load_questions(question_file, question_begin, question_end)
120 | elif benchname == 'humaneval':
121 | questions = read_problems()
122 | questions = list(questions.values())[question_begin:question_end]
123 | elif benchname == 'alpaca_eval':
124 | questions = json.load(open(question_file))
125 | elif benchname == 'gsm8k':
126 | # only use the first 1000 questions from test set
127 | questions = load_dataset('gsm8k', 'main', streaming=False, split='test')['question'][:500]
128 | else:
129 | raise ValueError("Unknown benchmark name")
130 |
131 | # random shuffle the questions to balance the loading
132 | # random.shuffle(questions)
133 |
134 | # Split the question file into `num_gpus` files
135 | assert num_gpus_total % num_gpus_per_model == 0
136 | use_ray = num_gpus_total // num_gpus_per_model > 1
137 |
138 | if use_ray:
139 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
140 | get_model_answers
141 | ).remote
142 | else:
143 | get_answers_func = get_model_answers
144 |
145 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model)
146 | ans_handles = []
147 | for i in range(0, len(questions), chunk_size):
148 | ans_handles.append(
149 | get_answers_func(
150 | model_path,
151 | model_id,
152 | questions[i : i + chunk_size],
153 | answer_file,
154 | max_new_token,
155 | num_choices,
156 | num_gpus_per_model,
157 | max_gpu_memory,
158 | dtype=dtype,
159 | revision=revision,
160 | benchname=benchname,
161 | warmup=warmup,
162 | tree_length=tree_length,
163 | temperature=temperature,
164 | posterior_threshold=posterior_threshold,
165 | posterior_alpha=posterior_alpha,
166 | sampling=sampling,
167 | gpu_power=gpu_power
168 | )
169 | )
170 |
171 | if use_ray:
172 | ray.get(ans_handles)
173 |
174 |
175 | @torch.inference_mode()
176 | def get_model_answers(
177 | model_path,
178 | model_id,
179 | questions,
180 | answer_file,
181 | max_new_token,
182 | num_choices,
183 | num_gpus_per_model,
184 | max_gpu_memory,
185 | dtype,
186 | revision,
187 | benchname,
188 | warmup,
189 | tree_length,
190 | temperature,
191 | posterior_threshold,
192 | posterior_alpha,
193 | sampling,
194 | gpu_power
195 | ):
196 | model = AutoPromptDecoder.from_pretrained(
197 | model_path,
198 | low_cpu_mem_usage=True,
199 | # load_in_4bit=True,
200 | torch_dtype=torch.float16,
201 | # device_map="auto"
202 | )
203 | model.cuda()
204 | tokenizer = model.tokenizer
205 |
206 | model.eval()
207 | print('Check model training state:',model.training)
208 |
209 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
210 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
211 |
212 | # warmup
213 | for i in range(warmup):
214 | torch.manual_seed(0)
215 | question = questions[i]
216 | if benchname == 'mt_bench':
217 | num_turns = len(question["turns"])
218 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k':
219 | num_turns = 1
220 | conv = get_conversation_template(model_id)
221 | turns = []
222 | idxs = []
223 | new_tokens = []
224 | wall_time = []
225 | for j in range(num_turns):
226 | if benchname == 'mt_bench':
227 | qs = question["turns"][j]
228 | conv.append_message(conv.roles[0], qs)
229 | conv.append_message(conv.roles[1], None)
230 | prompt = conv.get_prompt()
231 | elif benchname == 'humaneval':
232 | qs = question["prompt"]
233 | prompt = qs
234 | elif benchname == 'alpaca_eval':
235 | conv = get_conversation_template(model_id)
236 | conv.messages = []
237 | conv.append_message(conv.roles[0], question["instruction"])
238 | conv.append_message(conv.roles[1], "")
239 | prompt = conv.get_prompt()
240 | elif benchname == 'gsm8k':
241 | qs = question
242 | conv.append_message(conv.roles[0], qs)
243 | conv.append_message(conv.roles[1], None)
244 | prompt = conv.get_prompt()
245 |
246 | input_ids = tokenizer([prompt]).input_ids
247 |
248 | # try:
249 | torch.cuda.synchronize()
250 | start_time = time.time()
251 | output_ids, new_token, idx = infer(
252 | torch.as_tensor(input_ids).cuda(),
253 | model,
254 | tokenizer,
255 | tree_length,
256 | temperature=0,
257 | posterior_threshold=posterior_threshold,
258 | posterior_alpha=posterior_alpha,
259 | sampling=sampling,
260 | max_new_token=max_new_token
261 | )
262 | torch.cuda.synchronize()
263 | total_time = time.time() - start_time
264 | if benchname == 'mt_bench':
265 | output_ids = output_ids[0][len(input_ids[0]) :]
266 | # be consistent with the template's stop_token_ids
267 | if conv.stop_token_ids:
268 | stop_token_ids_index = [
269 | i
270 | for i, id in enumerate(output_ids)
271 | if id in conv.stop_token_ids
272 | ]
273 | if len(stop_token_ids_index) > 0:
274 | output_ids = output_ids[: stop_token_ids_index[0]]
275 |
276 | output = tokenizer.decode(
277 | output_ids,
278 | spaces_between_special_tokens=False,
279 | )
280 |
281 | if conv.stop_str and output.find(conv.stop_str) > 0:
282 | output = output[: output.find(conv.stop_str)]
283 | for special_token in tokenizer.special_tokens_map.values():
284 | if isinstance(special_token, list):
285 | for special_tok in special_token:
286 | output = output.replace(special_tok, "")
287 | else:
288 | output = output.replace(special_token, "")
289 | output = output.strip()
290 |
291 | if conv.name == "xgen" and output.startswith("Assistant:"):
292 | output = output.replace("Assistant:", "", 1).strip()
293 | conv.messages[-1][-1] = output
294 |
295 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k':
296 | output = tokenizer.decode(
297 | output_ids[0].tolist(),
298 | spaces_between_special_tokens=False,
299 | )
300 | # except RuntimeError as e:
301 | # print(e)
302 | # print("ERROR question ID: ", question["question_id"])
303 | # output = "ERROR"
304 |
305 | turns.append(output)
306 | idxs.append(int(idx))
307 | new_tokens.append(int(new_token))
308 | wall_time.append(total_time)
309 | print('Warmup done', warmup, 'steps')
310 |
311 |
312 | for i, question in tqdm(enumerate(questions), total=len(questions)):
313 | if benchname == 'mt_bench':
314 | question_id = question["question_id"]
315 | num_turns = len(question["turns"])
316 | elif benchname == 'humaneval':
317 | question_id = question["task_id"]
318 | num_turns = 1
319 | elif benchname == 'alpaca_eval' or benchname == 'gsm8k':
320 | question_id = i
321 | num_turns = 1
322 |
323 | if "category" in question and question["category"] in temperature_config:
324 | if temperature is not None:
325 | print(f"Overwriting temperature with {temperature} from command line")
326 | temp = temperature
327 | else:
328 | print(f"Using temperature from config for category {question['category']}")
329 | temp = temperature_config[question["category"]]
330 | else:
331 | print(f"Unknown category, using default temperature 0.0")
332 | temp = 0.0
333 |
334 | choices = []
335 | for i in range(num_choices):
336 | torch.manual_seed(0)
337 | conv = get_conversation_template(model_id)
338 | turns = []
339 | idxs = []
340 | new_tokens = []
341 | wall_time = []
342 | power = []
343 | energy = []
344 | for j in range(num_turns):
345 | if benchname == 'mt_bench':
346 | qs = question["turns"][j]
347 | conv.append_message(conv.roles[0], qs)
348 | conv.append_message(conv.roles[1], None)
349 | prompt = conv.get_prompt()
350 | elif benchname == 'humaneval':
351 | qs = question["prompt"]
352 | prompt = qs
353 | elif benchname == 'alpaca_eval':
354 | conv.messages = []
355 | conv.append_message(conv.roles[0], question["instruction"])
356 | conv.append_message(conv.roles[1], "")
357 | prompt = conv.get_prompt()
358 | elif benchname == 'gsm8k':
359 | qs = question
360 | conv.append_message(conv.roles[0], qs)
361 | conv.append_message(conv.roles[1], None)
362 | prompt = conv.get_prompt()
363 | input_ids = tokenizer([prompt]).input_ids
364 |
365 | try:
366 | torch.cuda.synchronize()
367 | if gpu_power:
368 | gpu_indices = [int(i) for i in os.environ['CUDA_VISIBLE_DEVICES'].split(',')]
369 | assert len(gpu_indices) == 1, "Only support single GPU for power measurement now"
370 | monitor = ZeusMonitor([0])
371 | monitor.begin_window("infer")
372 | start_time = time.time()
373 | output_ids, new_token, idx = infer(
374 | torch.as_tensor(input_ids).cuda(),
375 | model,
376 | tokenizer,
377 | tree_length,
378 | temperature=temp,
379 | posterior_threshold=posterior_threshold,
380 | posterior_alpha=posterior_alpha,
381 | sampling=sampling,
382 | max_new_token=max_new_token
383 | )
384 |
385 | torch.cuda.synchronize()
386 | if gpu_power:
387 | gpu_power_result = monitor.end_window("infer")
388 | total_time = time.time() - start_time
389 | if benchname == 'mt_bench':
390 |
391 | # if model.config.is_encoder_decoder:
392 | # output_ids = output_ids[0]
393 | # else:
394 | output_ids = output_ids[0][len(input_ids[0]) :]
395 |
396 | # be consistent with the template's stop_token_ids
397 | if conv.stop_token_ids:
398 | stop_token_ids_index = [
399 | i
400 | for i, id in enumerate(output_ids)
401 | if id in conv.stop_token_ids
402 | ]
403 | if len(stop_token_ids_index) > 0:
404 | output_ids = output_ids[: stop_token_ids_index[0]]
405 |
406 | output = tokenizer.decode(
407 | output_ids,
408 | spaces_between_special_tokens=False,
409 | )
410 | if conv.stop_str and output.find(conv.stop_str) > 0:
411 | output = output[: output.find(conv.stop_str)]
412 | for special_token in tokenizer.special_tokens_map.values():
413 | if isinstance(special_token, list):
414 | for special_tok in special_token:
415 | output = output.replace(special_tok, "")
416 | else:
417 | output = output.replace(special_token, "")
418 | output = output.strip()
419 |
420 | if conv.name == "xgen" and output.startswith("Assistant:"):
421 | output = output.replace("Assistant:", "", 1).strip()
422 | conv.messages[-1][-1] = output
423 |
424 | elif benchname == 'humaneval' or benchname == 'alpaca_eval' or benchname == 'gsm8k':
425 | output = tokenizer.decode(
426 | output_ids[0].tolist(),
427 | spaces_between_special_tokens=False,
428 | )
429 |
430 | except RuntimeError as e:
431 | print("ERROR question ID: ", question["question_id"])
432 | print(e)
433 | output = "ERROR"
434 |
435 | turns.append(output)
436 | idxs.append(int(idx))
437 | new_tokens.append(int(new_token))
438 | wall_time.append(total_time)
439 | if gpu_power:
440 | power.append(gpu_power_result.energy[0] / gpu_power_result.time)
441 | energy.append(gpu_power_result.energy[0])
442 | # torch.cuda.empty_cache()
443 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time})
444 | if gpu_power:
445 | choices[-1]["power"] = power
446 | choices[-1]["energy"] = energy
447 |
448 | # Dump answers
449 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
450 | with open(os.path.expanduser(answer_file), "a") as fout:
451 | ans_json = {
452 | "question_id": question_id,
453 | "answer_id": shortuuid.uuid(),
454 | "model_id": model_id,
455 | "choices": choices,
456 | "tstamp": time.time(),
457 | }
458 | fout.write(json.dumps(ans_json) + "\n")
459 |
460 |
461 | def reorg_answer_file(answer_file):
462 | """Sort by question id and de-duplication"""
463 | answers = {}
464 | with open(answer_file, "r") as fin:
465 | for l in fin:
466 | qid = json.loads(l)["question_id"]
467 | answers[qid] = l
468 |
469 | qids = sorted(list(answers.keys()))
470 | with open(answer_file, "w") as fout:
471 | for qid in qids:
472 | fout.write(answers[qid])
473 |
474 |
475 | if __name__ == "__main__":
476 | parser = argparse.ArgumentParser()
477 | parser.add_argument(
478 | "--model-path",
479 | type=str,
480 | required=True,
481 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
482 | )
483 | parser.add_argument(
484 | "--model-id", type=str, required=True, help="A custom name for the model."
485 | )
486 | parser.add_argument(
487 | "--bench-name",
488 | type=str,
489 | default="mt_bench",
490 | help="The name of the benchmark question set.",
491 | )
492 | parser.add_argument(
493 | "--question-begin",
494 | type=int,
495 | help="A debug option. The begin index of questions.",
496 | )
497 | parser.add_argument(
498 | "--question-end", type=int, help="A debug option. The end index of questions."
499 | )
500 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
501 | parser.add_argument(
502 | "--max-new-token",
503 | type=int,
504 | default=1024,
505 | help="The maximum number of new generated tokens.",
506 | )
507 | parser.add_argument(
508 | "--num-choices",
509 | type=int,
510 | default=1,
511 | help="How many completion choices to generate.",
512 | )
513 | parser.add_argument(
514 | "--num-gpus-per-model",
515 | type=int,
516 | default=1,
517 | help="The number of GPUs per model.",
518 | )
519 | parser.add_argument(
520 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
521 | )
522 | parser.add_argument(
523 | "--max-gpu-memory",
524 | type=str,
525 | help="Maxmum GPU memory used for model weights per GPU.",
526 | )
527 | parser.add_argument(
528 | "--dtype",
529 | type=str,
530 | choices=["float32", "float16", "bfloat16"],
531 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.",
532 | default=None,
533 | )
534 | parser.add_argument(
535 | "--revision",
536 | type=str,
537 | default="main",
538 | help="The model revision to load.",
539 | )
540 | parser.add_argument(
541 | "--warmup",
542 | type=int,
543 | default=3,
544 | help="The number of warmup steps.",
545 | )
546 | parser.add_argument(
547 | "--tree-length",
548 | type=str,
549 | default="63",
550 | help="The choices for sampling.",
551 | )
552 | parser.add_argument(
553 | "--temperature",
554 | type=float,
555 | default=None,
556 | help="The temperature for sampling.",
557 | )
558 | parser.add_argument(
559 | "--posterior_threshold",
560 | type=float,
561 | default=0.09,
562 | help="The threshold for posterior sampling.",
563 | )
564 | parser.add_argument(
565 | "--posterior_alpha",
566 | type=float,
567 | default=0.3,
568 | help="The alpha for posterior sampling.",
569 | )
570 | parser.add_argument(
571 | "--sampling",
572 | type=str,
573 | default='greedy',
574 | help="The sampling method for decoding."
575 | )
576 | parser.add_argument(
577 | "--gpu-power",
578 | action='store_true',
579 | default=False,
580 | help="Whether to measure power consumption."
581 | )
582 |
583 | args = parser.parse_args()
584 |
585 | if 'vicuna-13b' in args.model_path.lower():
586 | if '-2-' in args.model_path:
587 | from prompt.inference.dynamic_sparse_trees_2_13b import *
588 | print('Using 13b sparse trees')
589 | elif '-3-' in args.model_path:
590 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import *
591 | print('Using 13b 3-1 sparse trees')
592 | else:
593 | from prompt.inference.dynamic_sparse_trees_3_vicuna_13b import *
594 | print('Using 13b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-')
595 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length)
596 | elif 'vicuna-7b' in args.model_path.lower():
597 | if '-2-' in args.model_path:
598 | from prompt.inference.dynamic_sparse_trees_2_7b import *
599 | print('Using 7b 2-1 sparse trees')
600 | elif '-3-' in args.model_path:
601 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import *
602 | print('Using 7b 3-1 sparse trees')
603 | else:
604 | from prompt.inference.dynamic_sparse_trees_3_vicuna_7b import *
605 | print('Using 7b 3-1 sparse trees, this is the default because the model path does not contain -2- or -3-')
606 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length)
607 | elif 'mobilellama' in args.model_path.lower():
608 | from prompt.inference.dynamic_sparse_trees_3_MobileLLaMA import *
609 | print('Using MobileLLaMA 3-1 sparse trees')
610 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length)
611 | elif 'vicuna-68m' in args.model_path.lower():
612 | from prompt.inference.dynamic_sparse_trees_3_vicuna_68m import *
613 | print('Using Vicuna-68m sparse trees')
614 | args.tree_length = eval("dynamic_sparse_trees_" + args.tree_length)
615 | else:
616 | raise ValueError("Unknown model path")
617 | if args.num_gpus_total // args.num_gpus_per_model > 1:
618 | import ray
619 |
620 | ray.init()
621 |
622 | question_file = None
623 | if args.bench_name == 'mt_bench':
624 | question_file = f"data/{args.bench_name}/question.jsonl"
625 | elif args.bench_name == 'alpaca_eval':
626 | question_file = f"data/{args.bench_name}/alpaca_eval.json"
627 | if args.answer_file:
628 | answer_file = args.answer_file
629 | else:
630 | answer_file_name = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)+"-sampling-"+args.sampling
631 | answer_file = f"data/{args.bench_name}/model_answer/{answer_file_name}.jsonl"
632 |
633 | print(f"Output to {answer_file}")
634 | run_eval(
635 | model_path=args.model_path,
636 | model_id=args.model_id,
637 | question_file=question_file,
638 | question_begin=args.question_begin,
639 | question_end=args.question_end,
640 | answer_file=answer_file,
641 | max_new_token=args.max_new_token,
642 | num_choices=args.num_choices,
643 | num_gpus_per_model=args.num_gpus_per_model,
644 | num_gpus_total=args.num_gpus_total,
645 | max_gpu_memory=args.max_gpu_memory,
646 | dtype=str_to_torch_dtype(args.dtype),
647 | revision=args.revision,
648 | benchname=args.bench_name,
649 | warmup=args.warmup,
650 | tree_length=args.tree_length,
651 | temperature=args.temperature,
652 | posterior_threshold=args.posterior_threshold,
653 | posterior_alpha=args.posterior_alpha,
654 | sampling=args.sampling,
655 | gpu_power=args.gpu_power
656 | )
657 |
658 | reorg_answer_file(answer_file)
659 |
--------------------------------------------------------------------------------
/application/get_throughput_results.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import pandas as pd
4 | import os
5 | import argparse
6 |
7 | def get_throughput_results(input_file):
8 | with open(input_file, "r") as f:
9 | data = [json.loads(line) for line in f]
10 |
11 | new_tokens = []
12 | wall_time = []
13 | throughputs = []
14 | accept_lengths = []
15 | idxs = []
16 | for d in data:
17 | for choice in d["choices"]:
18 | new_tokens.extend(choice["new_tokens"])
19 | wall_time.extend(choice["wall_time"])
20 | for i in range(len(choice["new_tokens"])):
21 | throughputs.append(choice["new_tokens"][i] / choice["wall_time"][i])
22 | accept_lengths.append(choice["new_tokens"][i] / (choice["idxs"][i]+1))
23 | idxs.extend([idx+1 for idx in choice["idxs"]])
24 |
25 | return sum(new_tokens) / sum(wall_time), sum(throughputs) / len(throughputs), sum(new_tokens) / sum(idxs), sum(wall_time) / sum(idxs)
26 |
27 |
28 | def get_tree_latency(input_file):
29 | with open(input_file, "r") as f:
30 | data = [json.loads(line) for line in f]
31 |
32 | latencies = {}
33 | for d in data:
34 | latencies[d['tree_length']] = sum(d['choices'][0]['wall_time'])
35 |
36 | return latencies
37 |
38 |
39 | def get_gpu_power(input_file):
40 | with open(input_file, "r") as f:
41 | data = [json.loads(line) for line in f]
42 |
43 | gpu_power = []
44 | gpu_energy = []
45 | new_tokens = []
46 | power_per_token = []
47 | energy_per_token = []
48 | for d in data:
49 | for choice in d["choices"]:
50 | gpu_power.extend(choice["power"])
51 | gpu_energy.extend(choice["energy"])
52 | new_tokens.extend(choice["new_tokens"])
53 | power_per_token.append(sum(choice["power"]) / sum(choice["new_tokens"]))
54 | energy_per_token.append(sum(choice["energy"]) / sum(choice["new_tokens"]))
55 |
56 | return sum(gpu_power) / sum(new_tokens), \
57 | sum(gpu_power) / len(gpu_power), \
58 | sum(gpu_energy) / sum(new_tokens), \
59 | sum(gpu_energy) / len(gpu_energy)
60 |
61 |
62 | def print_average_throughputs(input_files):
63 | throughputs1, throughputs2, accepth_lengths, forward_pass_time = list(zip(*[get_throughput_results(input_file) for input_file in input_files]))
64 | print(f"Macro-Average throughput: {sum(throughputs1) / len(throughputs1):.3f} tokens/s")
65 | print(f"std: {pd.Series(throughputs1).std():.3f}")
66 | print(f"Micro-Average throughput: {sum(throughputs2) / len(throughputs2):.3f} tokens/s")
67 | print(f"std: {pd.Series(throughputs2).std():.3f}")
68 | print(f"Average accept lengths: {sum(accepth_lengths) / len(accepth_lengths):.5f}")
69 | print(f"std: {pd.Series(accepth_lengths).std():.5f}")
70 | print(f"Average forward pass time: {sum(forward_pass_time) / len(forward_pass_time):.5f} s")
71 | print(f"std: {pd.Series(forward_pass_time).std():.5f}")
72 |
73 | return (sum(accepth_lengths) / len(accepth_lengths)) / (sum(forward_pass_time) / len(forward_pass_time))
74 |
75 |
76 | def print_gpu_power(input_files):
77 | power_per_token, power, energy_per_token, energy = list(zip(*[get_gpu_power(input_file) for input_file in input_files]))
78 | print(f"Power per token: {sum(power_per_token) / len(power_per_token):.3f} W/token")
79 | print(f"std: {pd.Series(power_per_token).std():.3f}")
80 | print(f"Total power: {sum(power) / 1000:.3f} W")
81 | print(f"std: {pd.Series(power).std():.3f}")
82 | print(f"Energy per token: {sum(energy_per_token) / len(energy_per_token):.3f} J/token")
83 | print(f"std: {pd.Series(energy_per_token).std():.3f}")
84 | print(f"Total energy: {sum(energy) / len(energy):.3f} J")
85 | print(f"std: {pd.Series(energy).std():.3f}")
86 |
87 |
88 |
89 | def parse_file_name(file_name):
90 | # file name is in the format prefix{ddd}_0.json or prefix{dd}_1.json or prefix{d}_2.json where d is a digit
91 | prefix = file_name.split("_")[-2]
92 | rst = 0
93 | for i in range(1, 4):
94 | if prefix[-i:].isdigit():
95 | rst = prefix[-i:]
96 | return int(rst)
97 |
98 |
99 | def main(args):
100 | input_files = args.input_files
101 | # if the input is a directory, iterate over all files in the directory
102 | if os.path.isdir(input_files[0]):
103 | input_files = [os.path.join(input_files[0], f) for f in os.listdir(input_files[0]) if 'tree_latency.jsonl' not in f]
104 | # group files by prefix. Assume that the files are named as prefix1_0.json, prefix1_1.json, prefix2_0.json, prefix2_1.json, etc. The resulting list should be prefix0_0.json, prefix0_1.json, prefix0_2.json, prefix1_0.json, prefix1_1.json, etc.
105 | input_files = sorted(input_files, key=lambda x: parse_file_name(x))
106 | # for each group, get the average throughput using print_average_throughputs
107 | max_accepth_lenght_to_forward_pass_time = 0
108 | best_tree = None
109 | for i in range(0, len(input_files), args.n):
110 | # print prefix
111 | print(">>>", input_files[i])
112 | if args.gpu_power:
113 | print_gpu_power(input_files[i:i+args.n])
114 | else:
115 | ratio = print_average_throughputs(input_files[i:i+args.n])
116 | if ratio > max_accepth_lenght_to_forward_pass_time:
117 | max_accepth_lenght_to_forward_pass_time = ratio
118 | best_tree = input_files[i].split("_")[-2]
119 | if not args.gpu_power:
120 | print(f"Best sparse tree: {best_tree}, ratio: {max_accepth_lenght_to_forward_pass_time}")
121 | elif 'tree_latency.jsonl' in input_files[0]:
122 | latencies = get_tree_latency(input_files[0])
123 | input_files = [os.path.join(input_files[1], f) for f in os.listdir(input_files[1]) if 'tree_latency.jsonl' not in f]
124 | input_files = sorted(input_files, key=lambda x: parse_file_name(x))
125 | # for each group, get the average throughput using print_average_throughputs
126 | max_accepth_lenght_to_forward_pass_time = 0
127 | best_tree = None
128 | for i in range(0, len(input_files), 1):
129 | # print prefix
130 | print(">>>", input_files[i])
131 | tree_length = parse_file_name(input_files[i])
132 | if tree_length not in latencies:
133 | continue
134 | _, _, accepth_lengths, _ = get_throughput_results(input_files[i])
135 | ratio = accepth_lengths / latencies[tree_length]
136 | if ratio > max_accepth_lenght_to_forward_pass_time:
137 | max_accepth_lenght_to_forward_pass_time = ratio
138 | best_tree = input_files[i].split("_")[-2]
139 | print("Ratio: ", ratio, 'Accept length: ', accepth_lengths, 'Latency: ', latencies[tree_length])
140 | print(f"Best sparse tree: {best_tree}, ratio: {max_accepth_lenght_to_forward_pass_time}")
141 | else:
142 | if args.gpu_power:
143 | print_gpu_power(input_files)
144 | else:
145 | print_average_throughputs(input_files)
146 |
147 |
148 | if __name__ == "__main__":
149 | # parse arguments
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument("input_files", nargs="+", help="Input files to get throughput results")
152 | parser.add_argument("--n", type=int, default=1, help="Number of files to group")
153 | parser.add_argument("--gpu-power", action="store_true", help="Get GPU power", default=False)
154 | args = parser.parse_args()
155 | main(args)
--------------------------------------------------------------------------------
/application/show_result.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 show_result.py --mode [single|pairwise-baseline|pairwise-all]
4 | """
5 | import argparse
6 | import pandas as pd
7 |
8 |
9 | def display_result_single(args):
10 | if args.input_file is None:
11 | input_file = (
12 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl"
13 | )
14 | else:
15 | input_file = args.input_file
16 |
17 | print(f"Input file: {input_file}")
18 | df_all = pd.read_json(input_file, lines=True)
19 | df = df_all[["model", "score", "turn"]]
20 | df = df[df["score"] != -1]
21 |
22 | if args.model_list is not None:
23 | df = df[df["model"].isin(args.model_list)]
24 |
25 | print("\n########## First turn ##########")
26 | df_1 = df[df["turn"] == 1].groupby(["model", "turn"]).mean()
27 | print(df_1.sort_values(by="score", ascending=False))
28 |
29 | if args.bench_name == "mt_bench":
30 | print("\n########## Second turn ##########")
31 | df_2 = df[df["turn"] == 2].groupby(["model", "turn"]).mean()
32 | print(df_2.sort_values(by="score", ascending=False))
33 |
34 | print("\n########## Average ##########")
35 | df_3 = df[["model", "score"]].groupby(["model"]).mean()
36 | print(df_3.sort_values(by="score", ascending=False))
37 |
38 |
39 | def display_result_pairwise(args):
40 | if args.input_file is None:
41 | input_file = (
42 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl"
43 | )
44 | else:
45 | input_file = args.input_file
46 |
47 | print(f"Input file: {input_file}")
48 | df_all = pd.read_json(input_file, lines=True)
49 | df_all = df_all[(df_all["g1_winner"] != "error") & (df_all["g2_winner"] != "error")]
50 |
51 | model_list = (
52 | df_all["model_1"].unique().tolist() + df_all["model_2"].unique().tolist()
53 | )
54 | model_list = list(set(model_list))
55 |
56 | list_res = []
57 | # traverse df row by row
58 | for index, row in df_all.iterrows():
59 | if args.model_list is not None and row["model_1"] not in args.model_list:
60 | continue
61 | if args.baseline_model is not None:
62 | if args.baseline_model not in [row["model_1"], row["model_2"]]:
63 | continue
64 | if row["g1_winner"] == "tie" or row["g1_winner"] != row["g2_winner"]:
65 | list_res.append({"model": row["model_1"], "win": 0, "loss": 0, "tie": 1})
66 | list_res.append({"model": row["model_2"], "win": 0, "loss": 0, "tie": 1})
67 | else:
68 | if row["g1_winner"] == "model_1":
69 | winner = row["model_1"]
70 | loser = row["model_2"]
71 | else:
72 | winner = row["model_2"]
73 | loser = row["model_1"]
74 | list_res.append({"model": winner, "win": 1, "loss": 0, "tie": 0})
75 | list_res.append({"model": loser, "win": 0, "loss": 1, "tie": 0})
76 |
77 | df = pd.DataFrame(list_res)
78 | df = df.groupby(["model"]).sum()
79 |
80 | # remove baseline model
81 | if args.baseline_model is not None:
82 | df = df[df.index != args.baseline_model]
83 | # add win rate
84 | df["win_rate"] = df["win"] / (df["win"] + df["loss"] + df["tie"])
85 | df["loss_rate"] = df["loss"] / (df["win"] + df["loss"] + df["tie"])
86 | # each tie counts as 0.5 win + 0.5 loss
87 | df["win_rate_adjusted"] = (df["win"] + 0.5 * df["tie"]) / (
88 | df["win"] + df["loss"] + df["tie"]
89 | )
90 | # print(df.sort_values(by="win_rate", ascending=False))
91 | # print(df.sort_values(by="loss_rate", ascending=True))
92 | print(df.sort_values(by="win_rate_adjusted", ascending=False))
93 |
94 |
95 | if __name__ == "__main__":
96 | parser = argparse.ArgumentParser()
97 | parser.add_argument("--bench-name", type=str, default="mt_bench")
98 | parser.add_argument("--input-file", type=str)
99 | parser.add_argument("--judge-model", type=str, default="gpt-4")
100 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo")
101 | parser.add_argument(
102 | "--model-list",
103 | type=str,
104 | nargs="+",
105 | default=None,
106 | help="A list of models to be evaluated",
107 | )
108 | parser.add_argument(
109 | "--mode",
110 | type=str,
111 | default="single",
112 | choices=["pairwise-baseline", "pairwise-all", "single"],
113 | help=(
114 | "Evaluation mode. "
115 | "`pairwise-baseline` runs pairwise comparision against a baseline. "
116 | "`pairwise-all` runs pairwise comparision between all pairs. "
117 | "`single` runs single answer grading."
118 | ),
119 | )
120 | args = parser.parse_args()
121 |
122 | if args.mode == "single":
123 | display_result_func = display_result_single
124 | else:
125 | if args.mode == "pairwise-all":
126 | args.baseline_model = None
127 | display_result_func = display_result_pairwise
128 |
129 | print(f"Mode: {args.mode}")
130 | display_result_func(args)
131 |
--------------------------------------------------------------------------------
/application/webui.py:
--------------------------------------------------------------------------------
1 | # adapted from https://github.com/SafeAILab/EAGLE/blob/d08fe3f23e5f1d986bb50f786af60e5c4f7f757e/eagle/application/webui.py#L4
2 | import os
3 | import time
4 |
5 | import gradio as gr
6 | import argparse
7 | from prompt.utils import *
8 | from prompt.model.model import PromptDecoder, AutoPromptDecoder, PromptConfig
9 | from prompt.model.kv_cache import *
10 | import torch
11 | from fastchat.model import get_conversation_template
12 | import re
13 |
14 |
15 | def truncate_list(lst, num):
16 | if num not in lst:
17 | return lst
18 |
19 |
20 | first_index = lst.index(num)
21 |
22 |
23 | return lst[:first_index + 1]
24 |
25 |
26 | def find_list_markers(text):
27 |
28 | pattern = re.compile(r'(?m)(^\d+\.\s|\n)')
29 | matches = pattern.finditer(text)
30 |
31 |
32 | return [(match.start(), match.end()) for match in matches]
33 |
34 |
35 | def checkin(pointer,start,marker):
36 | for b,e in marker:
37 | if b<=pointer{text[pointer:start]}"
63 |
64 | result += sub_text
65 |
66 | pointer = end
67 |
68 | if pointer < len(text):
69 | result += f"{text[pointer:]}"
70 |
71 | return result
72 |
73 |
74 | def warmup(model):
75 | conv = get_conversation_template('vicuna')
76 | conv.append_message(conv.roles[0], "Hello")
77 | conv.append_message(conv.roles[1], None)
78 | prompt = conv.get_prompt()
79 | input_ids = model.tokenizer([prompt]).input_ids
80 | input_ids = torch.as_tensor(input_ids).cuda()
81 | for output_ids in model.ppd_generate(input_ids):
82 | ol=output_ids.shape[1]
83 |
84 |
85 | def bot(history, temperature, use_ppd, highlight_ppd,session_state,):
86 | if not history:
87 | return history, "0.00 tokens/s", "0.00", session_state
88 | pure_history = session_state.get("pure_history", [])
89 | conv = get_conversation_template('vicuna')
90 |
91 | for query, response in pure_history:
92 | conv.append_message(conv.roles[0], query)
93 | conv.append_message(conv.roles[1], response)
94 |
95 | prompt = conv.get_prompt()
96 |
97 | input_ids = model.tokenizer([prompt]).input_ids
98 | input_ids = torch.as_tensor(input_ids).cuda()
99 | input_len = input_ids.shape[1]
100 | naive_text = []
101 | cu_len = input_len
102 | totaltime=0
103 | start_time=time.time()
104 | total_ids=0
105 | if use_ppd:
106 |
107 | for output_ids in model.ppd_generate(input_ids, temperature=temperature, max_steps=args.max_new_token):
108 | totaltime+=(time.time()-start_time)
109 | total_ids+=1
110 | decode_ids = output_ids[0, input_len:].tolist()
111 | decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id)
112 | text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False,
113 | clean_up_tokenization_spaces=True, )
114 | naive_text.append(model.tokenizer.decode(output_ids[0, cu_len], skip_special_tokens=True,
115 | spaces_between_special_tokens=False,
116 | clean_up_tokenization_spaces=True, ))
117 |
118 | cu_len = output_ids.shape[1]
119 | colored_text = highlight_text(text, naive_text, "orange")
120 | if highlight_ppd:
121 | history[-1][1] = colored_text
122 | else:
123 | history[-1][1] = text
124 | pure_history[-1][1] = text
125 | session_state["pure_history"] = pure_history
126 | new_tokens = cu_len-input_len
127 | yield history,f"{new_tokens/totaltime:.2f} tokens/s",f"{new_tokens/total_ids:.2f}",session_state
128 | start_time = time.time()
129 |
130 |
131 | else:
132 | for output_ids in model.naive_generate(input_ids, temperature=temperature, max_steps=args.max_new_token):
133 | totaltime += (time.time() - start_time)
134 | total_ids+=1
135 | decode_ids = output_ids[0, input_len:].tolist()
136 | decode_ids = truncate_list(decode_ids, model.tokenizer.eos_token_id)
137 | text = model.tokenizer.decode(decode_ids, skip_special_tokens=True, spaces_between_special_tokens=False,
138 | clean_up_tokenization_spaces=True, )
139 | naive_text.append(model.tokenizer.decode(output_ids[0, cu_len], skip_special_tokens=True,
140 | spaces_between_special_tokens=False,
141 | clean_up_tokenization_spaces=True, ))
142 | cu_len = output_ids.shape[1]
143 | colored_text = highlight_text(text, naive_text, "orange")
144 | if highlight_ppd and use_ppd:
145 | history[-1][1] = colored_text
146 | else:
147 | history[-1][1] = text
148 | history[-1][1] = text
149 | pure_history[-1][1] = text
150 | new_tokens = cu_len - input_len
151 | yield history,f"{new_tokens/totaltime:.2f} tokens/s",f"{new_tokens/total_ids:.2f}",session_state
152 | start_time = time.time()
153 |
154 |
155 | def user(user_message, history,session_state):
156 | if history==None:
157 | history=[]
158 | pure_history = session_state.get("pure_history", [])
159 | pure_history += [[user_message, None]]
160 | session_state["pure_history"] = pure_history
161 | return "", history + [[user_message, None]],session_state
162 |
163 |
164 | def regenerate(history,session_state):
165 | if not history:
166 | return history, None,"0.00 tokens/s","0.00",session_state
167 | pure_history = session_state.get("pure_history", [])
168 | pure_history[-1][-1] = None
169 | session_state["pure_history"]=pure_history
170 | if len(history) > 1: # Check if there's more than one entry in history (i.e., at least one bot response)
171 | new_history = history[:-1] # Remove the last bot response
172 | last_user_message = history[-1][0] # Get the last user message
173 | return new_history + [[last_user_message, None]], None,"0.00 tokens/s","0.00",session_state
174 | history[-1][1] = None
175 | return history, None,"0.00 tokens/s","0.00",session_state
176 |
177 |
178 | def clear(history,session_state):
179 | pure_history = session_state.get("pure_history", [])
180 | pure_history = []
181 | session_state["pure_history"] = pure_history
182 | return [],"0.00 tokens/s","0.00",session_state
183 |
184 |
185 | parser = argparse.ArgumentParser()
186 | parser.add_argument(
187 | "--ppd-path",
188 | type=str,
189 | default="hmarkc/ppd-vicuna-7b-v1.3",
190 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
191 | )
192 | parser.add_argument(
193 | "--load-in-8bit", action="store_true", help="Use 8-bit quantization"
194 | )
195 | parser.add_argument(
196 | "--load-in-4bit", action="store_true", help="Use 4-bit quantization"
197 | )
198 | parser.add_argument(
199 | "--max-new-token",
200 | type=int,
201 | default=512,
202 | help="The maximum number of new generated tokens.",
203 | )
204 | args = parser.parse_args()
205 |
206 | model = AutoPromptDecoder.from_pretrained(
207 | args.ppd_path,
208 | low_cpu_mem_usage=True,
209 | torch_dtype=torch.float16,
210 | )
211 | model.cuda()
212 | model.eval()
213 | warmup(model)
214 |
215 | custom_css = """
216 | #speed textarea {
217 | color: red;
218 | font-size: 30px;
219 | }"""
220 |
221 | with gr.Blocks(css=custom_css) as demo:
222 | gs = gr.State({"pure_history": []})
223 | gr.Markdown('''## PPD Chatbot''')
224 | with gr.Row():
225 | speed_box = gr.Textbox(label="Speed", elem_id="speed", interactive=False, value="0.00 tokens/s")
226 | compression_box = gr.Textbox(label="Compression Ratio", elem_id="speed", interactive=False, value="0.00")
227 |
228 | chatbot = gr.Chatbot(height=600,show_label=False)
229 |
230 |
231 | msg = gr.Textbox(label="Your input")
232 | with gr.Row():
233 | send_button = gr.Button("Send")
234 | stop_button = gr.Button("Stop")
235 | regenerate_button = gr.Button("Regenerate")
236 | clear_button = gr.Button("Clear")
237 |
238 | with gr.Row():
239 | with gr.Column():
240 | use_ppd = gr.Checkbox(label="Use PPD", value=True)
241 | highlight_ppd = gr.Checkbox(label="Highlight the tokens generated by PPD", value=True)
242 | temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="temperature", value=0.5)
243 | note=gr.Markdown(show_label=False,value='''The Compression Ratio is defined as the number of generated tokens divided by the number of forward passes in the original LLM. If "Highlight the tokens generated by PPD" is checked, the tokens correctly guessed by PPD
244 | will be displayed in orange. Note: Checking this option may cause special formatting rendering issues in a few cases, especially when generating code''')
245 | enter_event=msg.submit(user, [msg, chatbot,gs], [msg, chatbot,gs], queue=True).then(
246 | bot, [chatbot, temperature, use_ppd, highlight_ppd,gs], [chatbot,speed_box,compression_box,gs]
247 | )
248 | clear_button.click(clear, [chatbot,gs], [chatbot,speed_box,compression_box,gs], queue=True)
249 |
250 | send_event=send_button.click(user, [msg, chatbot,gs], [msg, chatbot,gs],queue=True).then(
251 | bot, [chatbot, temperature, use_ppd, highlight_ppd,gs], [chatbot,speed_box,compression_box,gs]
252 | )
253 | regenerate_event=regenerate_button.click(regenerate, [chatbot,gs], [chatbot, msg,speed_box,compression_box,gs],queue=True).then(
254 | bot, [chatbot, temperature, use_ppd, highlight_ppd,gs], [chatbot,speed_box,compression_box,gs]
255 | )
256 | stop_button.click(fn=None, inputs=None, outputs=None, cancels=[send_event,regenerate_event,enter_event])
257 | demo.queue()
258 | demo.launch(share=True)
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/Overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/Overview.png
--------------------------------------------------------------------------------
/assets/PPD_LOGO.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/PPD_LOGO.png
--------------------------------------------------------------------------------
/assets/Speed_Mem_Train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/Speed_Mem_Train.png
--------------------------------------------------------------------------------
/assets/latency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/latency.png
--------------------------------------------------------------------------------
/assets/ppd_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/assets/ppd_demo.gif
--------------------------------------------------------------------------------
/dataset_generation/generate_dataset.py:
--------------------------------------------------------------------------------
1 | from prompt.utils import *
2 | import argparse
3 | import math
4 | from transformers import LlamaForCausalLM
5 | from transformers import BitsAndBytesConfig
6 |
7 | def generate_fine_tune_dataset(args):
8 | tokenizer = transformers.AutoTokenizer.from_pretrained(
9 | args.model_name_or_path,
10 | model_max_length=args.model_max_length,
11 | padding_side="left",
12 | use_fast=False,
13 | truncation=True
14 | )
15 | tokenizer.pad_token = tokenizer.unk_token
16 | data = get_finetune_dataset(tokenizer=tokenizer, data_path=args.data_path, size=args.size, offset=args.num_special_tokens+1)
17 |
18 | torch.save(data, f"{args.save_path}_{args.num_special_tokens}_finetune_{args.model_max_length}.pt")
19 |
20 |
21 | def generate_self_distillation_dataset(args):
22 | # Set RoPE scaling factor
23 | config = transformers.AutoConfig.from_pretrained(args.model_name_or_path)
24 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
25 | if orig_ctx_len and args.model_max_length > orig_ctx_len:
26 | scaling_factor = float(math.ceil(args.model_max_length / orig_ctx_len))
27 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
28 | config.use_cache = False
29 |
30 | config = transformers.AutoConfig.from_pretrained(args.model_name_or_path)
31 |
32 | quantization_config = BitsAndBytesConfig(
33 | load_in_4bit=True,
34 | bnb_4bit_compute_dtype=torch.bfloat16,
35 | bnb_4bit_use_double_quant=True,
36 | bnb_4bit_quant_type="nf4",
37 | )
38 |
39 | if config.model_type == "llama":
40 | model = LlamaForCausalLM.from_pretrained(
41 | args.model_name_or_path,
42 | config=config,
43 | low_cpu_mem_usage=True,
44 | quantization_config=quantization_config,
45 | )
46 | else:
47 | raise ValueError("Only support llama for now")
48 | data = get_self_distillation_dataset(model=model, data_path=args.data_path, num_special_tokens=args.num_special_tokens)
49 |
50 | model_name = args.model_name_or_path.split("/")[-1]
51 | torch.save(data, f"{args.save_path}_{args.num_special_tokens}_{model_name}_distillation__{args.model_max_length}.pt")
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = argparse.ArgumentParser()
56 | parser.add_argument("--model_name_or_path", type=str, default="lmsys/vicuna-7b-v1.3")
57 | parser.add_argument("--model_max_length", type=int, default=2048)
58 | parser.add_argument("--save_path", type=str, default="data/ShareGPT_training_dataset")
59 | parser.add_argument("--data_path", type=str, default="data/ShareGPT_training_dataset_2.pt")
60 | parser.add_argument("--size", type=int, default=None)
61 | parser.add_argument("--num_special_tokens", type=int, default=2)
62 | parser.add_argument("--dataset_type", type=str, default="finetune", choices=["finetune", "distillation"])
63 | args = parser.parse_args()
64 |
65 | if args.dataset_type == "finetune":
66 | generate_fine_tune_dataset(args)
67 | elif args.dataset_type == "distillation":
68 | generate_self_distillation_dataset(args)
69 |
--------------------------------------------------------------------------------
/prompt/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/__init__.py
--------------------------------------------------------------------------------
/prompt/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/evaluation/__init__.py
--------------------------------------------------------------------------------
/prompt/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | import numpy as np
5 | from prompt.model.model import *
6 | import matplotlib.pyplot as plt
7 | import torch.nn.functional as F
8 | from fastchat.model.model_adapter import get_conversation_template
9 | from tqdm import tqdm
10 | import argparse
11 | import matplotlib.pyplot as plt
12 |
13 | # Once
14 | # 1st iteration: Once upon [a time]
15 | # 2nd iteration: Once upon a [time there]
16 | # 3rd iteration: Once upon a time [there was]
17 | def get_accuracies(approximate_ids, logit):
18 | results = []
19 | _, _, num_special_tokens, _ = approximate_ids.shape
20 | for i in range(num_special_tokens):
21 | match = approximate_ids[:-1-i, :, i].eq(logit[1+i:, :, :1])
22 | results.append(match)
23 | # print(match.shape)
24 | # accuracy = match.any(dim=-1).sum().float() / (match.shape[0] * match.shape[1])
25 | # print(approximate_ids.shape, logit.shape)
26 | return results
27 |
28 | def plot_accuracies(eval_data, save_path):
29 | plt.figure()
30 | for i, data in enumerate(eval_data):
31 | results= []
32 | for K in range(1, 11):
33 | results.append((data[:, :, :K].any(dim=-1).sum().float() / (data.shape[0] * data.shape[1])).cpu())
34 | plt.plot(results, label=f"{i}th prediction")
35 | print(f"{i+1}th accuracy - {', '.join(['Top '+str(i+1)+' : '+str(result.item()) for i, result in enumerate(results)])}")
36 | plt.xlabel("K")
37 | plt.ylabel("Accuracy")
38 | plt.xticks(range(10))
39 | plt.yticks(np.arange(0, 1.1, 0.1))
40 | plt.legend()
41 | plt.show()
42 | plt.savefig(save_path)
43 |
44 | def main(args):
45 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46 | data = json.load(open(args.data_path))
47 |
48 | if args.eval_result_path:
49 | eval_data = torch.load(args.eval_result_path)
50 | plot_accuracies(eval_data, os.path.join(args.save_dir, args.model_name + "_accuracy.png"))
51 | return
52 |
53 |
54 | model = AutoPromptDecoder.from_pretrained(
55 | args.model_path,
56 | torch_dtype=torch.float16,
57 | low_cpu_mem_usage=True
58 | )
59 | tokenizer = model.tokenizer
60 | model = model.to(device)
61 |
62 | config = model.active_peft_config
63 | num_special_tokens = config.num_special_tokens
64 | virtual_tokens_per_special_token = config.virtual_tokens_per_special_token
65 | total_virtual_tokens = num_special_tokens * virtual_tokens_per_special_token
66 | # TODO: KV Cache
67 | results = None
68 |
69 | for sample in tqdm((data)):
70 | conv = get_conversation_template("vicuna")
71 | conv.messages = []
72 | conv.append_message(conv.roles[0], sample["instruction"])
73 | conv.append_message(conv.roles[1], "")
74 | prompt = conv.get_prompt()
75 | steps = args.steps
76 | logits_ids = []
77 | approximate_ids = []
78 |
79 | with torch.inference_mode():
80 | input_ids = tokenizer([prompt]).input_ids
81 | input_ids = torch.as_tensor(input_ids).to(device)
82 | outputs = model(input_ids)
83 | logits = outputs.logits
84 | pred = torch.argmax(logits[:, -num_special_tokens-1, :], dim=-1)
85 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous()
86 | _, approximate_tokens = prompt_logits.topk(10, dim=-1)
87 | # print(pred.device, input_ids.device, model.device)
88 | preds = torch.cat((input_ids, pred.unsqueeze(0)), dim=-1)
89 | # print(f"Exact token: {tokenizer.batch_decode(pred)}, approximate tokens: {tokenizer.batch_decode(approximate_tokens.squeeze(0))}")
90 | logits_ids.append(preds[:, -1:].detach())
91 | approximate_ids.append(approximate_tokens.detach())
92 | for _ in range(steps):
93 | outputs= model(preds)
94 | logits = outputs.logits
95 | pred = torch.argmax(logits[:, -num_special_tokens-1, :], dim=-1)
96 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous()
97 | _, approximate_tokens = prompt_logits.topk(10, dim=-1)
98 | # print(f"Exact token: {tokenizer.batch_decode(pred)}, approximate tokens: {tokenizer.batch_decode(approximate_tokens.squeeze(0))}")
99 | preds = torch.cat((preds, pred.unsqueeze(0)), dim=-1)
100 | logits_ids.append(preds[:, -1:].detach())
101 | approximate_ids.append(approximate_tokens.detach())
102 | logits_ids = torch.stack(logits_ids, dim=0)
103 | approximate_ids = torch.stack(approximate_ids, dim=0).squeeze(2)
104 | if results is None:
105 | results = get_accuracies(approximate_ids, logits_ids)
106 | else:
107 | # cat sub results
108 | cur_results = get_accuracies(approximate_ids, logits_ids)
109 | for i in range(len(results)):
110 | results[i] = torch.cat((results[i], cur_results[i]), dim=0)
111 |
112 | save_path = os.path.join(args.save_dir, args.model_name + "_accuracy.pt")
113 | torch.save(results, save_path)
114 | plot_accuracies(results, os.path.join(args.save_dir, args.model_name + "_accuracy.png"))
115 |
116 | if __name__ == "__main__":
117 | parser = argparse.ArgumentParser(description="Model Evaluator")
118 |
119 | parser.add_argument("--model_path", type=str, required=True,
120 | help="Path to the pre-trained model.")
121 | parser.add_argument("--model_name", type=str, required=True,
122 | help="Name of the model.")
123 | parser.add_argument("--data_path", type=str, required=True,
124 | help="Path to the evaluation data in JSON format.")
125 | parser.add_argument("--save_dir", type=str, default="./",
126 | help="Directory to save the results.")
127 | parser.add_argument("--steps", type=int, default=20,
128 | help="Number of steps to run the model.")
129 | parser.add_argument("--eval_result_path", type=str, default=None, required=False,
130 | help="Path to the evaluation result.")
131 | args = parser.parse_args()
132 |
133 | # If the save directory doesn't exist, create it
134 | if not os.path.exists(args.save_dir):
135 | os.makedirs(args.save_dir)
136 | main(args)
--------------------------------------------------------------------------------
/prompt/hf_utils.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import HfApi
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser("Upload PPD model to HuggingFace Hub")
5 | parser.add_argument("--folder", type=str, help="Path to model folder")
6 | parser.add_argument("--repo", type=str, help="Repo name")
7 | parser.add_argument("--private", action="store_true", help="Make repo private")
8 |
9 | args = parser.parse_args()
10 |
11 | api = HfApi()
12 |
13 | api.create_repo(
14 | repo_id=args.repo,
15 | private=args.private,
16 | exist_ok=True,
17 | )
18 |
19 | api.upload_folder(
20 | folder_path=args.folder,
21 | repo_id=args.repo,
22 | )
--------------------------------------------------------------------------------
/prompt/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/inference/__init__.py
--------------------------------------------------------------------------------
/prompt/inference/sparse_tree_builder.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 | import torch
4 |
5 | import argparse
6 |
7 | from pprint import pprint
8 | from copy import deepcopy
9 | import numpy as np
10 | import concurrent.futures
11 |
12 |
13 | def load_accuracy(file_name):
14 | eval_data = torch.load(file_name)
15 | accuracies = []
16 | for i, data in enumerate(eval_data):
17 | results= []
18 | for K in range(1, 11):
19 | results.append((data[:, :, :K].any(dim=-1).sum().float() / (data.shape[0] * data.shape[1])).cpu())
20 | print(f"{i+1}th accuracy - {', '.join(['Top '+str(i+1)+' : '+str(result.item()) for i, result in enumerate(results)])}")
21 | accuracy = [results[0]]
22 | for i in range(1, len(results)):
23 | accuracy.append(results[i] - results[i-1])
24 | accuracies.append(accuracy)
25 | accuracies = torch.tensor(accuracies)
26 | return accuracies
27 |
28 |
29 | def expected_accuracy(es, vs, print_output=False):
30 | # only calculate depth of 3
31 | e10, e20, e30 = es[0]
32 | e11, e21, e31 = es[1]
33 | e12, e22, e32 = es[2]
34 | e13, e23, e33 = es[3]
35 | v1, v2, v3 = vs
36 | a = np.array([[0, e11-1, e21, e31],
37 | [0, e12, e22-1, e32],
38 | [1, e13, e23, e33-1],
39 | [1, 1, 1, 1]])
40 | output = np.linalg.solve(a, [0, 0, 0, 1])
41 | if print_output:
42 | print("Expected Probability:")
43 | pprint(output)
44 | return output[1]*v1 + output[2]*v2 + output[3]*v3
45 |
46 |
47 | def find_all_candidates(parent, width, depth):
48 | if depth == 0:
49 | return []
50 | candidates = []
51 | for i in range(width):
52 | candidates.append(parent+[i])
53 | candidates.extend(find_all_candidates(parent+[i], width, depth-1))
54 | return candidates
55 |
56 |
57 | def find_optimal_sparse_tree(accuracies, num_candidates, max_depth=None):
58 | # Generate all possible candidates of varying lengths
59 | if max_depth:
60 | accuracies = accuracies[:max_depth]
61 | candidates = find_all_candidates([], accuracies.shape[1], accuracies.shape[0])
62 |
63 | # Calculate cumulative accuracy for each candidate
64 | candidate_accuracies = []
65 | for candidate in candidates:
66 | cumulative_accuracy = 1.0
67 | for idx, top_i in enumerate(candidate):
68 | cumulative_accuracy *= accuracies[idx, top_i]
69 | candidate_accuracies.append((cumulative_accuracy, candidate))
70 |
71 | # Sort candidates by their cumulative accuracy in descending order and select top n
72 | top_candidates = sorted(candidate_accuracies, key=lambda x: x[0], reverse=True)[:num_candidates]
73 |
74 | # Extract just the candidate paths
75 | top_candidate_paths = [list(candidate) for _, candidate in top_candidates]
76 | top_candidate_accs = [round(acc.cpu().item(), 5) for acc, _ in top_candidates]
77 |
78 | return top_candidate_paths, top_candidate_accs
79 |
80 |
81 | def find_optimal_extended_sparse_tree(accuracies, input_length_limit):
82 | # input_length_limit = num_candidates + sum(candidate_accuracy_n * num_special_tokens_n)
83 | # n is the depth of accuracies
84 | n = accuracies.shape[0]
85 | # generate and store the optimal sparse tree for each num_candidates and each depth
86 | optimal_sparse_trees = {}
87 | optimal_sparse_tree_accuracies = {}
88 | for depth in range(1, n+1):
89 | candidates, accs = find_optimal_sparse_tree(accuracies, input_length_limit, depth)
90 | candidate_acc_pairs = list(zip(candidates, accs))
91 | for length in range(1, input_length_limit+1):
92 | ls = []
93 | for candidate, acc in candidate_acc_pairs[:length]:
94 | sum_children_acc = sum([a for c, a in candidate_acc_pairs[:length] if c[:-1] == candidate])
95 | ls.append((candidate, acc - sum_children_acc))
96 | optimal_sparse_trees[(depth, length)] = ls
97 | for size in range(1, input_length_limit+1):
98 | optimal_sparse_tree_accuracies[(depth, size)] = sum(accs[:size])
99 | best_sparse_trees = None
100 | best_expected_acc = -1
101 | best_num_tree_nodes = None
102 | # only calculate depth of 3 for now
103 | for tree_node2 in range(input_length_limit//(n+1), input_length_limit//2 + 1):
104 | for tree_node3 in range(input_length_limit//(n+1), input_length_limit//2 + 1):
105 | tree_nodes = {1: 10, 2: tree_node2, 3: tree_node3}
106 | sparse_trees, expected_acc = find_extended_sparse_tree_fixed_tree_nodes(tree_nodes, input_length_limit, optimal_sparse_tree_accuracies, deepcopy(optimal_sparse_trees), n)
107 | if expected_acc > best_expected_acc:
108 | best_sparse_trees = sparse_trees
109 | best_expected_acc = expected_acc
110 | best_num_tree_nodes = tree_nodes
111 | print("Input limit", input_length_limit, "Tree nodes:", best_num_tree_nodes, "Expected Acc:", best_expected_acc)
112 | return best_sparse_trees
113 |
114 |
115 | def find_extended_sparse_tree_fixed_tree_nodes(tree_nodes, input_length_limit, optimal_sparse_tree_accuracies, candidate_acc_pairs, n):
116 | optimal_sparse_trees = {}
117 | for depth in range(1, n+1):
118 | num_nodes = n * len(candidate_acc_pairs[(depth, tree_nodes[depth])])
119 | optimal_sparse_trees[depth] = [[candidate, acc, n] for candidate, acc in candidate_acc_pairs[(depth, tree_nodes[depth])]]
120 | while num_nodes > input_length_limit - tree_nodes[depth]:
121 | min_accuracy_loss = float('-inf')
122 | min_index = 0
123 | for i, (_, candidate_acc, num_special_token) in enumerate(optimal_sparse_trees[depth]):
124 | if num_special_token == 1:
125 | continue
126 | accuracy_loss = candidate_acc * (optimal_sparse_tree_accuracies[(num_special_token-1, tree_nodes[num_special_token-1])] - \
127 | optimal_sparse_tree_accuracies[(num_special_token, tree_nodes[num_special_token])])
128 | if accuracy_loss > min_accuracy_loss:
129 | min_accuracy_loss = accuracy_loss
130 | min_index = i
131 | if min_accuracy_loss == float('inf'):
132 | break
133 | optimal_sparse_trees[depth][min_index][2] = optimal_sparse_trees[depth][min_index][2] - 1
134 | num_nodes -= 1
135 |
136 | es = []
137 | vs = []
138 | for depth in range(1, n+1):
139 | e_depth = [0] * (n+1)
140 | for (_, candidate_acc, num_special_token) in optimal_sparse_trees[depth]:
141 | e_depth[num_special_token] += candidate_acc
142 | e_depth[0] = 1 - sum(e_depth[1:])
143 | es.append(e_depth)
144 | vs.append(optimal_sparse_tree_accuracies[(depth, tree_nodes[depth])])
145 | es = np.array(es).T.tolist()
146 | acc = expected_accuracy(es, vs)
147 |
148 | return optimal_sparse_trees, acc
149 |
150 |
151 | def sparse_tree_info(best_sparse_trees):
152 | print("Best Sparse Trees:")
153 | pprint(best_sparse_trees)
154 |
155 | es = []
156 | vs = []
157 | for depth, best_sparse_tree in best_sparse_trees.items():
158 | print(f"Depth: {depth}")
159 | print("Number of tree nodes:", len(best_sparse_tree))
160 | print("Number of special tokens:", sum([num_special_token for _, _, num_special_token in best_sparse_tree]))
161 | acc_1 = sum([accuracy for _, accuracy, num_special_token in best_sparse_tree if num_special_token == 1])
162 | acc_2 = sum([accuracy for _, accuracy, num_special_token in best_sparse_tree if num_special_token == 2])
163 | acc_3 = sum([accuracy for _, accuracy, num_special_token in best_sparse_tree if num_special_token == 3])
164 | print("Probabilities to 1 special token:", acc_1)
165 | print("Probabilities to 2 special tokens:", acc_2)
166 | print("Probabilities to 3 special tokens:", acc_3)
167 | print("Probability to None", 1 - (acc_1 + acc_2 + acc_3))
168 |
169 |
170 | def write_sparse_tree_to_file(file_name, min_input_length, max_input_length, accuracies):
171 | def task(i):
172 | # This is the task that will be executed by each thread.
173 | # It returns a tuple of (i, result) so we know which iteration it belongs to.
174 | return i, find_optimal_extended_sparse_tree(accuracies, i)
175 |
176 | # Prepare to collect the results
177 | results = []
178 |
179 | # Using ThreadPoolExecutor to execute tasks concurrently
180 | with concurrent.futures.ThreadPoolExecutor() as executor:
181 | # Map the task function to the range of values concurrently
182 | future_to_i = {executor.submit(task, i): i for i in range(min_input_length, max_input_length+1)}
183 |
184 | for future in concurrent.futures.as_completed(future_to_i):
185 | i = future_to_i[future]
186 | try:
187 | # Collect results as they are completed
188 | results.append(future.result())
189 | except Exception as exc:
190 | print(f'Generated an exception: {exc}')
191 |
192 | # Sorting results to ensure they are in order
193 | results.sort(key=lambda x: x[0])
194 |
195 | # Writing results to file sequentially
196 | with open(file_name, "w") as f:
197 | for i, result in results:
198 | f.write(f"# Dynamic Sparse Trees for input length limit of {i}\n")
199 | f.write(f"dynamic_sparse_trees_{i} = {result}\n")
200 |
201 |
202 | def main(args):
203 | accuracies = load_accuracy(args.accuracies_file)
204 | write_sparse_tree_to_file(args.output_file, args.min_input_length, args.max_input_length, accuracies)
205 |
206 |
207 | if __name__ == "__main__":
208 | parser = argparse.ArgumentParser()
209 | parser.add_argument("--accuracies-file", type=str, required=True, help="Path to the accuracies file.")
210 | parser.add_argument("--output-file", type=str, required=True, help="Path to the output file.")
211 | parser.add_argument("--min-input-length", type=int, required=True, help="Minimum input length limit.")
212 | parser.add_argument("--max-input-length", type=int, required=True, help="Maximum input length limit.")
213 | args = parser.parse_args()
214 | main(args)
215 |
216 |
217 |
--------------------------------------------------------------------------------
/prompt/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/model/__init__.py
--------------------------------------------------------------------------------
/prompt/model/kv_cache.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 | # adapted from Medusa: https://github.com/FasterDecoding/Medusa/blob/5e980538695096e7e372c1e27a6bcf142bfeab11/medusa/model/kv_cache.py
5 | class KVCache:
6 | """
7 | A key-value cache for the model.
8 |
9 | This class provides a mechanism to maintain a growing cache of keys and values,
10 | particularly useful for models that benefit from caching previous states,
11 | like transformers during autoregressive decoding.
12 |
13 | Attributes:
14 | data (torch.Tensor): The tensor storing keys and values.
15 | current_length (int): Current length of the data being stored.
16 | """
17 |
18 | def __init__(self, data: torch.Tensor, current_length: int):
19 | """
20 | Initialize the KVCache.
21 |
22 | Args:
23 | data (torch.Tensor): Initial tensor to store the keys and values.
24 | current_length (int): Initial length of the data.
25 | normal_token_indices (torch.Tensor): Indices of the normal tokens in the data tensor.
26 | """
27 | self.data = data
28 | self.current_length = current_length
29 |
30 | @property
31 | def shape(self):
32 | """Return the shape of the data tensor with updated length."""
33 | return (
34 | self.data.shape[0],
35 | self.data.shape[1],
36 | self.current_length.item(),
37 | self.data.shape[3],
38 | )
39 |
40 | def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2):
41 | """
42 | Copy values from the current data at specified indices to a new location.
43 |
44 | Args:
45 | indices (torch.Tensor): Indices of the data tensor to be copied.
46 | prev_length (int): Previous length before adding new data.
47 | dim (int, optional): Dimension along which copying should be performed. Default is 2.
48 | """
49 | tgt = self.data.index_select(dim, indices)
50 | dst = self.data.narrow(dim, prev_length, tgt.shape[dim])
51 | dst.copy_(tgt, non_blocking=True)
52 | self.current_length.fill_(prev_length + tgt.shape[dim])
53 |
54 | def cat(self, tensor: torch.Tensor, dim: int = 2):
55 | """
56 | Concatenate the given tensor with the current data.
57 |
58 | Args:
59 | tensor (torch.Tensor): The tensor to be concatenated.
60 | dim (int, optional): The dimension along which concatenation should be done. Default is 2.
61 |
62 | Returns:
63 | torch.Tensor: The data tensor after concatenation up to the current length.
64 | """
65 | # if normal_token_indices is None:
66 | # dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
67 | # dst.copy_(tensor, non_blocking=True)
68 | # self.current_length.add_(tensor.shape[dim])
69 | # return torch.narrow(self.data, 2, 0, self.current_length)
70 | # dst = self.data.narrow(dim, self.current_length, len(normal_token_indices))
71 | # dst.copy_(tensor.index_select(dim, normal_token_indices), non_blocking=True)
72 | # rst = torch.cat([torch.narrow(self.data, 2, 0, self.current_length), tensor], dim)
73 | # self.current_length.add_(len(normal_token_indices))
74 | # return rst
75 | dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
76 | dst.copy_(tensor, non_blocking=True)
77 | self.current_length.add_(tensor.shape[dim])
78 | return torch.narrow(self.data, 2, 0, self.current_length)
79 |
80 |
81 | def initialize_past_key_values(model):
82 | """
83 | Initialize past key and value states for a given transformer model.
84 |
85 | This function prepares key-value cache structures for the model, allowing it to store and reuse
86 | past key and value states during autoregressive decoding, which can improve efficiency.
87 |
88 | Args:
89 | model (nn.Module): The transformer model for which past key-value states need to be initialized.
90 |
91 | Returns:
92 | tuple:
93 | - past_key_values (list): A list of KVCache objects for each layer in the model.
94 | - past_key_values_data (torch.Tensor): The tensor that will store all keys and values.
95 | - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache.
96 | """
97 | # Extracting configuration from the model
98 | config = model.config
99 | # Initializing the batch size to 1, this can be modified if different batch sizes are required
100 | batch_size = 1
101 | # Initializing a tensor to store past keys and values for all layers
102 | past_key_values_data = torch.zeros(
103 | config.num_hidden_layers * 2,
104 | batch_size,
105 | config.num_key_value_heads,
106 | # llama max_position_embeddings is 4096 instead of 2048
107 | config.max_position_embeddings*2,
108 | config.hidden_size // config.num_attention_heads,
109 | device=model.device,
110 | dtype=model.dtype,
111 | )
112 | # Initialize tensor to store the current length of the cached data for all layers.
113 | # [IMPORTANT] It needs to be kept on CPU for quick access and updates.
114 | current_length_data = torch.zeros(
115 | config.num_hidden_layers * 2, dtype=torch.long, device="cpu"
116 | )
117 | # Creating a KVCache for each pair of key and value in all layers
118 | past_key_values = [] * config.num_hidden_layers
119 | for i in range(config.num_hidden_layers):
120 | past_key_values.append(
121 | [
122 | KVCache(past_key_values_data[i * 2 + j], current_length_data[i * 2 + j])
123 | for j in range(2)
124 | ]
125 | )
126 | return past_key_values, past_key_values_data, current_length_data
--------------------------------------------------------------------------------
/prompt/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hmarkc/parallel-prompt-decoding/2ba3a1cd9328f274662c9d3a00a0a28b9ef3b874/prompt/train/__init__.py
--------------------------------------------------------------------------------
/prompt/train/train.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | import warnings
3 | import math
4 | import pathlib
5 | from typing import Dict, Optional
6 |
7 | import torch
8 | import transformers
9 | from transformers import Trainer, BitsAndBytesConfig
10 | from transformers.trainer_pt_utils import LabelSmoother
11 |
12 | from torch.nn import CrossEntropyLoss
13 | import torch.nn.functional as F
14 |
15 | from prompt.utils import *
16 | from prompt.model.model import PromptDecoder, PromptConfig, AutoPromptDecoder
17 | from prompt.model.modeling_llama_custom import LlamaForCausalLM as CustomLlamaForCausalLM
18 |
19 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
20 |
21 | class ParamEfficientFineTuner(Trainer):
22 | def compute_loss(self, model, inputs, return_outputs=False):
23 | """
24 | Compute the training loss for the model.
25 |
26 | Args:
27 | model (torch.nn.Module): The model for which to compute the loss.
28 | inputs (dict): The input data, including input IDs, attention mask, and labels.
29 | return_outputs (bool): Whether to return model outputs along with the loss.
30 |
31 | Returns:
32 | Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs.
33 | """
34 | num_special_tokens = self.model.active_peft_config.num_special_tokens
35 | if torch.any(inputs["input_ids"][:, -1] == self.tokenizer.eos_token_id):
36 | warnings.warn("Input ends with EOS token.")
37 | input_ids = inputs["input_ids"]
38 | attention_mask = inputs["attention_mask"]
39 | labels = inputs["labels"]
40 |
41 | outputs = model(
42 | input_ids=input_ids, attention_mask=attention_mask
43 | )
44 | logits = outputs.logits
45 |
46 | # Calculate loss on the prompt tokens
47 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous()
48 | prompt_labels = labels[..., -num_special_tokens:].contiguous()
49 | prompt_labels = prompt_labels.to(logits.device)
50 | loss = 0
51 | loss_fn = CrossEntropyLoss()
52 | decay_coefficient = 0.8
53 | for i in range(num_special_tokens):
54 | loss += loss_fn(prompt_logits[:, i, :], prompt_labels[:, i]) * (decay_coefficient ** i)
55 | if num_special_tokens > 0:
56 | loss /= num_special_tokens
57 | return (loss, outputs) if return_outputs else loss
58 |
59 |
60 | class DistillationTrainer(Trainer):
61 | def compute_loss(self, model, inputs, return_outputs=False):
62 | """
63 | Compute the training loss for the model.
64 |
65 | Args:
66 | model (torch.nn.Module): The model for which to compute the loss.
67 | inputs (dict): The input data, including input IDs, attention mask, and labels.
68 | return_outputs (bool): Whether to return model outputs along with the loss.
69 |
70 | Returns:
71 | Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs.
72 | """
73 | num_special_tokens = self.model.active_peft_config.num_special_tokens
74 | if torch.any(inputs["input_ids"][:, -1] == self.tokenizer.eos_token_id):
75 | warnings.warn("Input ends with EOS token.")
76 | input_ids = inputs["input_ids"]
77 | attention_mask = inputs["attention_mask"]
78 | labels = inputs["labels"]
79 |
80 | outputs = model(
81 | input_ids=input_ids, attention_mask=attention_mask
82 | )
83 | logits = outputs.logits
84 |
85 | # Calculate loss on the prompt tokens
86 | prompt_logits = logits[:, -num_special_tokens:, :].contiguous()
87 | prompt_labels = labels.contiguous()
88 | prompt_labels = prompt_labels.to(logits.device)
89 | loss = 0
90 | decay_coefficient = 0.8
91 | for i in range(num_special_tokens):
92 | loss_i = F.kl_div(
93 | F.log_softmax(prompt_logits[:, i, :], dim=-1),
94 | F.softmax(prompt_labels[:, i, :], dim=-1),
95 | reduction='sum'
96 | ) / prompt_logits.shape[0]
97 | loss += loss_i * (decay_coefficient ** i)
98 | if num_special_tokens > 0:
99 | loss /= num_special_tokens
100 | return (loss, outputs) if return_outputs else loss
101 |
102 |
103 | @dataclass
104 | class ModelArguments:
105 | model_name_or_path: str = field(default="lmsys/vicuna-7b-v1.3")
106 | num_special_tokens: int = field(default=1)
107 | virtual_tokens_per_special_token: int = field(default=1)
108 | use_custom_lm_head: bool = field(default=False)
109 | use_prefix_tuning: bool = field(default=False)
110 | prefix_virtual_tokens: int = field(default=10)
111 | vt_attention_type: str = field(default="decoder")
112 | aggregation_type: str = field(default="mean")
113 | num_exits: int = field(default=1)
114 | load_in_4bit: bool = field(
115 | default=False,
116 | metadata={"help": "Load in 4 bit."},
117 | )
118 | load_in_8bit: bool = field(
119 | default=False,
120 | metadata={"help": "Load in 8 bit."},
121 | )
122 |
123 |
124 | @dataclass
125 | class DataArguments:
126 | dataset_path: Optional[str] = field(
127 | default=None, metadata={"help": "Path to the saved dataset."},
128 | )
129 | size: Optional[int] = field(
130 | default=None, metadata={"help": "Number of examples to use."}
131 | )
132 | use_chunked: bool = field(
133 | default=False, metadata={"help": "Whether to use chunked dataset."}
134 | )
135 |
136 |
137 | @dataclass
138 | class TrainingArguments(transformers.TrainingArguments):
139 | cache_dir: str = field(default=None)
140 | optim: str = field(default="adamw_torch")
141 | trainer_type: str = field(default="param_efficient_fine_tuner", metadata={"help": "Trainer type: param_efficient_fine_tuner, distillation_trainer"})
142 | stage1_model_path: Optional[str] = field(
143 | default=None,
144 | metadata={"help": "Path to the stage 1 model."},
145 | )
146 | lm_head_lr_multiplier: float = field(default=0.1)
147 | model_max_length: int = field(
148 | default=1024,
149 | metadata={
150 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
151 | },
152 | )
153 |
154 |
155 |
156 | def train():
157 | parser = transformers.HfArgumentParser(
158 | (ModelArguments, DataArguments, TrainingArguments)
159 | )
160 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
161 |
162 | quantization_config = BitsAndBytesConfig(
163 | load_in_4bit=True,
164 | bnb_4bit_compute_dtype=torch.bfloat16,
165 | bnb_4bit_use_double_quant=True,
166 | bnb_4bit_quant_type="nf4",
167 | )
168 |
169 | print("load_in_4_bits", model_args.load_in_4bit)
170 |
171 | # Create model
172 | peft_config = PromptConfig(
173 | tokenizer_name_or_path=model_args.model_name_or_path,
174 | base_model_name_or_path=model_args.model_name_or_path,
175 | num_special_tokens=model_args.num_special_tokens,
176 | virtual_tokens_per_special_token=model_args.virtual_tokens_per_special_token,
177 | use_prefix_tuning=model_args.use_prefix_tuning,
178 | prefix_virtual_tokens=model_args.prefix_virtual_tokens,
179 | vt_attention_type=VTAttentionType.from_str(model_args.vt_attention_type),
180 | aggregation_type=AggregationType.from_str(model_args.aggregation_type),
181 | use_custom_lm_head=model_args.use_custom_lm_head,
182 | num_exits=model_args.num_exits,
183 | )
184 | if training_args.stage1_model_path:
185 | model = AutoPromptDecoder.from_pretrained(
186 | training_args.stage1_model_path,
187 | low_cpu_mem_usage=True,
188 | cache_dir=training_args.cache_dir,
189 | quantization_config=quantization_config if model_args.load_in_4bit else None,
190 | new_config=peft_config,
191 | )
192 | else:
193 | # Set RoPE scaling factor
194 | config = transformers.AutoConfig.from_pretrained(
195 | model_args.model_name_or_path,
196 | cache_dir=training_args.cache_dir,
197 | )
198 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
199 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
200 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
201 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
202 | config.use_cache = False
203 |
204 | config = transformers.AutoConfig.from_pretrained(
205 | model_args.model_name_or_path,
206 | cache_dir=training_args.cache_dir
207 | )
208 |
209 | if config.model_type == "llama":
210 | base_model = CustomLlamaForCausalLM.from_pretrained(
211 | model_args.model_name_or_path,
212 | config=config,
213 | cache_dir=training_args.cache_dir,
214 | low_cpu_mem_usage=True,
215 | quantization_config=quantization_config if model_args.load_in_4bit else None,
216 | # load_in_4bit=model_args.load_in_4bit,
217 | # load_in_8bit=model_args.load_in_8bit,
218 | )
219 | else:
220 | raise ValueError("Only support llama for now")
221 |
222 | for param in base_model.base_model.parameters():
223 | param.requires_grad = False
224 | model = PromptDecoder(base_model, peft_config)
225 | print(model.print_trainable_parameters(), model)
226 |
227 | # Output dir
228 | training_args.output_dir = f"{training_args.output_dir}/prompt_{model_args.model_name_or_path.split('/')[-1]}_{model_args.num_special_tokens}_{model_args.virtual_tokens_per_special_token}_cl{training_args.model_max_length}_{model_args.vt_attention_type.upper()}_{model_args.aggregation_type}{'_custom_lm_head' if model_args.use_custom_lm_head else ''}{'_prefix' + str(model_args.prefix_virtual_tokens) if model_args.use_prefix_tuning else ''}_exits{model_args.num_exits}"
229 |
230 | tokenizer = transformers.AutoTokenizer.from_pretrained(
231 | model_args.model_name_or_path,
232 | cache_dir=training_args.cache_dir,
233 | model_max_length=training_args.model_max_length,
234 | padding_side="left",
235 | use_fast=False,
236 | truncation=True
237 | )
238 | tokenizer.pad_token = tokenizer.unk_token
239 |
240 | # Load data
241 | if data_args.use_chunked:
242 | data = ChunkDataset(data_args.dataset_path)
243 | else:
244 | data = torch.load(data_args.dataset_path)
245 | data.set_size(data_args.size)
246 |
247 | # Set up optimizer
248 | optimizer_grouped_parameters = [
249 | {
250 | "params": [
251 | p for n, p in model.named_parameters() if (p.requires_grad and "lm_head" in n)
252 | ],
253 | "lr": training_args.learning_rate * training_args.lm_head_lr_multiplier,
254 | "weight_decay": training_args.weight_decay,
255 | },
256 | {
257 | "params": [
258 | p for n, p in model.named_parameters() if (p.requires_grad and "prompt_encoder" in n)
259 | ],
260 | "lr": training_args.learning_rate,
261 | "weight_decay": training_args.weight_decay,
262 | },
263 | ]
264 | optim_cls, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args)
265 | optimizer = optim_cls(optimizer_grouped_parameters, **optim_kwargs)
266 |
267 | # Start trainner
268 | if training_args.trainer_type == "distillation_trainer":
269 | trainer = DistillationTrainer(
270 | model=model, tokenizer=tokenizer, args=training_args, train_dataset=data, eval_dataset=None, optimizers=(optimizer, None)
271 | )
272 | elif training_args.trainer_type == "param_efficient_fine_tuner":
273 | trainer = ParamEfficientFineTuner(
274 | model=model, tokenizer=tokenizer, args=training_args, train_dataset=data, eval_dataset=None, optimizers=(optimizer, None)
275 | )
276 | else:
277 | raise ValueError(f"Trainer type {training_args.trainer_type} not supported.")
278 |
279 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
280 | print("Resuming training...")
281 | trainer.train(resume_from_checkpoint=True)
282 | else:
283 | trainer.train()
284 |
285 | # Save model
286 | model.save_pretrained(training_args.output_dir)
287 |
288 | if __name__ == "__main__":
289 | train()
--------------------------------------------------------------------------------
/prompt/utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | import json
3 | import os
4 | from enum import Enum
5 | from typing import Dict, Optional, Sequence
6 |
7 | from tqdm import tqdm
8 | import torch
9 | import torch.nn.functional as F
10 | from torch.utils.data import Dataset, DataLoader
11 | import transformers
12 | from transformers.trainer_pt_utils import LabelSmoother
13 | from fastchat.conversation import SeparatorStyle
14 | from fastchat.model.model_adapter import get_conversation_template
15 |
16 | import prompt.inference.dynamic_sparse_trees_3_vicuna_13b as dynamic_sparse_trees_3_vicuna_13b
17 | import prompt.inference.dynamic_sparse_trees_3_vicuna_7b as dynamic_sparse_trees_3_vicuna_7b
18 | import prompt.inference.dynamic_sparse_trees_3_MobileLLaMA as dynamic_sparse_trees_3_mobilellama
19 |
20 |
21 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
22 |
23 |
24 | class VTAttentionType(str, Enum):
25 | """Attention type for VicunaTuning
26 | """
27 | DECODER = "decoder"
28 | ENCODER = "encoder"
29 | ENSEMBLE = "ensemble"
30 |
31 | def __str__(self):
32 | return self.value
33 |
34 | @staticmethod
35 | def from_str(s):
36 | s = s.lower()
37 | if s == "decoder":
38 | return VTAttentionType.DECODER
39 | elif s == "encoder":
40 | return VTAttentionType.ENCODER
41 | elif s == "ensemble":
42 | return VTAttentionType.ENSEMBLE
43 | else:
44 | raise ValueError(f"Invalid attention type: {s}")
45 |
46 |
47 | class AggregationType(str, Enum):
48 | """Aggregation type for VicunaTuning
49 | """
50 | MEAN = "mean"
51 | WEIGHTED = "weighted"
52 | ADAPTIVAE_WEIGHTED = "adaptive_weighted"
53 |
54 | def __str__(self):
55 | return self.value
56 |
57 | @staticmethod
58 | def from_str(s):
59 | s = s.lower()
60 | if s == "mean":
61 | return AggregationType.MEAN
62 | elif s == "weighted":
63 | return AggregationType.WEIGHTED
64 | elif s == "adaptive_weighted":
65 | return AggregationType.ADAPTIVAE_WEIGHTED
66 | else:
67 | raise ValueError(f"Invalid aggregation type: {s}")
68 |
69 |
70 | def preprocess(
71 | sources,
72 | tokenizer: transformers.PreTrainedTokenizer,
73 | ) -> Dict:
74 | """
75 | Preprocesses conversation data and tokenizes it for model input.
76 |
77 | Args:
78 | sources: A list of conversation sources.
79 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for tokenization.
80 |
81 | Returns:
82 | Dict: A dictionary containing tokenized inputs, labels, and attention mask.
83 | """
84 | conv = get_conversation_template("vicuna")
85 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
86 |
87 | # Apply prompt templates
88 | conversations = []
89 | for i, source in enumerate(sources):
90 | if roles[source[0]["from"]] != conv.roles[0]:
91 | # Skip the first one if it is not from human
92 | source = source[1:]
93 |
94 | conv.messages = []
95 | for j, sentence in enumerate(source):
96 | role = roles[sentence["from"]]
97 | assert role == conv.roles[j % 2], f"{i}, {j}, {role}, {conv.roles[j % 2]}"
98 | conv.append_message(role, sentence["value"])
99 | conversations.append(conv.get_prompt())
100 |
101 | # Tokenize conversations
102 | input_ids = tokenizer(
103 | conversations,
104 | return_tensors="pt",
105 | padding="max_length",
106 | max_length=tokenizer.model_max_length,
107 | truncation=True
108 | ).input_ids
109 | targets = input_ids.clone()
110 |
111 | assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
112 |
113 | # Mask targets. Only compute loss on the assistant outputs.
114 | sep = conv.sep + conv.roles[1] + ": "
115 | # print("sep", sep)
116 | for conversation, target in zip(conversations, targets):
117 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
118 |
119 | turns = conversation.split(conv.sep2)
120 | # the number of preceding padding tokens
121 | cur_len = 1
122 | for p in target:
123 | if p == tokenizer.pad_token_id:
124 | cur_len += 1
125 | else:
126 | break
127 | target[:cur_len] = IGNORE_TOKEN_ID
128 | # target_imm = target.clone()
129 | # target_imm[target_imm == -100] = 0
130 | # print("target1", tokenizer.decode(target_imm))
131 | for i, turn in enumerate(turns):
132 | if turn == "":
133 | break
134 | turn_len = len(tokenizer(turn).input_ids)
135 |
136 | parts = turn.split(sep)
137 | if len(parts) != 2:
138 | break
139 | parts[0] += sep
140 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
141 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
142 |
143 | # Ignore the user instructions
144 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
145 | # print(cur_len, cur_len + instruction_len)
146 | # target_imm = target.clone()
147 | # target_imm[target_imm == -100] = 0
148 | # print("target2", tokenizer.decode(target_imm))
149 | cur_len += turn_len
150 |
151 | target[cur_len:] = IGNORE_TOKEN_ID
152 |
153 | if cur_len < tokenizer.model_max_length:
154 | if cur_len != total_len:
155 | target[:] = IGNORE_TOKEN_ID
156 |
157 | # a= (input_ids[0, :] != targets[0, :]).nonzero(as_tuple=False)
158 | # print("input_ids compare to targets", a)
159 | # print("targets compare to input_ids", a.shape)
160 | # print(targets[0, input_ids[0, :] != targets[0, :]])
161 | return dict(
162 | input_ids=input_ids,
163 | labels=targets,
164 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
165 | )
166 |
167 |
168 |
169 | class FineTuningDataset(Dataset):
170 | """Dataset for fine-tuning.
171 |
172 | Args:
173 | raw_data (list): A list of raw data examples.
174 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
175 | """
176 |
177 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, offset):
178 | super(FineTuningDataset, self).__init__()
179 |
180 | sources = [example["conversations"] for example in raw_data]
181 | data_dict = preprocess(sources, tokenizer)
182 | block_indices = find_last_positive_block(data_dict["labels"], IGNORE_TOKEN_ID, offset)
183 | input_ids, attention_mask, labels = randomly_truncate(data_dict["input_ids"],
184 | data_dict["attention_mask"],
185 | data_dict["labels"],
186 | block_indices,
187 | offset,
188 | tokenizer.pad_token_id,
189 | IGNORE_TOKEN_ID)
190 |
191 | self.input_ids = input_ids
192 | self.labels = labels
193 | self.attention_mask = attention_mask
194 |
195 | def __len__(self):
196 | return len(self.input_ids)
197 |
198 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
199 | return dict(
200 | input_ids=self.input_ids[i],
201 | labels=self.labels[i],
202 | attention_mask=self.attention_mask[i],
203 | )
204 |
205 | def set_size(self, size):
206 | if size is None:
207 | return
208 | self.input_ids = self.input_ids[:size]
209 | self.labels = self.labels[:size]
210 | self.attention_mask = self.attention_mask[:size]
211 |
212 |
213 | def get_finetune_dataset(
214 | tokenizer: transformers.PreTrainedTokenizer, data_path, size: Optional[int] = None, offset=0
215 | ) -> Dict:
216 | """Make dataset and collator for supervised fine-tuning.
217 |
218 | Args:
219 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
220 | data_args: Data arguments.
221 |
222 | Returns:
223 | dict: A dictionary containing train and eval datasets.
224 | """
225 |
226 | json_file = json.load(open(data_path, "r"))
227 | size = size or len(json_file)
228 | dataset = FineTuningDataset(json_file[:size], tokenizer=tokenizer, offset=offset)
229 | return dataset
230 |
231 |
232 | def find_last_positive_block(A, ignored_id, n):
233 | """
234 | Find the start and end index of the last block of positive numbers of at least size n in each row of A.
235 |
236 | Args:
237 | - A (torch.Tensor): Input tensor of shape [N, L]
238 | - n (int): Minimum size of the block
239 |
240 | Returns:
241 | - torch.Tensor: Tensor of shape [N, 2] containing start and end indices of the last block of positive numbers of at least size n
242 | """
243 | N, L = A.shape
244 | indices = torch.full((N, 2), -1, dtype=torch.long) # Initialize with -1
245 |
246 | for i in range(N):
247 | last_pos_end = -1
248 | block_size = 0
249 |
250 | for j in range(L-1, -1, -1):
251 | if A[i, j] != ignored_id:
252 | if last_pos_end == -1:
253 | last_pos_end = j # Mark the end of a positive block
254 | block_size += 1
255 | else:
256 | if last_pos_end != -1:
257 | if block_size >= n:
258 | indices[i, 0] = j + 1 # Start of the last positive block
259 | indices[i, 1] = last_pos_end
260 | break
261 | else:
262 | # Reset for next block search
263 | last_pos_end = -1
264 | block_size = 0
265 | if j == 0 and last_pos_end != -1 and block_size >= n:
266 | indices[i, 0] = 0
267 | indices[i, 1] = last_pos_end
268 |
269 | return indices
270 |
271 |
272 | def randomly_truncate(input_ids, attention_mask, labels, positions, k, pad_token_id=0, IGNORE_TOKEN_ID=IGNORE_TOKEN_ID):
273 | N, L = input_ids.shape
274 | # Initialize the tensor that will hold the truncated sequences
275 | truncated_batch = torch.full_like(input_ids, pad_token_id)
276 | truncated_attention_mask = torch.full_like(attention_mask, 0)
277 | truncated_labels = torch.full_like(labels, IGNORE_TOKEN_ID)
278 |
279 | for i in range(N):
280 | start, end = positions[i]
281 | # The cut has to leave at least k elements truncated, so we adjust the end accordingly
282 | # Also, ensure the cut is at least at the start position or further to the right
283 | if start == -1 or end == -1:
284 | cut = L-k
285 | else:
286 | valid_end = max(start + 1, end - k + 1)
287 | # Randomly choose a cut point from start to the valid_end
288 | cut = torch.randint(start, valid_end, (1,)).item()
289 | # print(start, cut, L-cut)
290 | # Truncate the sequence and pad from the left
291 | try:
292 | truncated_batch[i, L-cut:] = input_ids[i, :cut]
293 | truncated_attention_mask[i, L-cut:] = attention_mask[i, :cut]
294 | truncated_labels[i, L-cut-k:] = labels[i, :cut+k]
295 | except:
296 | print(valid_end, cut, start, end)
297 | print(i, L-cut, cut, L, input_ids[i, :cut].shape, truncated_batch[i, L-cut:].shape)
298 | print(i, L-cut, cut, L, attention_mask[i, :cut].shape, truncated_attention_mask[i, L-cut:].shape)
299 | print(i, L-cut-k, cut+k, L, labels[i, :cut+k].shape, truncated_labels[i, L-cut-k:].shape)
300 | raise Exception("Error in truncation")
301 |
302 | return truncated_batch, truncated_attention_mask, truncated_labels
303 |
304 |
305 | class DistillationDataset(Dataset):
306 | """Dataset for fine-tuning.
307 |
308 | Args:
309 | data (list): A list of data containing input_ids, labels, and attention_mask.
310 | """
311 |
312 | def __init__(self, data):
313 | super(DistillationDataset, self).__init__()
314 | self.input_ids = [d["input_ids"] for d in data]
315 | self.labels = [d["labels"] for d in data]
316 | self.attention_mask = [d["attention_mask"] for d in data]
317 |
318 | def __len__(self):
319 | return len(self.input_ids)
320 |
321 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
322 | return dict(
323 | input_ids=self.input_ids[i],
324 | labels=self.labels[i],
325 | attention_mask=self.attention_mask[i],
326 | )
327 |
328 | def set_size(self, size):
329 | if size is None:
330 | return
331 | self.input_ids = self.input_ids[:size]
332 | self.labels = self.labels[:size]
333 | self.attention_mask = self.attention_mask[:size]
334 |
335 |
336 | def get_self_distillation_dataset(model, data_path, num_special_tokens):
337 | dataset = torch.load(data_path)
338 | dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
339 | data = []
340 | model.eval()
341 | # dataloader is faster but batched input need more memory
342 | for batch in tqdm(dataloader):
343 | input_ids = batch["input_ids"]
344 | attention_mask = batch["attention_mask"]
345 | batch_size, seq_length = input_ids.shape
346 | preds = input_ids.clone()
347 | batch_labels = []
348 |
349 | for j in range(num_special_tokens+1):
350 | with torch.inference_mode():
351 | outputs = model(input_ids=preds, attention_mask=attention_mask)
352 | logits = outputs.logits
353 | input_id = logits[:, -1:, :].argmax(-1)
354 |
355 | if j > 0:
356 | batch_labels.append(logits[:, -1, :])
357 |
358 | preds = torch.cat([preds, input_id], dim=1)
359 | attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1).to(attention_mask.device)], dim=1)
360 |
361 | labels = torch.stack(batch_labels, dim=1)
362 | for i in range(batch_size):
363 | data.append({"input_ids": preds[i, :-num_special_tokens-1], "labels": labels[i], "attention_mask": attention_mask[i, :-num_special_tokens-1]})
364 | return DistillationDataset(data)
365 |
366 |
367 | def chunk_dataset(dataset_path, chunk_size, output_dir):
368 | dataset = torch.load(dataset_path)
369 | total_size = len(dataset)
370 | print(f"Total size: {total_size}")
371 | for i in tqdm(range(0, total_size, chunk_size)):
372 | chunk = dataset[i:i+chunk_size]
373 | torch.save(chunk, os.path.join(output_dir, f'dataset_chunk_{i//chunk_size}.pt'))
374 |
375 |
376 | class ChunkDataset(Dataset):
377 | def __init__(self, chunk_dir):
378 | super(ChunkDataset, self).__init__()
379 | self.chunk_dir = chunk_dir
380 | # List all chunk files
381 | self.chunk_files = [os.path.join(chunk_dir, f) for f in os.listdir(chunk_dir) if f.startswith('dataset_chunk_')]
382 | self.chunk_files.sort(key=lambda x: (len(x), x))
383 | # Calculate offsets and total length
384 | self.lengths = [torch.load(f, map_location='cpu')['input_ids'].__len__() for f in self.chunk_files]
385 | self.cumulative_lengths = [sum(self.lengths[:i+1]) for i in range(len(self.lengths))]
386 |
387 | def __len__(self):
388 | return self.cumulative_lengths[-1]
389 |
390 | def __getitem__(self, idx):
391 | # Find which chunk contains the item
392 | chunk_idx = next(i for i, total in enumerate(self.cumulative_lengths) if total > idx)
393 | if chunk_idx > 0:
394 | idx -= self.cumulative_lengths[chunk_idx-1] # Adjust index relative to the chunk
395 |
396 | # Load the chunk
397 | chunk = torch.load(self.chunk_files[chunk_idx], map_location='cpu')
398 |
399 | # Extract and return the item
400 | return dict(
401 | input_ids=chunk['input_ids'][idx],
402 | labels=chunk['labels'][idx],
403 | attention_mask=chunk['attention_mask'][idx],
404 | )
405 |
406 |
407 | def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
408 | original_logit = logit.clone()
409 | logit = logit / temperature
410 | probs = torch.softmax(logit, dim=-1)
411 | entropy = -torch.sum(
412 | probs * torch.log(probs + 1e-5), dim=-1
413 | )
414 | threshold = torch.minimum(
415 | torch.ones_like(entropy) * posterior_threshold,
416 | torch.exp(-entropy) * posterior_alpha,
417 | )
418 | indices_to_remove = probs < threshold.unsqueeze(-1)
419 | logit[indices_to_remove] = float('-inf')
420 | prob = F.softmax(logit, dim=-1)
421 | try:
422 | sampled_tokens = torch.multinomial(prob, 1)
423 | except:
424 | print(prob.max(), prob.min())
425 | print(logit.max(), logit.min())
426 | print(original_logit.max(), original_logit.min())
427 | print(temperature, original_logit.max()/ temperature, original_logit.min()/ temperature)
428 | print(indices_to_remove.any())
429 | raise Exception("Error in sampling")
430 | return sampled_tokens
431 |
432 |
433 | def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
434 | logits = logits[:, :-1] / temperature
435 | n_samples, n_tokens = logits.shape[0], logits.shape[1]
436 | logits = logits.view(n_samples*n_tokens, -1)
437 | probs = F.softmax(logits, dim=-1)
438 | entropy = -torch.sum(
439 | probs * torch.log(probs + 1e-5), dim=-1
440 | )
441 | threshold = torch.minimum(
442 | torch.ones_like(entropy) * posterior_threshold,
443 | torch.exp(-entropy) * posterior_alpha,
444 | )
445 | indices_to_remove = probs < threshold.unsqueeze(-1)
446 | logits[indices_to_remove] = float('-1e4')
447 | sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
448 | sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
449 | posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
450 | return posterior_mask
451 |
452 |
453 | def pad_path(path, length, pad_value=-2):
454 | """
455 | Pad the given path list with a specific value up to a specified length.
456 |
457 | Parameters:
458 | - path (list): The original list that needs padding.
459 | - length (int): The desired length of the padded list.
460 | - pad_value (optional, default=-2): The value to use for padding.
461 |
462 | Returns:
463 | - list: A new list based on the original path but padded to the desired length.
464 |
465 | Example:
466 | >>> pad_path([1,2,3], 5)
467 | [1, 2, 3, -2, -2]
468 |
469 | Note:
470 | If the given path is already longer than the specified length,
471 | then no padding occurs, and the original path is returned.
472 | """
473 |
474 | # Calculate the number of padding values needed by subtracting the length
475 | # of the path from the desired length.
476 | # Append the padding values to the original path and return the new list.
477 | return path + [pad_value] * (length - len(path))
478 |
479 | def get_dynamic_sparse_tree(model_path):
480 | if 'vicuna-13b' in model_path.lower():
481 | print('Using 13b 3-1 sparse trees')
482 | tree = dynamic_sparse_trees_3_vicuna_13b.dynamic_sparse_trees_60
483 | elif 'vicuna-7b' in model_path.lower():
484 | print('Using 7b 3-1 sparse trees')
485 | tree = dynamic_sparse_trees_3_vicuna_7b.dynamic_sparse_trees_105
486 | elif 'mobilellama' in model_path.lower():
487 | print('Using MobileLLaMA 3-1 sparse trees')
488 | tree = dynamic_sparse_trees_3_mobilellama.dynamic_sparse_trees_285
489 | else:
490 | raise ValueError("Unknown model path")
491 | return tree
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["poetry-core"]
3 | build-backend = "poetry.core.masonry.api"
4 |
5 | [tool.poetry]
6 | name = "parallel-prompt-decoding"
7 | version = "1.0"
8 | description = "Cheap and efficient LLM inference Acceleration using Prompting"
9 | packages = [
10 | { include = "prompt", from = "." }
11 | ]
12 | authors = ["Mark (Hao) Chen "]
13 | readme = "README.md"
14 | license = "Apache-2.0"
15 | keywords = ["LLM", "Prompting", "Inference Acceleration", "NLP", "Machine Learning", "Language Model"]
16 |
17 | [tool.poetry.dependencies]
18 | python = ">=3.9"
19 | fschat = "^0.2.36"
20 | torch = "^2.0.1"
21 | transformers = "4.37.2"
22 | accelerate = "^0.27.2"
23 | peft = "^0.8.0"
24 | datasets = ">=2.17.0"
25 | numpy = ">=1.26.0"
26 | bitsandbytes = "^0.42.0"
27 | setuptools = "*"
28 | sentencepiece = "*"
29 | protobuf = "^4.25.3"
30 | matplotlib = "*"
31 | gradio = "*"
32 | openai = "*"
33 | anthropic = "*"
34 | zeus-ml = "*"
35 | human_eval = "*"
36 |
--------------------------------------------------------------------------------
/script/eval/eval-2-1-ensemble.sh:
--------------------------------------------------------------------------------
1 | python3 prompt/evaluation/eval.py --model_path hmarkc/ppd-vicuna-7b-v1.3\
2 | --model_name ppd-vicuna-7b-v1.3 \
3 | --data_path ./data/alpaca_eval.json \
4 | --save_dir ./log/eval/ \
5 |
--------------------------------------------------------------------------------
/script/latency/optimal-sparse-tree.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # use the alpaca eval dataset to find the optimal sparse tree
3 | # Accept length for Vicuna 7b
4 | python accept_length.py \
5 | --dir-path ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-7b \
6 | --file-name dynamic_sparse_tree \
7 | --model-name hmarkc/ppd-vicuna-7b-v1.3 \
8 | --eval-file-name gen_model_answer_prompt_decoding.py \
9 | --run-baseline \
10 | --n 1 \
11 | --max-length 79 \
12 | --min-length 60 \
13 | --length-interval 9 \
14 | --choices "[5, 10, 20, 35, 60, 120, 200, 500]" \
15 |
16 |
17 | # latency for Vicuna 7b
18 | python3 tree_latency.py \
19 | --model-path hmarkc/ppd-vicuna-7b-v1.3 \
20 | --model-id vicuna_faster \
21 | --answer-file ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-7b-21-04/tree_latency.jsonl \
22 | --bench-name alpaca_eval \
23 | --min-tree-length 60 \
24 | --max-tree-length 120 \
25 | --length-interval 3 \
26 | --max-new-token 1024
27 |
28 | # Accept length for Vicuna 13b
29 | python accept_length.py \
30 | --dir-path ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-13b \
31 | --file-name dynamic_sparse_tree \
32 | --model-name hmarkc/ppd-vicuna-13b-v1.3\
33 | --eval-file-name gen_model_answer_prompt_decoding.py \
34 | --max-length 120 \
35 | --min-length 60 \
36 | --length-interval 3 \
37 | --n 1 \
38 |
39 | # latency for Vicuna 13b
40 | python3 tree_latency.py \
41 | --model-path hmarkc/ppd-vicuna-13b-v1.3 \
42 | --model-id vicuna_faster \
43 | --answer-file ./data/alpaca_eval/dynamic_sparse_tree_search/3-1-13b/tree_latency.jsonl \
44 | --bench-name alpaca_eval \
45 | --min-tree-length 60 \
46 | --max-tree-length 120 \
47 | --length-interval 3 \
48 | --max-new-token 1024
49 |
50 | # Accept length for full sparse tree
51 | # python accept_length.py \
52 | # --dir-path data/alpaca_eval/sparse_tree_search/3-1-7b/ \
53 | # --file-name full_sparse_tree \
54 | # --model-name hmarkc/ppd-vicuna-7b-v1.3 \
55 | # --eval-file-name gen_model_answer_full_sparse_tree.py \
56 | # --choices "[5, 10, 20, 35, 60, 120, 200, 500]" \
57 | # --n 1 \
58 |
59 | # python accept_length.py \
60 | # --dir-path data/alpaca_eval/sparse_tree_search/3-1-13b/ \
61 | # --file-name sparse_tree \
62 | # --model-name hmarkc/ppd-vicuna-13b-v1.3 \
63 | # --eval-file-name gen_model_answer_full_sparse_tree.py \
64 | # --max-length 120 \
65 | # --min-length 60 \
66 | # --length-interval 3 \
67 | # --n 1 \
68 |
69 | # Accept length for random sparse tree
70 | python accept_length.py \
71 | --dir-path data/alpaca_eval/random_tree_search/3-1-7b/ \
72 | --file-name random_sparse_tree \
73 | --model-name hmarkc/ppd-vicuna-7b-v1.3 \
74 | --eval-file-name gen_model_answer_random_sparse_tree.py \
75 | --choices "[5, 10, 20, 35, 60, 120, 200, 500]" \
76 | --n 1 \
77 |
78 | # Accept length for MobileLLaMA
79 | python accept_length.py \
80 | --dir-path ./data/alpaca_eval/dynamic_sparse_tree_search/MobileLLaMA \
81 | --file-name dynamic_sparse_tree \
82 | --model-name ../test/MobileLLaMA \
83 | --eval-file-name gen_model_answer_prompt_decoding.py \
84 | --n 1 \
85 | --max-length 79 \
86 | --min-length 60 \
87 | --length-interval 9 \
88 | --choices "[75, 105, 135, 165, 195, 225, 255, 285]" \
89 |
90 |
91 | # latency for MobileLLaMA
92 | python3 tree_latency.py \
93 | --model-path ../test/MobileLLaMA \
94 | --model-id MobileLLaMA \
95 | --answer-file ./data/alpaca_eval/dynamic_sparse_tree_search/MobileLLaMA/tree_latency.jsonl \
96 | --bench-name alpaca_eval \
97 | --min-tree-length 60 \
98 | --max-tree-length 285 \
99 | --length-interval 3 \
100 | --max-new-token 1024
--------------------------------------------------------------------------------
/script/latency/vicuna-13b-gen.sh:
--------------------------------------------------------------------------------
1 | # MT bench, temperature sampling
2 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/mt_bench/experiments/vicuna-13b-baseline1.jsonl --bench-name mt_bench
3 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/mt_bench/experiments/vicuna-13b-baseline2.jsonl --bench-name mt_bench
4 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/mt_bench/experiments/vicuna-13b-baseline3.jsonl --bench-name mt_bench
5 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-13b-faster1.jsonl --tree-length 60 --bench-name mt_bench
6 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-13b-faster2.jsonl --tree-length 60 --bench-name mt_bench
7 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-13b-faster3.jsonl --tree-length 60 --bench-name mt_bench
8 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-13b-faster4.jsonl --tree-length 60 --bench-name mt_bench
9 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-13b-faster5.jsonl --tree-length 60 --bench-name mt_bench
10 |
11 | # MT bench, greedy sampling
12 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/mt_bench/experiments/vicuna-13b-baseline1-greedy.jsonl --bench-name mt_bench --temperature 0.0
13 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/mt_bench/experiments/vicuna-13b-baseline2-greedy.jsonl --bench-name mt_bench --temperature 0.0
14 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/mt_bench/experiments/vicuna-13b-baseline3-greedy.jsonl --bench-name mt_bench --temperature 0.0
15 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-13b-faster1-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0
16 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-13b-faster2-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0
17 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-13b-faster3-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0
18 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-13b-faster4-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0
19 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-13b-faster5-greedy.jsonl --tree-length 60 --bench-name mt_bench --temperature 0.0
20 |
21 | # HumanEval
22 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/humaneval/experiments/vicuna-13b-baseline1.jsonl --bench-name humaneval --max-new-token 512
23 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/humaneval/experiments/vicuna-13b-baseline2.jsonl --bench-name humaneval --max-new-token 512
24 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/humaneval/experiments/vicuna-13b-baseline3.jsonl --bench-name humaneval --max-new-token 512
25 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/humaneval/experiments/vicuna-13b-faster1.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512
26 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/humaneval/experiments/vicuna-13b-faster2.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512
27 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/humaneval/experiments/vicuna-13b-faster3.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512
28 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/humaneval/experiments/vicuna-13b-faster4.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512
29 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/humaneval/experiments/vicuna-13b-faster5.jsonl --tree-length 60 --bench-name humaneval --max-new-token 512
30 |
31 | # GSM8K
32 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline1 --answer-file data/gsm8k/experiments/vicuna-13b-baseline1.jsonl --bench-name gsm8k
33 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline2 --answer-file data/gsm8k/experiments/vicuna-13b-baseline2.jsonl --bench-name gsm8k
34 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-13b-baseline3 --answer-file data/gsm8k/experiments/vicuna-13b-baseline3.jsonl --bench-name gsm8k
35 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster1 --answer-file data/gsm8k/experiments/vicuna-13b-faster1.jsonl --tree-length 60 --bench-name gsm8k
36 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster2 --answer-file data/gsm8k/experiments/vicuna-13b-faster2.jsonl --tree-length 60 --bench-name gsm8k
37 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster3 --answer-file data/gsm8k/experiments/vicuna-13b-faster3.jsonl --tree-length 60 --bench-name gsm8k
38 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster4 --answer-file data/gsm8k/experiments/vicuna-13b-faster4.jsonl --tree-length 60 --bench-name gsm8k
39 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-13b-v1.3 --model-id vicuna-faster5 --answer-file data/gsm8k/experiments/vicuna-13b-faster5.jsonl --tree-length 60 --bench-name gsm8k
40 |
--------------------------------------------------------------------------------
/script/latency/vicuna-7b-gen.sh:
--------------------------------------------------------------------------------
1 | # MT bench, temperature sampling
2 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/mt_bench/experiments/vicuna-7b-baseline1.jsonl --bench-name mt_bench
3 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/mt_bench/experiments/vicuna-7b-baseline2.jsonl --bench-name mt_bench
4 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/mt_bench/experiments/vicuna-7b-baseline3.jsonl --bench-name mt_bench
5 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-7b-faster1.jsonl --tree-length 105 --bench-name mt_bench
6 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-7b-faster2.jsonl --tree-length 105 --bench-name mt_bench
7 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-7b-faster3.jsonl --tree-length 105 --bench-name mt_bench
8 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-7b-faster4.jsonl --tree-length 105 --bench-name mt_bench
9 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-7b-faster5.jsonl --tree-length 105 --bench-name mt_bench
10 |
11 | # MT bench, greedy sampling
12 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/mt_bench/experiments/vicuna-7b-baseline1-greedy.jsonl --bench-name mt_bench --temperature 0.0
13 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/mt_bench/experiments/vicuna-7b-baseline2-greedy.jsonl --bench-name mt_bench --temperature 0.0
14 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/mt_bench/experiments/vicuna-7b-baseline3-greedy.jsonl --bench-name mt_bench --temperature 0.0
15 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/mt_bench/experiments/vicuna-7b-faster1-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0
16 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/mt_bench/experiments/vicuna-7b-faster2-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0
17 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/mt_bench/experiments/vicuna-7b-faster3-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0
18 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/mt_bench/experiments/vicuna-7b-faster4-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0
19 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/mt_bench/experiments/vicuna-7b-faster5-greedy.jsonl --tree-length 105 --bench-name mt_bench --temperature 0.0
20 |
21 | # HumanEval
22 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/humaneval/experiments/vicuna-7b-baseline1.jsonl --bench-name humaneval --max-new-token 512
23 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/humaneval/experiments/vicuna-7b-baseline2.jsonl --bench-name humaneval --max-new-token 512
24 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/humaneval/experiments/vicuna-7b-baseline3.jsonl --bench-name humaneval --max-new-token 512
25 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/humaneval/experiments/vicuna-7b-faster1.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512
26 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/humaneval/experiments/vicuna-7b-faster2.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512
27 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/humaneval/experiments/vicuna-7b-faster3.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512
28 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/humaneval/experiments/vicuna-7b-faster4.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512
29 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/humaneval/experiments/vicuna-7b-faster5.jsonl --tree-length 105 --bench-name humaneval --max-new-token 512
30 |
31 | # GSM8K
32 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline1 --answer-file data/gsm8k/experiments/vicuna-7b-baseline1.jsonl --bench-name gsm8k
33 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline2 --answer-file data/gsm8k/experiments/vicuna-7b-baseline2.jsonl --bench-name gsm8k
34 | python3 gen_model_answer_baseline.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-7b-baseline3 --answer-file data/gsm8k/experiments/vicuna-7b-baseline3.jsonl --bench-name gsm8k
35 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster1 --answer-file data/gsm8k/experiments/vicuna-7b-faster1.jsonl --tree-length 105 --bench-name gsm8k
36 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster2 --answer-file data/gsm8k/experiments/vicuna-7b-faster2.jsonl --tree-length 105 --bench-name gsm8k
37 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster3 --answer-file data/gsm8k/experiments/vicuna-7b-faster3.jsonl --tree-length 105 --bench-name gsm8k
38 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster4 --answer-file data/gsm8k/experiments/vicuna-7b-faster4.jsonl --tree-length 105 --bench-name gsm8k
39 | python3 gen_model_answer_prompt_decoding.py --model-path hmarkc/ppd-vicuna-7b-v1.3 --model-id vicuna-faster5 --answer-file data/gsm8k/experiments/vicuna-7b-faster5.jsonl --tree-length 105 --bench-name gsm8k
40 |
--------------------------------------------------------------------------------
/script/train/train-ensemble-attention-kd.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | accelerate launch --num_processes 4 prompt/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
3 | --dataset_path "./ShareGPT_training_dataset_2_distillation.pt" \
4 | --output_dir test/ \
5 | --num_train_epochs 1 \
6 | --save_steps 500 \
7 | --model_max_length 2048 \
8 | --num_special_tokens 3 \
9 | --virtual_tokens_per_special_token 1 \
10 | --per_device_train_batch_size 1 \
11 | --per_device_eval_batch_size 1 \
12 | --gradient_accumulation_steps 4 \
13 | --evaluation_strategy "no" \
14 | --learning_rate 1e-2 \
15 | --weight_decay 0.0 \
16 | --warmup_ratio 0.0 \
17 | --lr_scheduler_type "cosine" \
18 | --logging_steps 10 \
19 | --load_in_4bit \
20 | --vt_attention_type "ensemble" \
21 | --trainer_type "distillation_trainer"
22 | # --use_prefix_tuning \
23 | # --prefix_virtual_tokens 10 \
24 | # --size 100
25 | # --tf32 True \ requires at least Ampere
26 | # 2048 length not working for batch 1
27 |
--------------------------------------------------------------------------------