├── .gitignore
├── CITATION.cff
├── LICENSE
├── README.md
├── ROADMAP.md
├── assets
├── category_speedup.png
├── logo.png
├── medusa_acc.csv
├── medusa_choices.png
├── medusa_demo.gif
├── medusa_pipeline.jpg
├── medusa_speedup_cmp.jpg
└── size_speedup.png
├── create_data.py
├── data_generation
├── README.md
├── convert_to_sharegpt.py
└── generate.py
├── deepspeed.json
├── llm_judge
├── README.md
├── data
│ ├── judge_prompts.jsonl
│ └── mt_bench
│ │ ├── model_answer
│ │ ├── medusa-vicuna-13b-v1.3-1-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── medusa-vicuna-13b-v1.3-2-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── medusa-vicuna-13b-v1.3-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── medusa-vicuna-33b-v1.3-1-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── medusa-vicuna-33b-v1.3-2-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── medusa-vicuna-33b-v1.3-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── medusa-vicuna-7b-v1.3-1-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── medusa-vicuna-7b-v1.3-2-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ └── medusa-vicuna-7b-v1.3-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl
│ │ ├── model_judgment
│ │ └── gpt-4_single.jsonl
│ │ ├── question.jsonl
│ │ └── reference_answer
│ │ └── gpt-4.jsonl
├── gen_judgement.py
├── gen_model_answer_baseline.py
├── gen_model_answer_baseline_inf_only.py
├── gen_model_answer_huggingface.py
├── gen_model_answer_medusa.py
├── gen_model_answer_medusa_inf_only.py
├── gen_model_answer_medusa_legacy.py
└── show_result.py
├── medusa
├── __init__.py
├── eval
│ ├── README.md
│ ├── gen_results.py
│ └── heads_accuracy.py
├── hf_utils.py
├── inference
│ ├── __init__.py
│ └── cli.py
├── model
│ ├── __init__.py
│ ├── kv_cache.py
│ ├── medusa_choices.py
│ ├── medusa_model.py
│ ├── medusa_model_legacy.py
│ ├── medusa_model_new.py
│ ├── modeling_llama_kv.py
│ ├── modeling_llama_kv_legacy.py
│ ├── modeling_mistral_kv.py
│ ├── utils.py
│ └── utils_legacy.py
└── train
│ ├── __init__.py
│ └── train_legacy.py
├── notebooks
├── medusa_configuration_explained.ipynb
├── medusa_inference_explained.ipynb
└── medusa_introduction.ipynb
├── pyproject.toml
├── scripts
├── train_vicuna_33b_8bit.sh
└── train_vicuna_7b.sh
└── simple_gradio_interface.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # wandb
163 | .wandb/
164 | wandb/
165 |
166 | ShareGPT_Vicuna_unfiltered/
167 |
168 | test_medusa*
169 |
170 | # test
171 | notebooks/test*.ipynb
172 | notebooks/*.pdf
173 | llm_judge/*.sh
174 | llm_judge/data/mt_bench_test
175 | llm_judge/data/mt_bench_test_rs
176 | data
177 | medusa/eval/*.sh
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | references:
4 | - type: article
5 | authors:
6 | - family-names: Cai
7 | given-names: Tianle
8 | - family-names: Li
9 | given-names: Yuhong
10 | - family-names: Geng
11 | given-names: Zhengyang
12 | - family-names: Peng
13 | given-names: Hongwu
14 | - family-names: Lee
15 | given-names: Jason D.
16 | - family-names: Chen
17 | given-names: Deming
18 | - family-names: Dao
19 | given-names: Tri
20 | title: "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads"
21 | year: 2024
22 | journal: "arXiv preprint arXiv: 2401.10774"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads
2 |
3 |
4 | | Blog | Report | Roadmap |
6 |
7 |
8 | ---
9 | *News* 🔥
10 | - [2024/1] Medusa technical report is now available on [arXiv](https://arxiv.org/abs/2401.10774). We've added multiple new features, including Medusa-2 recipe for full-model training, self-distillation for adding Medusa to any fine-tuned LLM, etc. The new results show a 2.2-3.6x speedup over the original model on a range of LLMs.
11 |
12 | ---
13 | ## Introduction
14 |
15 | Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads.
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | Medusa-1 on Vicuna-7b.
24 |
25 |
26 |
27 |
28 |
29 | We aim to tackle the three pain points of popular acceleration techniques like speculative decoding:
30 |
31 | - Requirement of a good draft model.
32 | - System complexity.
33 | - Inefficiency when using sampling-based generation.
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training. During generation, these heads each produce multiple likely words for the corresponding position. These options are then combined and processed using a tree-based attention mechanism. Finally, a typical acceptance scheme is employed to pick the longest plausible prefix from the candidates for further decoding.
43 |
44 |
45 |
46 |
47 | We aim to solve the challenges associated with speculative decoding by implementing the following ideas:
48 |
49 | - Instead of introducing a new model, we train multiple decoding heads on the *same* model.
50 | - The training is parameter-efficient so that even the "GPU-Poor" can do it. And since there is no additional model, there is no need to adjust the distributed computing setup.
51 | - Relaxing the requirement of matching the distribution of the original model makes the non-greedy generation even faster than greedy decoding.
52 |
53 | In the initial release, our primary focus is on optimizing Medusa for a batch size of 1—a setting commonly utilized for local model hosting. In this configuration, Medusa delivers approximately a 2x speed increase across a range of Vicuna models. We are actively working to extend Medusa's capabilities by integrating it into additional inference frameworks, with the aim of achieving even greater performance gains and extending Medusa to broader settings.
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 | In the updated version, we add support for full-model training, called Medusa-2 (compared to Medusa-1, which only trains the new heads), which requires a special recipe that adds the speculative prediction ability while keeping the original model's performance.
62 |
63 | We also add support for self-distillation, which allows us to add Medusa to any fine-tuned LLM without requiring the availability of the original training data.
64 |
65 | ## Contents
66 | - [Introduction](#introduction)
67 | - [Contents](#contents)
68 | - [Installation](#installation)
69 | - [Method 1: With pip (may not be the latest version)](#method-1-with-pip-may-not-be-the-latest-version)
70 | - [Method 2: From the source (recommended)](#method-2-from-the-source-recommended)
71 | - [Model Weights](#model-weights)
72 | - [Inference](#inference)
73 | - [Training](#training)
74 | - [Training (legacy)](#training-legacy)
75 | - [Push to Hugging Face Hub](#push-to-hugging-face-hub)
76 | - [Citation](#citation)
77 | - [Codebase Guide](#codebase-guide)
78 | - [Community Adoption](#community-adoption)
79 | - [Contributing](#contributing)
80 | - [Acknowledgements](#acknowledgements)
81 |
82 | ## Installation
83 | ### Method 1: With pip (may not be the latest version)
84 | ```bash
85 | pip install medusa-llm
86 | ```
87 | ### Method 2: From the source (recommended)
88 | ```bash
89 | git clone https://github.com/FasterDecoding/Medusa.git
90 | cd Medusa
91 | pip install -e .
92 | ```
93 |
94 | ### Model Weights
95 | #### Medusa-1
96 | | Size | Chat Command | Hugging Face Repo |
97 | | ---- | --------------------------------------------- | --------------------------------------------------------------------- |
98 | | 7B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-7b-v1.3` | [FasterDecoding/medusa-vicuna-7b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3) |
99 | | 13B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-13b-v1.3` | [FasterDecoding/medusa-vicuna-13b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-13b-v1.3) |
100 | | 33B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-33b-v1.3` | [FasterDecoding/medusa-vicuna-33b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-33b-v1.3) |
101 |
102 | #### Medusa-2
103 | | Size | Chat Command | Hugging Face Repo |
104 | | ---- | --------------------------------------------- | --------------------------------------------------------------------- |
105 | | Zephyr-7B-Beta | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-zephyr-7b-beta` | [FasterDecoding/medusa-1.0-zephyr-7b-beta](https://huggingface.co/FasterDecoding/medusa-1.0-zephyr-7b-beta) |
106 | | Vicuna-7B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-7b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-7b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-7b-v1.5) |
107 | | Vicuna-13B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-13b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-13b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-13b-v1.5) |
108 | | Vicuna-33B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-33b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-33b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-33b-v1.5) |
109 |
110 |
111 | ### Inference
112 | We currently support single-GPU inference with a batch size of 1, which is the most common setup for local model hosting. We are actively working to extend Medusa's capabilities by integrating it into other inference frameworks; please don't hesitate to reach out if you are interested in contributing to this effort.
113 |
114 | You can use the following command to launch a CLI interface:
115 | ```bash
116 | CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa model]
117 | ```
118 | You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with `--base-model [path of base model]`.
119 |
120 | ### Training
121 | In the updated version, we use the amazing [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library to manage the training process. Please refer to our [fork](https://github.com/ctlllll/axolotl) for the training code. The major code modifications are in [`src/axolotl/utils/models.py`](https://github.com/ctlllll/axolotl/blob/main/src/axolotl/utils/models.py). The training configs can be found in [`examples/medusa`](https://github.com/ctlllll/axolotl/tree/main/examples/medusa). A typical training command is as follows:
122 | ```bash
123 | accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml
124 | ```
125 |
126 | The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo.
127 |
128 | ### Training on various architectures
129 | *The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.*
130 |
131 | For training, please install:
132 | ```bash
133 | pip install -e ".[train]"
134 | ```
135 | #### Prepare the data
136 | We take a public version of the ShareGPT dataset, which is a subset of the Vicuna training data. For other models, you can use the corresponding training dataset.
137 | ```bash
138 | git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered
139 | ```
140 | Remark: If you haven't installed `git-lfs`, please install it before cloning:
141 | ```bash
142 | git lfs install
143 | ```
144 |
145 | #### Adapt the data to the model you want to enable medusa on.
146 |
147 | Start by launch an inference server you like that will run the model you want to train on.
148 | Let's use [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as an example.
149 |
150 | For instance you can use [text-generation-inference](https://github.com/huggingface/text-generation-inference), which you
151 | can also use after you've trained the medusa heads.
152 |
153 | ```
154 | model=mistralai/Mistral-7B-Instruct-v0.2
155 | volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run
156 | docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000
157 | ```
158 | The sequences in shareGPT are relatively long for some, so make sure you can infer on those. If you do not have enough room, the script will simply ignore those long conversation.
159 | It shouldn't impact too much downstream performance, but more data is always better.
160 | You can use various tradeoffs to [speed up inference](https://huggingface.co/docs/text-generation-inference/index) but the defaults show be good enough in most cases.
161 |
162 | ```
163 | python create_data.py --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json --output-filename mistral.json
164 | ```
165 |
166 | #### Train the model
167 | We follow the training setup from [FastChat](https://github.com/lm-sys/FastChat#fine-tuning), but with a much larger learning rate because we freeze the original model and only train the new heads. Here is the training command for the Vicuna-7b model on 4 GPUs. Since we are only training the new heads, the training does not require a lot of memory, and only data parallelism is needed. You can modify the script to fit your own setup. For larger models, we use the same setup. You can also use `--load_in_8bit` or `--load_in_4bit` to load the base model in quantized format.
168 | ```bash
169 | torchrun --nproc_per_node=4 medusa/train/train_legacy.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \
170 | --data_path mistral.json \
171 | --bf16 True \
172 | --output_dir test \
173 | --num_train_epochs 2 \
174 | --per_device_train_batch_size 8 \
175 | --per_device_eval_batch_size 8 \
176 | --gradient_accumulation_steps 4 \
177 | --evaluation_strategy "no" \
178 | --save_strategy "no" \
179 | --learning_rate 1e-3 \
180 | --weight_decay 0.0 \
181 | --warmup_ratio 0.1 \
182 | --lr_scheduler_type "cosine" \
183 | --logging_steps 1 \
184 | --tf32 True \
185 | --model_max_length 2048 \
186 | --lazy_preprocess True \
187 | --medusa_num_heads 3 \
188 | --medusa_num_layers 1 \
189 | --deepspeed deepspeed.json
190 | ```
191 | ### Push to Hugging Face Hub
192 | You can use the following command to push your model to the Hugging Face Hub:
193 | ```bash
194 | python -m medusa.hf_utils --folder [path of the model folder] --repo [name of the repo]
195 | ```
196 |
197 | ## Citation
198 | ```bibtex
199 | @article{cai2024medusa,
200 | title = {Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads},
201 | author = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Jason D. Lee and Deming Chen and Tri Dao},
202 | year = {2024},
203 | journal = {arXiv preprint arXiv: 2401.10774}
204 | }
205 | ```
206 |
207 | ## Codebase Guide
208 | `medusa/model/medusa_model.py` is the key file for Medusa. It contains the `MedusaModel` class, which is a wrapper of the original model and the new heads. This class also has an implementation of a streaming generation method. If you want to dive into the details of Medusa, this is the place to start.
209 |
210 | We also provide some illustrative notebooks in `notebooks/` to help you understand the codebase.
211 |
212 | ## Community Adoption
213 | We are super excited to see that Medusa has been adopted by many open-source projects. Here is an (incomplete) list:
214 | - [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/medusa)
215 | - [TGI](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/medusa.py)
216 | - [RTP-LLM](https://github.com/alibaba/rtp-llm/blob/main/docs/SpeculativeDecoding-Tutroial.md#medusa-decoding)
217 |
218 | We are grateful to the authors for their contributions to the community and sincerely hope that Medusa can help accelerate the development of LLMs. If you are using Medusa in your project, please let us know, and we will add your project to the list.
219 |
220 | ## Contributing
221 | We welcome community contributions to Medusa. If you have an idea for how to improve it, please open an issue to discuss it with us. When submitting a pull request, please ensure that your changes are well-tested. Please split each major change into a separate pull request. We also have a [Roadmap](ROADMAP.md) summarizing our future plans for Medusa. Don't hesitate to reach out if you are interested in contributing to any of the items on the roadmap.
222 |
223 | ## Acknowledgements
224 | This codebase is influenced by remarkable projects from the LLM community, including [FastChat](https://github.com/lm-sys/FastChat), [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/), [vllm](https://github.com/vllm-project/vllm), [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl).
225 |
226 | This project is supported by [Together AI](https://together.ai/), [MyShell AI](https://myshell.ai/), [Chai AI](https://www.chai-research.com/).
227 |
--------------------------------------------------------------------------------
/ROADMAP.md:
--------------------------------------------------------------------------------
1 | # Roadmap
2 |
3 | ## Functionality
4 | - [ ] Batched inference
5 | - [ ] Fine-grained KV cache management
6 | - [x] Explore tree sparsity
7 | - [x] Fine-tune Medusa heads together with LM head from scratch
8 | - [x] Distill from any model without access to the original training data
9 |
10 | ## Integration
11 | ### Local Deployment
12 | - [ ] [mlc-llm](https://github.com/mlc-ai/mlc-llm)
13 | - [ ] [exllama](https://github.com/turboderp/exllama)
14 | - [ ] [llama.cpp](https://github.com/ggerganov/llama.cpp)
15 | ### Serving
16 | - [ ] [vllm](https://github.com/vllm-project/vllm)
17 | - [ ] [lightllm](https://github.com/ModelTC/lightllm)
18 | - [x] [TGI](https://github.com/huggingface/text-generation-inference)
19 | - [x] [TensorRT](https://github.com/NVIDIA/TensorRT-LLM)
--------------------------------------------------------------------------------
/assets/category_speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/category_speedup.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/logo.png
--------------------------------------------------------------------------------
/assets/medusa_acc.csv:
--------------------------------------------------------------------------------
1 | "Name","Created","Runtime","End Time","Hostname","ID","Notes","State","Updated","Tags","eval_batch_size","logging_dir","output_dir","per_device_eval_batch_size","per_device_train_batch_size","train_batch_size","train/loss","train/medusa0_loss","train/medusa0_top1","train/medusa0_top2","train/medusa0_top3","train/medusa0_top4","train/medusa0_top5","train/medusa1_loss","train/medusa1_top1","train/medusa1_top2","train/medusa1_top3","train/medusa1_top4","train/medusa1_top5","train/medusa2_loss","train/medusa2_top1","train/medusa2_top2","train/medusa2_top3","train/medusa2_top4","train/medusa2_top5","train/medusa3_loss","train/medusa3_top1","train/medusa3_top2","train/medusa3_top3","train/medusa3_top4","train/medusa3_top5","train/medusa4_loss","train/medusa4_top1","train/medusa4_top2","train/medusa4_top3","train/medusa4_top4","train/medusa4_top5","train/train_loss","train/train_runtime","train/train_samples_per_second","train/train_steps_per_second"
2 | "33b","2023-08-14T02:40:59.000Z","3199","2023-08-14T03:34:18.000Z","della-l07g4","av0zctkn","-","failed","2023-08-17T18:54:58.000Z","","4","test/runs/Aug13_22-38-18_della-l07g4","test_medusa_mlp_vicuna-33b-v1.3_medusa_5_lr_0.001_layers_1","4","4","4","19.7555","1.8279917240142824","0.6045850515365601","0.7161898016929626","0.7648836374282837","0.798293948173523","0.8222854137420654","3.4793782234191895","0.3557846248149872","0.4689888060092926","0.5287009477615356","0.5692198276519775","0.59605473279953","4.460977077484131","0.22463124990463257","0.31970855593681335","0.3785320818424225","0.4210058748722077","0.4462413489818573","5.0544304847717285","0.15958771109580994","0.24880042672157288","0.30726853013038635","0.34796518087387085","0.37551093101501465","5.388251781463623","0.13701795041561127","0.20668207108974457","0.2621290385723114","0.29838281869888306","0.32575085759162903","24.166565484840778","3234.9651","21.213","0.166"
3 | "13b","2023-08-13T22:31:29.000Z","2763","2023-08-13T23:17:32.000Z","della-l08g5","hy3g0c62","-","finished","2023-08-14T02:22:24.000Z","","8","test/runs/Aug13_18-29-53_della-l08g5","test_medusa_mlp_vicuna-13b-v1.3_medusa_5_lr_0.001_layers_1","8","8","8","19.5949","1.8737130165100095","0.5939363837242126","0.705268383026123","0.7578279972076416","0.7924950122833252","0.8161033391952515","3.575400114059448","0.3397117257118225","0.439985066652298","0.5068339705467224","0.5464711785316467","0.580268383026123","4.575368881225586","0.216078519821167","0.30852383375167847","0.3612077236175537","0.4029572308063507","0.4363816976547241","5.15444803237915","0.14997513592243197","0.23173458874225616","0.28777334094047546","0.32343438267707825","0.3541252315044403","5.453802585601807","0.12226639688014984","0.1957007795572281","0.243414506316185","0.2790755331516266","0.3070327937602997","24.09846121860838","2793.9231","24.562","0.192"
4 | "7b","2023-08-13T22:07:43.000Z","1909","2023-08-13T22:39:32.000Z","della-l08g2","ub9cluo4","-","finished","2023-08-14T02:22:20.000Z","","8","test/runs/Aug13_18-06-30_della-l08g2","test_medusa_mlp_vicuna-7b-v1.3_medusa_5_lr_0.001_layers_1","8","8","8","20.2451","2.069507122039795","0.5603876709938049","0.6717196702957153","0.7271371483802795","0.7599403262138367","0.7799453139305115","3.723043203353882","0.31635186076164246","0.4235834777355194","0.4850894510746002","0.5282057523727417","0.5625","4.692985534667969","0.2010437250137329","0.28789758682250977","0.3475397527217865","0.3834492862224579","0.4168737530708313","5.258499622344971","0.14736579358577728","0.22502483427524567","0.27373260259628296","0.3108846843242645","0.3397117257118225","5.5384345054626465","0.11754472553730012","0.18663020431995392","0.23459243774414065","0.2721172869205475","0.2998260259628296","24.93223100445568","1917.1552","35.794","0.28"
--------------------------------------------------------------------------------
/assets/medusa_choices.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_choices.png
--------------------------------------------------------------------------------
/assets/medusa_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_demo.gif
--------------------------------------------------------------------------------
/assets/medusa_pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_pipeline.jpg
--------------------------------------------------------------------------------
/assets/medusa_speedup_cmp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_speedup_cmp.jpg
--------------------------------------------------------------------------------
/assets/size_speedup.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/size_speedup.png
--------------------------------------------------------------------------------
/create_data.py:
--------------------------------------------------------------------------------
1 | import typer
2 | import json
3 | from transformers import Conversation
4 | from typing_extensions import Annotated
5 | import httpx
6 | import tqdm
7 | import asyncio
8 |
9 | app = typer.Typer()
10 |
11 |
12 | client = httpx.AsyncClient(timeout=None)
13 |
14 | async def run(conv: Conversation, url: str):
15 | payload = {"model":"tgi", "messages": conv.messages}
16 | response = await client.post(url, json=payload)
17 | content = response.json()
18 | message = content["choices"][0]["message"]
19 | message.pop("name", None)
20 | conv.add_message(message)
21 |
22 |
23 |
24 |
25 | def fix_source(source):
26 | if source and source[0]["from"] == "gpt":
27 | # Skip if GPT is first to talk
28 | source = source[1:]
29 | new_source = []
30 | for item in source:
31 | role = "assistant" if item["from"] == "gpt" else "user"
32 | content = item["value"]
33 | new_source.append({"role": role, "content": content})
34 | return new_source
35 |
36 |
37 | async def recreate_conversation(conversation, sem, url):
38 | async with sem:
39 | conv = Conversation()
40 | try:
41 | for message in conversation[::2]:
42 | assert message["role"] == "user"
43 | conv.add_message(message)
44 | await run(conv, url)
45 | except Exception as e:
46 | print(e)
47 | pass
48 | return conv.messages
49 |
50 | @app.command()
51 | def main(
52 | *,
53 | input_filename: Annotated[str, typer.Option("--input-filename")],
54 | output_filename: Annotated[str, typer.Option("--output-filename")],
55 | url: Annotated[str, typer.Option("--url")] = "http://localhost:8080/v1/chat/completions",
56 | concurrency: Annotated[int, typer.Option("--concurrency")] = 64
57 | ):
58 | sem = asyncio.Semaphore(concurrency)
59 | async def _main():
60 | with open(input_filename, "r") as f:
61 | input_data = json.loads(f.read())
62 | conversations = [fix_source(source["conversations"]) for source in input_data]
63 |
64 | futures = []
65 | for conversation in conversations:
66 | future = recreate_conversation(conversation, sem, url)
67 | futures.append(future)
68 |
69 | recreated_conversations = await tqdm.asyncio.tqdm.gather(*futures)
70 |
71 | with open(output_filename, "w") as f:
72 | json.dump(recreated_conversations, f, indent=4)
73 | asyncio.run(_main())
74 |
75 |
76 | if __name__ == "__main__":
77 | app()
78 |
--------------------------------------------------------------------------------
/data_generation/README.md:
--------------------------------------------------------------------------------
1 | # Generate chat data for self-distillation
2 | We use vLLM to enable batched generation. First, install dependencies:
3 | ```bash
4 | pip install vllm openai
5 | ```
6 |
7 | ## Start server
8 |
9 | ```bash
10 | python -m vllm.entrypoints.openai.api_server \
11 | --model YOUR_MODEL_NAME --port 8000
12 | ```
13 | You can also start multiple servers with different ports to enable parallel generation. In `generate.py`, we scan the ports from 8000 to 8009 to find available servers. You can modify the code to use other ports.
14 |
15 | ## Generate data
16 | The following command will let the model to continue the first prompt from each sample in `DATA_PATH`, this is suitable for models that can play both roles in a conversation (e.g., Zephyr 7B). If you want to use all prompts in each sample to repeatly talk to the model, use `--chat` instead. `--chat` mode works for more models but may take longer time to generate due to repeated computation (welcome to contribute a better implementation).
17 |
18 | ```bash
19 | python generate.py --data_path YOUR_DATA_PATH --output_path YOUR_OUTPUT_PATH --num_threads NUM_THREADS --max_tokens YOUR_MAX_TOKENS --temperature YOUR_TEMPERATURE
20 | ```
21 |
22 | ## (Optional) Format data
23 | When generated with `--chat`, the output file will follow the ShareGPT format ([example](https://github.com/lm-sys/FastChat/blob/main/data/dummy_conversation.json)).
24 | You can use the following command to convert the generated text withour `--chat` to the same format:
25 | ```bash
26 | python convert_to_sharegpt.py --input_path YOUR_INPUT_PATH --model_name YOUR_MODEL_NAME --output_path YOUR_OUTPUT_PATH
27 | ```
--------------------------------------------------------------------------------
/data_generation/convert_to_sharegpt.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import concurrent.futures
5 |
6 | import openai
7 | import shortuuid
8 | import tqdm
9 |
10 | import argparse
11 | import random
12 |
13 | from tenacity import (
14 | retry,
15 | stop_after_attempt,
16 | wait_random_exponential,
17 | )
18 |
19 | from fastchat.conversation import Conversation, SeparatorStyle
20 | from fastchat.model.model_adapter import get_conversation_template
21 | from transformers import AutoTokenizer
22 |
23 | # Use the same arguments as in generate.py
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--input_path", type=str)
26 | parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta")
27 | args = parser.parse_args()
28 |
29 | conv = get_conversation_template(args.model_name)
30 | tokenizer = AutoTokenizer.from_pretrained(args.model_name)
31 |
32 | data = []
33 | with open(args.input_path) as f:
34 | for line in f.readlines():
35 | data.append(json.loads(line))
36 |
37 | def convert(text):
38 | messages = []
39 |
40 | for turn in text.split(conv.roles[0]):
41 | pairs = turn.split(conv.roles[1])
42 | if len(pairs) != 2:
43 | continue
44 | messages.append({
45 | "from": "human",
46 | "value": pairs[0].split(conv.sep)[0].strip()
47 | })
48 | messages.append({
49 | "from": "gpt",
50 | "value": pairs[1].split(conv.sep)[0].strip()
51 | })
52 | # pop the last message because it might be incomplete
53 | if len(messages) > 0:
54 | messages.pop()
55 | # make sure number of messages is even
56 | if len(messages) % 2 == 1:
57 | messages.pop()
58 | return {"conversations": messages}
59 |
60 | sharegpt_data = []
61 | for d in tqdm.tqdm(data):
62 | sample = convert(d["text"])
63 | if len(sample["conversations"]) < 2:
64 | continue
65 | sharegpt_data.append(sample)
66 |
67 | # dump to jsonl
68 | with open(args.input_path.replace(".jsonl", "_sharegpt.jsonl"), "w") as f:
69 | for d in sharegpt_data:
70 | f.write(json.dumps(d) + "\n")
--------------------------------------------------------------------------------
/data_generation/generate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import concurrent.futures
5 |
6 | import openai
7 | import shortuuid
8 | import tqdm
9 |
10 | import argparse
11 | import random
12 |
13 | from tenacity import (
14 | retry,
15 | stop_after_attempt,
16 | wait_random_exponential,
17 | )
18 |
19 | from fastchat.conversation import Conversation, SeparatorStyle
20 | from fastchat.model.model_adapter import get_conversation_template
21 |
22 | # Modify OpenAI's API key and API base to use vLLM's API server.
23 | openai.api_key = "EMPTY"
24 | openai.api_base = "http://localhost:8000/v1"
25 |
26 | api_base_pool = []
27 |
28 | # List models API
29 | for i in range(10):
30 | openai.api_base = "http://localhost:800{}/v1".format(i)
31 | try:
32 | models = openai.Model.list()["data"][0]["id"]
33 | print(openai.api_base, models)
34 | api_base_pool.append(openai.api_base)
35 | except:
36 | break
37 |
38 | print("API base pool: ", api_base_pool)
39 |
40 | parser = argparse.ArgumentParser()
41 | parser.add_argument("--data_path", type=str)
42 | parser.add_argument("--output_path", type=str)
43 | parser.add_argument("--num_threads", type=int, default=256)
44 | parser.add_argument("--temperature", type=float, default=0.3)
45 | parser.add_argument("--max_tokens", type=int, default=2048)
46 | parser.add_argument("--chat", action="store_true")
47 | args = parser.parse_args()
48 |
49 | # Assuming the ShareGPT format
50 | data = json.load(open(args.data_path, "r"))
51 |
52 | def generate_data(messages, idx):
53 | try:
54 | # load balanced
55 | openai.api_base = api_base_pool[idx % len(api_base_pool)]
56 | model_name=openai.Model.list()["data"][0]["id"]
57 |
58 | if args.chat:
59 | converted_messages = []
60 | output_messages = []
61 | if messages[0]["from"] == "system":
62 | converted_messages.append(
63 | {
64 | "role": "system",
65 | "content": messages[0]["text"],
66 | }
67 | )
68 | output_messages.append(messages[0])
69 | messages = messages[1:]
70 | for message in messages[::2]:
71 | if message["from"] != "human":
72 | return
73 | converted_messages.append(
74 | {
75 | "role": "user",
76 | "content": message["value"],
77 | }
78 | )
79 | try:
80 | response = openai.ChatCompletion.create(
81 | model=model_name,
82 | messages=converted_messages,
83 | max_tokens=args.max_tokens,
84 | temperature=args.temperature,
85 | )
86 | if response.choices[0]['finish_reason'] == "length":
87 | break
88 | response = response.choices[0]['message']['content'].strip()
89 | output_messages.append(message)
90 | output_messages.append(
91 | {
92 | "from": "gpt",
93 | "value": response,
94 | }
95 | )
96 | converted_messages.append(
97 | {
98 | "role": "assistant",
99 | "content": response,
100 | }
101 | )
102 | except:
103 | break
104 | if len(output_messages) == 0:
105 | return
106 | with open(args.output_path, "a") as f:
107 | # write in share gpt format
108 | f.write(json.dumps({"conversations": output_messages}) + "\n")
109 | else:
110 | conv = get_conversation_template(model_name)
111 | if messages[0]["from"] == "system":
112 | conv.system_message = messages[0]["text"]
113 | messages = messages[1:]
114 | conv.append_message(conv.roles[0], messages[0]["value"])
115 | conv.append_message(conv.roles[1], None)
116 | prompt = conv.get_prompt()
117 |
118 | response = openai.Completion.create(
119 | model=model_name,
120 | prompt=prompt,
121 | max_tokens=args.max_tokens,
122 | temperature=args.temperature,
123 | ignore_eos=True,
124 | skip_special_tokens=False,
125 | spaces_between_special_tokens=False,
126 | )
127 | response = response.choices[0]['text'].strip()
128 | with open(args.output_path, "a") as f:
129 | # write in share gpt format
130 | f.write(json.dumps({"text": prompt+response}) + "\n")
131 | except Exception as e:
132 | print(e)
133 | print(prompt)
134 | print("Failed to generate data")
135 |
136 | # if output_path exists, count the number of lines and skip the first n data
137 | start = 0
138 | if os.path.exists(args.output_path):
139 | with open(args.output_path, "r") as f:
140 | start = len(f.readlines())
141 | print("Skip first {} data".format(start))
142 |
143 | with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_threads) as executor:
144 | futures = []
145 | for idx, sample in enumerate(data[start:]):
146 | future = executor.submit(
147 | generate_data,
148 | sample["conversations"],
149 | idx,
150 | )
151 | futures.append(future)
152 |
153 | for future in tqdm.tqdm(
154 | concurrent.futures.as_completed(futures), total=len(futures)
155 | ):
156 | future.result()
--------------------------------------------------------------------------------
/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 |
6 | "zero_optimization": {
7 | "stage": 3,
8 | "overlap_comm": true,
9 | "contiguous_gradients": true,
10 | "sub_group_size": 1e9,
11 | "reduce_bucket_size": "auto",
12 | "stage3_prefetch_bucket_size": "auto",
13 | "stage3_param_persistence_threshold": "auto",
14 | "stage3_max_live_parameters": 1e9,
15 | "stage3_max_reuse_distance": 1e9,
16 | "stage3_gather_16bit_weights_on_model_save": true
17 | },
18 |
19 | "gradient_accumulation_steps": "auto",
20 | "steps_per_print": 2000,
21 | "train_batch_size": "auto",
22 | "train_micro_batch_size_per_gpu": "auto",
23 | "wall_clock_breakdown": false
24 | }
25 |
--------------------------------------------------------------------------------
/llm_judge/README.md:
--------------------------------------------------------------------------------
1 | # LLM Judge
2 | | [Original Github Repository](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge)
3 |
4 | ## Installation
5 |
6 | | [Guide](https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/README.md)
7 |
8 | ## Usage
9 |
10 | We report the 3 times running results of the Medusa X Vicuna v1.3 7/13/33b on a single A100 in `./data/mt_bench/model_answer/`. The original settings are: `temperature` (it is deprecated and use the default LLM Judge setting), `posterior_threshold=0.09`, `posterior_alpha=0.3`.
11 |
12 | - Run benchmark
13 |
14 |
15 | ```
16 | export CUDA_VISIBLE_DEVICES=0 # set the GPU id
17 | python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-7b-v1.3 --model-id medusa-vicuna-7b-v1.3-0
18 | python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-13b-v1.3 --model-id medusa-vicuna-13b-v1.3-0
19 | python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-33b-v1.3 --model-id medusa-vicuna-33b-v1.3-0
20 | ```
21 |
22 | - Run baseline: replace `gen_model_answer_medusa.py` with `gen_model_answer_baseline.py` (Please note we only implement the greedy inference for wall-time comparison. If you want to use the sampling generator, please refer to the original repository.)
23 |
24 |
25 | - Query the results
26 |
27 | ```
28 | export OPENAI_API_KEY=$OPENAI_API_KEYs # set the OpenAI API key
29 | python gen_judgement.py --model-list medusa-vicuna-7b-v1.3-0-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3
30 | ```
31 |
32 | - Show results
33 |
34 | To obtain the results of GPT-4 judge for Vicuna-7b ( Huggingface greedy | Huggingface sampling | Medusa sampling), run:
35 |
36 | ```
37 | python show_result.py
38 | ```
39 |
40 | ## Citation
41 | Please cite the original paper if you find the code or datasets helpful.
42 | ```
43 | @misc{zheng2023judging,
44 | title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
45 | author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
46 | year={2023},
47 | eprint={2306.05685},
48 | archivePrefix={arXiv},
49 | primaryClass={cs.CL}
50 | }
51 | ```
--------------------------------------------------------------------------------
/llm_judge/data/judge_prompts.jsonl:
--------------------------------------------------------------------------------
1 | {"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"}
2 | {"name": "pair-v2-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. You should choose the assistant that follows the user's instructions and answers the user's questions better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. You should focus on who provides a better answer to the second user question. Begin your evaluation by comparing the responses of the two assistants and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"}
3 | {"name": "pair-math-v1", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for math questions", "category": "math", "output_format": "[[A]]"}
4 | {"name": "pair-math-v1-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. Your evaluation should consider correctness and helpfulness. You will be given reference answers, the assistant A's answers, the assistant B's answers. Your job is to determine which assistant provides correct and helpful answers to the second user question. Begin your evaluation by comparing both assistants' answers with the reference answers. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"}
5 | {"name": "single-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"}
6 | {"name": "single-math-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"}
7 | {"name": "single-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant's answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"}
8 | {"name": "single-math-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"}
9 |
--------------------------------------------------------------------------------
/llm_judge/gen_judgement.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all]
4 | """
5 | import argparse
6 | from concurrent.futures import ThreadPoolExecutor
7 | import json
8 |
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | from fastchat.llm_judge.common import (
13 | load_questions,
14 | load_model_answers,
15 | load_judge_prompts,
16 | check_data,
17 | play_a_match_pair,
18 | play_a_match_single,
19 | get_model_list,
20 | Judge,
21 | MatchPair,
22 | MatchSingle,
23 | NEED_REF_CATS,
24 | )
25 |
26 |
27 | def make_match(
28 | questions,
29 | models,
30 | model_answers,
31 | judge,
32 | baseline_model,
33 | ref_answers=None,
34 | multi_turn=False,
35 | ):
36 | matches = []
37 | for q in questions:
38 | if multi_turn and len(q["turns"]) != 2:
39 | continue
40 | for i in range(len(models)):
41 | q_id = q["question_id"]
42 | m_1 = models[i]
43 | m_2 = baseline_model
44 | if m_1 == m_2:
45 | continue
46 | a_1 = model_answers[m_1][q_id]
47 | a_2 = model_answers[baseline_model][q_id]
48 | if ref_answers is not None:
49 | ref = ref_answers[judge.model_name][q_id]
50 | match = MatchPair(
51 | dict(q),
52 | m_1,
53 | m_2,
54 | a_1,
55 | a_2,
56 | judge,
57 | ref_answer=ref,
58 | multi_turn=multi_turn,
59 | )
60 | else:
61 | match = MatchPair(
62 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn
63 | )
64 | matches.append(match)
65 | return matches
66 |
67 |
68 | def make_match_all_pairs(
69 | questions,
70 | models,
71 | model_answers,
72 | judge,
73 | baseline_model=None,
74 | ref_answers=None,
75 | multi_turn=False,
76 | ):
77 | matches = []
78 | for q in questions:
79 | if multi_turn and len(q["turns"]) != 2:
80 | continue
81 | for i in range(len(models)):
82 | for j in range(i + 1, len(models)):
83 | q_id = q["question_id"]
84 | m_1 = models[i]
85 | m_2 = models[j]
86 | a_1 = model_answers[m_1][q_id]
87 | a_2 = model_answers[m_2][q_id]
88 | if ref_answers is not None:
89 | ref = ref_answers[judge.model_name][q_id]
90 | match = MatchPair(
91 | dict(q),
92 | m_1,
93 | m_2,
94 | a_1,
95 | a_2,
96 | judge,
97 | ref_answer=ref,
98 | multi_turn=multi_turn,
99 | )
100 | else:
101 | match = MatchPair(
102 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn
103 | )
104 | matches.append(match)
105 | return matches
106 |
107 |
108 | def make_match_single(
109 | questions,
110 | models,
111 | model_answers,
112 | judge,
113 | baseline_model=None,
114 | ref_answers=None,
115 | multi_turn=False,
116 | ):
117 | matches = []
118 | for q in questions:
119 | if multi_turn and len(q["turns"]) != 2:
120 | continue
121 | for i in range(len(models)):
122 | q_id = q["question_id"]
123 | m = models[i]
124 | a = model_answers[m][q_id]
125 | if ref_answers is not None:
126 | ref = ref_answers[judge.model_name][q_id]
127 | matches.append(
128 | MatchSingle(
129 | dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn
130 | )
131 | )
132 | else:
133 | matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn))
134 | return matches
135 |
136 |
137 | def make_judge_pairwise(judge_model, judge_prompts):
138 | judges = {}
139 | judges["default"] = Judge(judge_model, judge_prompts["pair-v2"])
140 | judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True)
141 | judges["default-mt"] = Judge(
142 | judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True
143 | )
144 | judges["math-mt"] = Judge(
145 | judge_model,
146 | judge_prompts["pair-math-v1-multi-turn"],
147 | ref_based=True,
148 | multi_turn=True,
149 | )
150 | return judges
151 |
152 |
153 | def make_judge_single(judge_model, judge_prompts):
154 | judges = {}
155 | judges["default"] = Judge(judge_model, judge_prompts["single-v1"])
156 | judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True)
157 | judges["default-mt"] = Judge(
158 | judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True
159 | )
160 | judges["math-mt"] = Judge(
161 | judge_model,
162 | judge_prompts["single-math-v1-multi-turn"],
163 | ref_based=True,
164 | multi_turn=True,
165 | )
166 | return judges
167 |
168 |
169 | if __name__ == "__main__":
170 | parser = argparse.ArgumentParser()
171 | parser.add_argument(
172 | "--bench-name",
173 | type=str,
174 | default="mt_bench",
175 | help="The name of the benchmark question set.",
176 | )
177 | parser.add_argument(
178 | "--judge-file",
179 | type=str,
180 | default="data/judge_prompts.jsonl",
181 | help="The file of judge prompts.",
182 | )
183 | parser.add_argument("--judge-model", type=str, default="gpt-4")
184 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo")
185 | parser.add_argument(
186 | "--mode",
187 | type=str,
188 | default="single",
189 | choices=["pairwise-baseline", "pairwise-all", "single"],
190 | help=(
191 | "Evaluation mode. "
192 | "`pairwise-baseline` runs pairwise comparision against a baseline. "
193 | "`pairwise-all` runs pairwise comparision between all pairs. "
194 | "`single` runs single answer grading."
195 | ),
196 | )
197 | parser.add_argument(
198 | "--model-list",
199 | type=str,
200 | nargs="+",
201 | default=None,
202 | help="A list of models to be evaluated",
203 | )
204 | parser.add_argument(
205 | "--parallel", type=int, default=1, help="The number of concurrent API calls."
206 | )
207 | parser.add_argument(
208 | "--first-n", type=int, help="A debug option. Only run the first `n` judgments."
209 | )
210 | args = parser.parse_args()
211 |
212 | question_file = f"data/{args.bench_name}/question.jsonl"
213 | answer_dir = f"data/{args.bench_name}/model_answer"
214 | ref_answer_dir = f"data/{args.bench_name}/reference_answer"
215 |
216 | # Load questions
217 | questions = load_questions(question_file, None, None)
218 |
219 | # Load answers
220 | model_answers = load_model_answers(answer_dir)
221 | ref_answers = load_model_answers(ref_answer_dir)
222 |
223 | # Load judge
224 | judge_prompts = load_judge_prompts(args.judge_file)
225 |
226 | if args.first_n:
227 | questions = questions[: args.first_n]
228 |
229 | if args.model_list is None:
230 | models = get_model_list(answer_dir)
231 | else:
232 | models = args.model_list
233 |
234 | if args.mode == "single":
235 | judges = make_judge_single(args.judge_model, judge_prompts)
236 | play_a_match_func = play_a_match_single
237 | output_file = (
238 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl"
239 | )
240 | make_match_func = make_match_single
241 | baseline_model = None
242 | else:
243 | judges = make_judge_pairwise(args.judge_model, judge_prompts)
244 | play_a_match_func = play_a_match_pair
245 | output_file = (
246 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl"
247 | )
248 | if args.mode == "pairwise-all":
249 | make_match_func = make_match_all_pairs
250 | baseline_model = None
251 | else:
252 | make_match_func = make_match
253 | baseline_model = args.baseline_model
254 |
255 | check_data(questions, model_answers, ref_answers, models, judges)
256 |
257 | question_math = [q for q in questions if q["category"] in NEED_REF_CATS]
258 | question_default = [q for q in questions if q["category"] not in NEED_REF_CATS]
259 |
260 | # Make matches
261 | matches = []
262 | matches += make_match_func(
263 | question_default, models, model_answers, judges["default"], baseline_model
264 | )
265 | matches += make_match_func(
266 | question_math,
267 | models,
268 | model_answers,
269 | judges["math"],
270 | baseline_model,
271 | ref_answers,
272 | )
273 | matches += make_match_func(
274 | question_default,
275 | models,
276 | model_answers,
277 | judges["default-mt"],
278 | baseline_model,
279 | multi_turn=True,
280 | )
281 | matches += make_match_func(
282 | question_math,
283 | models,
284 | model_answers,
285 | judges["math-mt"],
286 | baseline_model,
287 | ref_answers,
288 | multi_turn=True,
289 | )
290 |
291 | # Filter out existed matches
292 | total_num_matches = len(matches)
293 | filtered_matches = []
294 | try:
295 | with open(output_file, "r") as f:
296 | existed_matches = [json.loads(line) for line in f]
297 | except FileNotFoundError:
298 | existed_matches = []
299 | uniq_ids = set(
300 | [
301 | f"{e['question_id']}_{e['model']}_{e['judge'][0]}_{e['judge'][1]}_{e['turn']}"
302 | for e in existed_matches
303 | ]
304 | )
305 | for match in matches:
306 | turn = 2 if match.judge.multi_turn else 1
307 | uniq_id = f"{match.question['question_id']}_{match.answer['model_id']}_{match.judge.model_name}_{match.judge.prompt_template['name']}_{turn}"
308 | if uniq_id in uniq_ids:
309 | print(f"Skip {uniq_id}")
310 | else:
311 | filtered_matches.append(match)
312 | matches = filtered_matches
313 |
314 | match_stat = {}
315 | match_stat["bench_name"] = args.bench_name
316 | match_stat["mode"] = args.mode
317 | match_stat["judge"] = args.judge_model
318 | match_stat["baseline"] = baseline_model
319 | match_stat["model_list"] = models
320 | match_stat["total_num_questions"] = len(questions)
321 | match_stat["total_num_matches"] = total_num_matches
322 | match_stat["current_num_matches"] = len(matches)
323 | match_stat["output_path"] = output_file
324 |
325 | # Show match stats and prompt enter to continue
326 | print("Stats:")
327 | print(json.dumps(match_stat, indent=4))
328 | input("Press Enter to confirm...")
329 |
330 | # Play matches
331 | if args.parallel == 1:
332 | for match in tqdm(matches):
333 | play_a_match_func(match, output_file=output_file)
334 | else:
335 |
336 | def play_a_match_wrapper(match):
337 | play_a_match_func(match, output_file=output_file)
338 |
339 | np.random.seed(0)
340 | np.random.shuffle(matches)
341 |
342 | with ThreadPoolExecutor(args.parallel) as executor:
343 | for match in tqdm(
344 | executor.map(play_a_match_wrapper, matches), total=len(matches)
345 | ):
346 | pass
--------------------------------------------------------------------------------
/llm_judge/gen_model_answer_medusa_legacy.py:
--------------------------------------------------------------------------------
1 | """Generate answers with local models.
2 |
3 | Usage:
4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0
5 | """
6 | import argparse
7 | import json
8 | import os
9 | import random
10 | import time
11 | import shortuuid
12 | import torch
13 | from tqdm import tqdm
14 |
15 | from fastchat.llm_judge.common import load_questions, temperature_config
16 | from fastchat.model import load_model, get_conversation_template
17 |
18 | # Medusa imports
19 | import transformers
20 |
21 |
22 | from medusa.model.utils import *
23 | from medusa.model.medusa_model import MedusaModel
24 | from medusa.model.kv_cache import initialize_past_key_values
25 | from medusa.model.medusa_choices import *
26 |
27 | def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
28 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
29 | # Avoid modifying the input_ids in-place
30 | input_ids = input_ids.clone()
31 |
32 | # Cache medusa buffers (the fixed patterns for tree attention)
33 | if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices:
34 | # Load the cached medusa buffer
35 | medusa_buffers = model.medusa_buffers
36 | else:
37 | # Initialize the medusa buffer
38 | medusa_buffers = generate_medusa_buffers(
39 | medusa_choices, device=model.base_model.device
40 | )
41 | model.medusa_buffers = medusa_buffers
42 | model.medusa_choices = medusa_choices
43 |
44 | # Initialize the past key and value states
45 | if hasattr(model, "past_key_values"):
46 | past_key_values = model.past_key_values
47 | past_key_values_data = model.past_key_values_data
48 | current_length_data = model.current_length_data
49 | # Reset the past key and value states
50 | current_length_data.zero_()
51 | else:
52 | (
53 | past_key_values,
54 | past_key_values_data,
55 | current_length_data,
56 | ) = initialize_past_key_values(model.base_model)
57 | model.past_key_values = past_key_values
58 | model.past_key_values_data = past_key_values_data
59 | model.current_length_data = current_length_data
60 |
61 | input_len = input_ids.shape[1]
62 | reset_medusa_mode(model)
63 | medusa_logits, logits = initialize_medusa(
64 | input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values
65 | )
66 | new_token = 0
67 |
68 | for idx in range(max_steps):
69 | candidates, tree_candidates = generate_candidates(
70 | medusa_logits,
71 | logits,
72 | medusa_buffers["tree_indices"],
73 | medusa_buffers["retrieve_indices"],
74 | )
75 | medusa_logits, logits, outputs = tree_decoding(
76 | model,
77 | tree_candidates,
78 | past_key_values,
79 | medusa_buffers["medusa_position_ids"],
80 | input_ids,
81 | medusa_buffers["retrieve_indices"],
82 | )
83 | best_candidate, accept_length = evaluate_posterior(
84 | logits, candidates, temperature, posterior_threshold, posterior_alpha
85 | )
86 | input_ids, logits, medusa_logits, new_token = update_inference_inputs(
87 | input_ids,
88 | candidates,
89 | best_candidate,
90 | accept_length,
91 | medusa_buffers["retrieve_indices"],
92 | outputs,
93 | logits,
94 | medusa_logits,
95 | new_token,
96 | past_key_values_data,
97 | current_length_data,
98 | )
99 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist():
100 | break
101 | if new_token > 1024:
102 | break
103 | return input_ids, new_token, idx
104 |
105 | def run_eval(
106 | model_path,
107 | model_id,
108 | question_file,
109 | question_begin,
110 | question_end,
111 | answer_file,
112 | max_new_token,
113 | num_choices,
114 | num_gpus_per_model,
115 | num_gpus_total,
116 | max_gpu_memory,
117 | temperature,
118 | posterior_threshold,
119 | posterior_alpha,
120 | medusa_choices,
121 | ):
122 | questions = load_questions(question_file, question_begin, question_end)
123 | # random shuffle the questions to balance the loading
124 | # random.shuffle(questions)
125 | shuffled_ids = [q["question_id"] for q in questions]
126 | # with open(f"data/{args.bench_name}/model_ids/{args.model_id}.shuffled_ids", "w") as fout:
127 | # json.dump(shuffled_ids, fout)
128 |
129 | # Split the question file into `num_gpus` files
130 | assert num_gpus_total % num_gpus_per_model == 0
131 | use_ray = num_gpus_total // num_gpus_per_model > 1
132 |
133 | if use_ray:
134 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)(
135 | get_model_answers
136 | ).remote
137 | else:
138 | get_answers_func = get_model_answers
139 |
140 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) # // 2
141 | ans_handles = []
142 | for i in range(0, len(questions), chunk_size):
143 | ans_handles.append(
144 | get_answers_func(
145 | model_path,
146 | model_id,
147 | questions[i : i + chunk_size],
148 | answer_file,
149 | max_new_token,
150 | num_choices,
151 | num_gpus_per_model,
152 | max_gpu_memory,
153 | temperature,
154 | posterior_threshold,
155 | posterior_alpha,
156 | medusa_choices,
157 | )
158 | )
159 |
160 | if use_ray:
161 | ray.get(ans_handles)
162 |
163 |
164 | @torch.inference_mode()
165 | def get_model_answers(
166 | model_path,
167 | model_id,
168 | questions,
169 | answer_file,
170 | max_new_token,
171 | num_choices,
172 | num_gpus_per_model,
173 | max_gpu_memory,
174 | temperature,
175 | posterior_threshold,
176 | posterior_alpha,
177 | medusa_choices,
178 | ):
179 |
180 | # Medusa model setup
181 | num_heads = 4
182 |
183 | model = MedusaModel.from_pretrained(
184 | model_path,
185 | medusa_num_heads = num_heads,
186 | torch_dtype=torch.float16,
187 | low_cpu_mem_usage=True,
188 | device_map="auto"
189 | )
190 |
191 | tokenizer = model.get_tokenizer()
192 |
193 | model.eval()
194 | print('Check model training state:',model.training)
195 |
196 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
197 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
198 |
199 | question = questions[0]
200 |
201 | # warmup
202 | for _ in range(3):
203 | torch.manual_seed(0)
204 | conv = get_conversation_template(model_id)
205 | turns = []
206 | idxs = []
207 | new_tokens = []
208 | wall_time = []
209 | for j in range(len(question["turns"])):
210 | qs = question["turns"][j]
211 | conv.append_message(conv.roles[0], qs)
212 | conv.append_message(conv.roles[1], None)
213 | prompt = conv.get_prompt()
214 | input_ids = tokenizer([prompt]).input_ids
215 |
216 | # if temperature < 1e-4:
217 | # do_sample = False
218 | # else:
219 | # do_sample = True
220 |
221 | # some models may error out when generating long outputs
222 | try:
223 | torch.cuda.synchronize()
224 | start_time = time.time()
225 | output_ids, new_token, idx = medusa_forward(
226 | torch.as_tensor(input_ids).cuda(),
227 | model,
228 | tokenizer,
229 | medusa_choices,
230 | temperature,
231 | posterior_threshold,
232 | posterior_alpha,
233 | )
234 | torch.cuda.synchronize()
235 | total_time = time.time() - start_time
236 | output_ids = output_ids[0][len(input_ids[0]) :]
237 | # be consistent with the template's stop_token_ids
238 | if conv.stop_token_ids:
239 | stop_token_ids_index = [
240 | i
241 | for i, id in enumerate(output_ids)
242 | if id in conv.stop_token_ids
243 | ]
244 | if len(stop_token_ids_index) > 0:
245 | output_ids = output_ids[: stop_token_ids_index[0]]
246 |
247 | output = tokenizer.decode(
248 | output_ids,
249 | spaces_between_special_tokens=False,
250 | )
251 | if conv.stop_str and output.find(conv.stop_str) > 0:
252 | output = output[: output.find(conv.stop_str)]
253 | for special_token in tokenizer.special_tokens_map.values():
254 | if isinstance(special_token, list):
255 | for special_tok in special_token:
256 | output = output.replace(special_tok, "")
257 | else:
258 | output = output.replace(special_token, "")
259 | output = output.strip()
260 |
261 | if conv.name == "xgen" and output.startswith("Assistant:"):
262 | output = output.replace("Assistant:", "", 1).strip()
263 | except RuntimeError as e:
264 | print("ERROR question ID: ", question["question_id"])
265 | output = "ERROR"
266 |
267 | turns.append(output)
268 | idxs.append(int(idx))
269 | new_tokens.append(int(new_token))
270 | wall_time.append(total_time)
271 | conv.messages[-1][-1] = output
272 | print('Warmup done')
273 |
274 |
275 | for question in tqdm(questions):
276 | if question["category"] in temperature_config:
277 | temperature = temperature_config[question["category"]]
278 | else:
279 | temperature = 0.7
280 |
281 | choices = []
282 | for i in range(num_choices):
283 | torch.manual_seed(i)
284 | conv = get_conversation_template(model_id)
285 | turns = []
286 | idxs = []
287 | new_tokens = []
288 | wall_time = []
289 | for j in range(len(question["turns"])):
290 | qs = question["turns"][j]
291 | conv.append_message(conv.roles[0], qs)
292 | conv.append_message(conv.roles[1], None)
293 | prompt = conv.get_prompt()
294 | input_ids = tokenizer([prompt]).input_ids
295 |
296 | # if temperature < 1e-4:
297 | # do_sample = False
298 | # else:
299 | # do_sample = True
300 |
301 | # some models may error out when generating long outputs
302 | try:
303 | torch.cuda.synchronize()
304 | start_time = time.time()
305 | output_ids, new_token, idx = medusa_forward(
306 | torch.as_tensor(input_ids).cuda(),
307 | model,
308 | tokenizer,
309 | medusa_choices,
310 | temperature,
311 | posterior_threshold,
312 | posterior_alpha,
313 | )
314 | torch.cuda.synchronize()
315 | total_time = time.time() - start_time
316 | # if model.config.is_encoder_decoder:
317 | # output_ids = output_ids[0]
318 | # else:
319 | output_ids = output_ids[0][len(input_ids[0]) :]
320 |
321 | # be consistent with the template's stop_token_ids
322 | if conv.stop_token_ids:
323 | stop_token_ids_index = [
324 | i
325 | for i, id in enumerate(output_ids)
326 | if id in conv.stop_token_ids
327 | ]
328 | if len(stop_token_ids_index) > 0:
329 | output_ids = output_ids[: stop_token_ids_index[0]]
330 |
331 | output = tokenizer.decode(
332 | output_ids,
333 | spaces_between_special_tokens=False,
334 | )
335 | if conv.stop_str and output.find(conv.stop_str) > 0:
336 | output = output[: output.find(conv.stop_str)]
337 | for special_token in tokenizer.special_tokens_map.values():
338 | if isinstance(special_token, list):
339 | for special_tok in special_token:
340 | output = output.replace(special_tok, "")
341 | else:
342 | output = output.replace(special_token, "")
343 | output = output.strip()
344 |
345 | if conv.name == "xgen" and output.startswith("Assistant:"):
346 | output = output.replace("Assistant:", "", 1).strip()
347 | except RuntimeError as e:
348 | print("ERROR question ID: ", question["question_id"])
349 | output = "ERROR"
350 |
351 | turns.append(output)
352 | idxs.append(int(idx))
353 | new_tokens.append(int(new_token))
354 | wall_time.append(total_time)
355 | conv.messages[-1][-1] = output
356 | # torch.cuda.empty_cache()
357 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time})
358 |
359 | # Dump answers
360 | os.makedirs(os.path.dirname(answer_file), exist_ok=True)
361 | with open(os.path.expanduser(answer_file), "a") as fout:
362 | ans_json = {
363 | "question_id": question["question_id"],
364 | "answer_id": shortuuid.uuid(),
365 | "model_id": model_id,
366 | "choices": choices,
367 | "tstamp": time.time(),
368 | }
369 | fout.write(json.dumps(ans_json) + "\n")
370 |
371 |
372 | def reorg_answer_file(answer_file):
373 | """Sort by question id and de-duplication"""
374 | answers = {}
375 | with open(answer_file, "r") as fin:
376 | for l in fin:
377 | qid = json.loads(l)["question_id"]
378 | answers[qid] = l
379 |
380 | qids = sorted(list(answers.keys()))
381 | with open(answer_file, "w") as fout:
382 | for qid in qids:
383 | fout.write(answers[qid])
384 |
385 |
386 | if __name__ == "__main__":
387 | parser = argparse.ArgumentParser()
388 | parser.add_argument(
389 | "--model-path",
390 | type=str,
391 | required=True,
392 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
393 | )
394 | parser.add_argument("--model-id", type=str, required=True)
395 | parser.add_argument(
396 | "--bench-name",
397 | type=str,
398 | default="mt_bench",
399 | help="The name of the benchmark question set.",
400 | )
401 | parser.add_argument(
402 | "--question-begin",
403 | type=int,
404 | help="A debug option. The begin index of questions.",
405 | )
406 | parser.add_argument(
407 | "--question-end", type=int, help="A debug option. The end index of questions."
408 | )
409 | parser.add_argument("--answer-file", type=str, help="The output answer file.")
410 | parser.add_argument(
411 | "--max-new-token",
412 | type=int,
413 | default=1024,
414 | help="The maximum number of new generated tokens.",
415 | )
416 | parser.add_argument(
417 | "--num-choices",
418 | type=int,
419 | default=1,
420 | help="How many completion choices to generate.",
421 | )
422 | parser.add_argument(
423 | "--num-gpus-per-model",
424 | type=int,
425 | default=1,
426 | help="The number of GPUs per model.",
427 | )
428 | parser.add_argument(
429 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs."
430 | )
431 | parser.add_argument(
432 | "--max-gpu-memory",
433 | type=str,
434 | help="Maxmum GPU memory used for model weights per GPU.",
435 | )
436 |
437 | # YL: Medusa args
438 | parser.add_argument(
439 | "--temperature",
440 | type=float,
441 | default=0.0,
442 | help="The temperature for medusa sampling.",
443 | )
444 |
445 | parser.add_argument(
446 | "--posterior-threshold",
447 | type=float,
448 | default=0.09,
449 | help="The posterior threshold for medusa sampling.",
450 | )
451 |
452 | parser.add_argument(
453 | "--posterior-alpha",
454 | type=float,
455 | default=0.3,
456 | help="The posterior alpha for medusa sampling.",
457 | )
458 |
459 | parser.add_argument(
460 | "--medusa-choices",
461 | type=str,
462 | default="mc_sim_7b_63",
463 | help="The medusa choices for medusa sampling.",
464 | )
465 |
466 |
467 |
468 |
469 | args = parser.parse_args()
470 |
471 | args.model_id = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)
472 | args.medusa_choices = eval(args.medusa_choices)
473 | if args.num_gpus_total // args.num_gpus_per_model > 1:
474 | import ray
475 |
476 | ray.init()
477 |
478 | question_file = f"data/{args.bench_name}/question.jsonl"
479 | if args.answer_file:
480 | answer_file = args.answer_file
481 | else:
482 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl"
483 |
484 | print(f"Output to {answer_file}")
485 |
486 | run_eval(
487 | args.model_path,
488 | args.model_id,
489 | question_file,
490 | args.question_begin,
491 | args.question_end,
492 | answer_file,
493 | args.max_new_token,
494 | args.num_choices,
495 | args.num_gpus_per_model,
496 | args.num_gpus_total,
497 | args.max_gpu_memory,
498 |
499 | args.temperature,
500 | args.posterior_threshold,
501 | args.posterior_alpha,
502 | args.medusa_choices,
503 | )
504 |
505 | reorg_answer_file(answer_file)
--------------------------------------------------------------------------------
/llm_judge/show_result.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 show_result.py --mode [single|pairwise-baseline|pairwise-all]
4 | """
5 | import argparse
6 | import pandas as pd
7 |
8 |
9 | def display_result_single(args):
10 | if args.input_file is None:
11 | input_file = (
12 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl"
13 | )
14 | else:
15 | input_file = args.input_file
16 |
17 | print(f"Input file: {input_file}")
18 | df_all = pd.read_json(input_file, lines=True)
19 | df = df_all[["model", "score", "turn"]]
20 | df = df[df["score"] != -1]
21 |
22 | if args.model_list is not None:
23 | df = df[df["model"].isin(args.model_list)]
24 |
25 | print("\n########## First turn ##########")
26 | df_1 = df[df["turn"] == 1].groupby(["model", "turn"]).mean()
27 | print(df_1.sort_values(by="score", ascending=False))
28 |
29 | if "mt_bench" in args.bench_name:
30 | print("\n########## Second turn ##########")
31 | df_2 = df[df["turn"] == 2].groupby(["model", "turn"]).mean()
32 | print(df_2.sort_values(by="score", ascending=False))
33 |
34 | print("\n########## Average ##########")
35 | df_3 = df[["model", "score"]].groupby(["model"]).mean()
36 | print(df_3.sort_values(by="score", ascending=False))
37 |
38 |
39 | def display_result_pairwise(args):
40 | if args.input_file is None:
41 | input_file = (
42 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl"
43 | )
44 | else:
45 | input_file = args.input_file
46 |
47 | print(f"Input file: {input_file}")
48 | df_all = pd.read_json(input_file, lines=True)
49 | df_all = df_all[(df_all["g1_winner"] != "error") & (df_all["g2_winner"] != "error")]
50 |
51 | model_list = (
52 | df_all["model_1"].unique().tolist() + df_all["model_2"].unique().tolist()
53 | )
54 | model_list = list(set(model_list))
55 |
56 | list_res = []
57 | # traverse df row by row
58 | for index, row in df_all.iterrows():
59 | if args.model_list is not None and row["model_1"] not in args.model_list:
60 | continue
61 | if args.baseline_model is not None:
62 | if args.baseline_model not in [row["model_1"], row["model_2"]]:
63 | continue
64 | if row["g1_winner"] == "tie" or row["g1_winner"] != row["g2_winner"]:
65 | list_res.append({"model": row["model_1"], "win": 0, "loss": 0, "tie": 1})
66 | list_res.append({"model": row["model_2"], "win": 0, "loss": 0, "tie": 1})
67 | else:
68 | if row["g1_winner"] == "model_1":
69 | winner = row["model_1"]
70 | loser = row["model_2"]
71 | else:
72 | winner = row["model_2"]
73 | loser = row["model_1"]
74 | list_res.append({"model": winner, "win": 1, "loss": 0, "tie": 0})
75 | list_res.append({"model": loser, "win": 0, "loss": 1, "tie": 0})
76 |
77 | df = pd.DataFrame(list_res)
78 | df = df.groupby(["model"]).sum()
79 |
80 | # remove baseline model
81 | if args.baseline_model is not None:
82 | df = df[df.index != args.baseline_model]
83 | # add win rate
84 | df["win_rate"] = df["win"] / (df["win"] + df["loss"] + df["tie"])
85 | df["loss_rate"] = df["loss"] / (df["win"] + df["loss"] + df["tie"])
86 | # each tie counts as 0.5 win + 0.5 loss
87 | df["win_rate_adjusted"] = (df["win"] + 0.5 * df["tie"]) / (
88 | df["win"] + df["loss"] + df["tie"]
89 | )
90 | # print(df.sort_values(by="win_rate", ascending=False))
91 | # print(df.sort_values(by="loss_rate", ascending=True))
92 | print(df.sort_values(by="win_rate_adjusted", ascending=False))
93 |
94 |
95 | if __name__ == "__main__":
96 | parser = argparse.ArgumentParser()
97 | parser.add_argument("--bench-name", type=str, default="mt_bench")
98 | parser.add_argument("--input-file", type=str)
99 | parser.add_argument("--judge-model", type=str, default="gpt-4")
100 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo")
101 | parser.add_argument(
102 | "--model-list",
103 | type=str,
104 | nargs="+",
105 | default=None,
106 | help="A list of models to be evaluated",
107 | )
108 | parser.add_argument(
109 | "--mode",
110 | type=str,
111 | default="single",
112 | choices=["pairwise-baseline", "pairwise-all", "single"],
113 | help=(
114 | "Evaluation mode. "
115 | "`pairwise-baseline` runs pairwise comparision against a baseline. "
116 | "`pairwise-all` runs pairwise comparision between all pairs. "
117 | "`single` runs single answer grading."
118 | ),
119 | )
120 | args = parser.parse_args()
121 |
122 | if args.mode == "single":
123 | display_result_func = display_result_single
124 | else:
125 | if args.mode == "pairwise-all":
126 | args.baseline_model = None
127 | display_result_func = display_result_pairwise
128 |
129 | print(f"Mode: {args.mode}")
130 | display_result_func(args)
--------------------------------------------------------------------------------
/medusa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/__init__.py
--------------------------------------------------------------------------------
/medusa/eval/README.md:
--------------------------------------------------------------------------------
1 |
2 | We use [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json) dataset for evaluating each head's accuracy during generation in `heads_accuracy.py`.
3 |
4 | ```
5 | python heads_accuracy.py --model_path 'FasterDecoding/medusa-vicuna-7b-v1.3' --model_name 'medusa-vicuna-7b-v1.3' --medusa_num_heads 5 --data_path '../../data/alpaca_eval.json'
6 | ```
7 |
8 |
9 | To create the tree and plot the tree (requires `pygraphviz` package), please run:
10 |
11 | ```
12 | python gen_results.py --accuracy-path '../../data/medusa-vicuna-7b-v1.3_heads_accuracy.pt' --output-path '../../data/graph.jpg'
13 | ```
14 |
15 | If you want to use the tree, please add the generated tree (in a nested tuple) to `../model/medusa_choices.py`.
16 |
17 | Citation:
18 |
19 | ```
20 | @misc{alpaca_eval,
21 | author = {Xuechen Li and Tianyi Zhang and Yann Dubois and Rohan Taori and Ishaan Gulrajani and Carlos Guestrin and Percy Liang and Tatsunori B. Hashimoto },
22 | title = {AlpacaEval: An Automatic Evaluator of Instruction-following Models},
23 | year = {2023},
24 | publisher = {GitHub},
25 | journal = {GitHub repository},
26 | howpublished = {\url{https://github.com/tatsu-lab/alpaca_eval}}
27 | }```
--------------------------------------------------------------------------------
/medusa/eval/gen_results.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import copy
3 | import networkx as nx
4 | import torch
5 | import argparse
6 |
7 | def load_accuracy_table(path):
8 | test_accuracy = torch.load(path)
9 | accuracy_table = []
10 | for i in range(len(test_accuracy)):
11 | accuracy_table.append(test_accuracy[i].sum(0)/16100)
12 | return torch.stack(accuracy_table)
13 |
14 | def get_node_expectation(accuracies, node):
15 | expectation = copy.deepcopy(accuracies[0, node[0]])
16 | for i in range(1, len(node)):
17 | expectation *= accuracies[i, node[i]]
18 | return expectation
19 |
20 | def explore_graph(accuracies, max_depth, max_child, num_iterations):
21 | explored_nodes = {}
22 | accept_nodes = [tuple([0])]
23 | expectations = get_node_expectation(accuracies, accept_nodes[0])
24 | explored_nodes[tuple(accept_nodes[0])] = expectations
25 |
26 | for _ in range(num_iterations):
27 | # find neighbors
28 | neighbors = []
29 | for node in accept_nodes:
30 | if node[-1] < max_child[len(node) - 1] - 1:
31 | neighbor = list(copy.deepcopy(node))
32 | neighbor[-1] = neighbor[-1] + 1
33 | neighbors.append(neighbor)
34 | if len(node) < max_depth:
35 | neighbor = list(copy.deepcopy(node))
36 | neighbor.append(0)
37 | neighbors.append(neighbor)
38 |
39 | # find the best neighbor
40 | best_neighbor = None
41 | best_neighbor_expectation = 0
42 | for neighbor in neighbors:
43 | if tuple(neighbor) in accept_nodes:
44 | continue
45 | if tuple(neighbor) in explored_nodes:
46 | neighbor_expectation = explored_nodes[tuple(neighbor)]
47 | else:
48 | neighbor_expectation = get_node_expectation(accuracies, neighbor)
49 | explored_nodes[tuple(neighbor)] = neighbor_expectation
50 | if neighbor_expectation > best_neighbor_expectation:
51 | best_neighbor = neighbor
52 | best_neighbor_expectation = neighbor_expectation
53 | accept_nodes.append(tuple(best_neighbor))
54 | expectations += best_neighbor_expectation
55 |
56 | return accept_nodes
57 |
58 | def plot_and_save_graph(accept_nodes, output_path):
59 | plt.figure(figsize=(40, 20))
60 |
61 | G = nx.DiGraph()
62 |
63 | for path in accept_nodes:
64 | for i in range(len(path)):
65 | if i == 0:
66 | parent = 'root'
67 | else:
68 | parent = tuple(path[:i])
69 | child = tuple(path[:i+1])
70 | G.add_edge(parent, child)
71 |
72 | pos = nx.nx_agraph.graphviz_layout(G, prog='dot')
73 | nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, width=2, edge_color="gray")
74 | plt.savefig(output_path)
75 |
76 | def main():
77 | parser = argparse.ArgumentParser(description="Generate Results.")
78 | parser.add_argument('--accuracy-path', type=str, required=True, help="Path to load accuracy tensor.")
79 | parser.add_argument('--output-path', type=str, required=True, help="Path to save the generated graph.")
80 | parser.add_argument('--max-depth', type=int, default=5, help="Maximum depth of the graph.")
81 | parser.add_argument('--num-iterations', type=int, default=62, help="Number of exploration iterations.")
82 | parser.add_argument('--max-child', nargs='+', type=int, default=[10, 10, 10, 10, 10], help="Maximum number of children per depth.")
83 |
84 | args = parser.parse_args()
85 |
86 | accuracies = load_accuracy_table(args.accuracy_path)
87 | accept_nodes = explore_graph(accuracies, args.max_depth, args.max_child, args.num_iterations)
88 |
89 | print("Accepted Nodes:", accept_nodes)
90 |
91 | try:
92 | plot_and_save_graph(accept_nodes, args.output_path)
93 | print(f"Graph saved to {args.output_path}.")
94 | except Exception as e:
95 | print(f"Failed to save the graph due to the following error: {e}")
96 | print("Ensure that Graphviz and pygraphviz are installed and set up correctly.")
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/medusa/eval/heads_accuracy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import json
4 | from contextlib import contextmanager
5 | import numpy as np
6 | from medusa.model.medusa_model import MedusaModel
7 | from medusa.model.kv_cache import *
8 | from medusa.model.utils import *
9 | from medusa.model.medusa_choices import *
10 | from copy import deepcopy
11 | import matplotlib.pyplot as plt
12 | import torch.nn.functional as F
13 | from fastchat.model.model_adapter import get_conversation_template
14 | from tqdm import tqdm
15 | import argparse
16 |
17 | def get_accuracies(medusa, logit):
18 | # get the correct counts of each head
19 | seq_len, choices, topk = medusa.shape
20 | results = []
21 | for choice in range(choices):
22 | results.append(medusa[:-choice - 1,choice].eq(logit[choice + 1:,0]))
23 | return results
24 |
25 |
26 |
27 | def main(args):
28 | model = MedusaModel.from_pretrained(
29 | args.model_path,
30 | # medusa_num_heads=args.medusa_num_heads,
31 | torch_dtype=torch.float16,
32 | low_cpu_mem_usage=True,
33 | device_map="auto"
34 | )
35 | tokenizer = model.get_tokenizer()
36 |
37 |
38 | data = json.load(open(args.data_path))
39 | past_key_values, past_key_values_data, current_length_data = initialize_past_key_values(model.base_model)
40 | model.past_key_values = past_key_values
41 | model.past_key_values_data = past_key_values_data
42 | model.current_length_data = current_length_data
43 | results = None
44 |
45 | for sample in tqdm((data)):
46 | conv = get_conversation_template("vicuna")
47 | conv.messages = []
48 | conv.append_message(conv.roles[0], sample["instruction"])
49 | conv.append_message(conv.roles[1], "")
50 | prompt = conv.get_prompt()
51 | steps = args.steps
52 | logits_ids = []
53 | medusa_topk_ids = []
54 |
55 | with torch.inference_mode():
56 | input_ids = tokenizer([prompt]).input_ids
57 | input_ids = torch.as_tensor(input_ids).cuda()
58 | model.current_length_data.zero_() # this is for rerun
59 | reset_medusa_mode(model)
60 | medusa_logits, outputs, logits = model(
61 | input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True
62 | )
63 | _, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1)
64 | input_id = logits[:, -1:].argmax(dim=-1)
65 | logits_ids.append(input_id.detach().cpu())
66 | medusa_topk_ids.append(medusa_topk.detach().cpu())
67 | for _ in range(steps):
68 | medusa_logits, outputs, logits = model(
69 | input_id, past_key_values=past_key_values, output_orig=True, medusa_forward=True
70 | )
71 | _, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1)
72 | input_id = logits[:, -1:].argmax(dim=-1)
73 | logits_ids.append(input_id.detach().cpu())
74 | medusa_topk_ids.append(medusa_topk.detach().cpu())
75 | logits_ids = torch.stack(logits_ids, dim=0)
76 | medusa_topk_ids = torch.stack(medusa_topk_ids, dim=0).squeeze(2)
77 | if results is None:
78 | results = get_accuracies(medusa_topk_ids, logits_ids)
79 | else:
80 | # cat sub results
81 | cur_results = get_accuracies(medusa_topk_ids, logits_ids)
82 | for i in range(len(results)):
83 | results[i] = torch.cat((results[i], cur_results[i]), dim=0)
84 |
85 | save_path = os.path.join(args.save_dir, args.model_name + "_heads_accuracy.pt")
86 | torch.save(results, save_path)
87 |
88 | if __name__ == "__main__":
89 | parser = argparse.ArgumentParser(description="Medusa Model Evaluator")
90 |
91 | parser.add_argument("--model_path", type=str, required=True,
92 | help="Path to the pre-trained Medusa model.")
93 | parser.add_argument("--model_name", type=str, required=True,
94 | help="Name of the model.")
95 | parser.add_argument("--medusa_num_heads", type=int, default=5,
96 | help="Number of medusa heads.")
97 | parser.add_argument("--data_path", type=str, required=True,
98 | help="Path to the evaluation data in JSON format.")
99 | parser.add_argument("--save_dir", type=str, default="../../data",
100 | help="Directory to save the results.")
101 | parser.add_argument("--steps", type=int, default=20,
102 | help="Number of steps to run the model.")
103 | args = parser.parse_args()
104 |
105 | # If the save directory doesn't exist, create it
106 | if not os.path.exists(args.save_dir):
107 | os.makedirs(args.save_dir)
108 | main(args)
--------------------------------------------------------------------------------
/medusa/hf_utils.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import HfApi
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser("Upload Medusa model to HuggingFace Hub")
5 | parser.add_argument("--folder", type=str, help="Path to model folder")
6 | parser.add_argument("--repo", type=str, help="Repo name")
7 | parser.add_argument("--private", action="store_true", help="Make repo private")
8 |
9 | args = parser.parse_args()
10 |
11 | api = HfApi()
12 |
13 | api.create_repo(
14 | repo_id=args.repo,
15 | private=args.private,
16 | exist_ok=True,
17 | )
18 |
19 | api.upload_folder(
20 | folder_path=args.folder,
21 | repo_id=args.repo,
22 | )
--------------------------------------------------------------------------------
/medusa/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/inference/__init__.py
--------------------------------------------------------------------------------
/medusa/inference/cli.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
2 | """
3 | Chat with a model with command line interface.
4 |
5 | Usage:
6 | python3 -m medusa.inference.cli --model
7 | Other commands:
8 | - Type "!!exit" or an empty line to exit.
9 | - Type "!!reset" to start a new conversation.
10 | - Type "!!remove" to remove the last prompt.
11 | - Type "!!regen" to regenerate the last message.
12 | - Type "!!save " to save the conversation history to a json file.
13 | - Type "!!load " to load a conversation history from a json file.
14 | """
15 | import argparse
16 | import os
17 | import re
18 | import sys
19 | import torch
20 | from fastchat.serve.cli import SimpleChatIO, RichChatIO, ProgrammaticChatIO
21 | from fastchat.model.model_adapter import get_conversation_template
22 | from fastchat.conversation import get_conv_template
23 | import json
24 | from medusa.model.medusa_model import MedusaModel
25 |
26 |
27 | def main(args):
28 | if args.style == "simple":
29 | chatio = SimpleChatIO(args.multiline)
30 | elif args.style == "rich":
31 | chatio = RichChatIO(args.multiline, args.mouse)
32 | elif args.style == "programmatic":
33 | chatio = ProgrammaticChatIO()
34 | else:
35 | raise ValueError(f"Invalid style for console: {args.style}")
36 | try:
37 | model = MedusaModel.from_pretrained(
38 | args.model,
39 | torch_dtype=torch.float16,
40 | low_cpu_mem_usage=True,
41 | device_map="auto",
42 | load_in_8bit=args.load_in_8bit,
43 | load_in_4bit=args.load_in_4bit,
44 | )
45 | tokenizer = model.get_tokenizer()
46 | conv = None
47 |
48 | def new_chat():
49 | return get_conversation_template(args.model)
50 |
51 | def reload_conv(conv):
52 | """
53 | Reprints the conversation from the start.
54 | """
55 | for message in conv.messages[conv.offset :]:
56 | chatio.prompt_for_output(message[0])
57 | chatio.print_output(message[1])
58 |
59 | while True:
60 | if not conv:
61 | conv = new_chat()
62 |
63 | try:
64 | inp = chatio.prompt_for_input(conv.roles[0])
65 | except EOFError:
66 | inp = ""
67 |
68 | if inp == "!!exit" or not inp:
69 | print("exit...")
70 | break
71 | elif inp == "!!reset":
72 | print("resetting...")
73 | conv = new_chat()
74 | continue
75 | elif inp == "!!remove":
76 | print("removing last message...")
77 | if len(conv.messages) > conv.offset:
78 | # Assistant
79 | if conv.messages[-1][0] == conv.roles[1]:
80 | conv.messages.pop()
81 | # User
82 | if conv.messages[-1][0] == conv.roles[0]:
83 | conv.messages.pop()
84 | reload_conv(conv)
85 | else:
86 | print("No messages to remove.")
87 | continue
88 | elif inp == "!!regen":
89 | print("regenerating last message...")
90 | if len(conv.messages) > conv.offset:
91 | # Assistant
92 | if conv.messages[-1][0] == conv.roles[1]:
93 | conv.messages.pop()
94 | # User
95 | if conv.messages[-1][0] == conv.roles[0]:
96 | reload_conv(conv)
97 | # Set inp to previous message
98 | inp = conv.messages.pop()[1]
99 | else:
100 | # Shouldn't happen in normal circumstances
101 | print("No user message to regenerate from.")
102 | continue
103 | else:
104 | print("No messages to regenerate.")
105 | continue
106 | elif inp.startswith("!!save"):
107 | args = inp.split(" ", 1)
108 |
109 | if len(args) != 2:
110 | print("usage: !!save ")
111 | continue
112 | else:
113 | filename = args[1]
114 |
115 | # Add .json if extension not present
116 | if not "." in filename:
117 | filename += ".json"
118 |
119 | print("saving...", filename)
120 | with open(filename, "w") as outfile:
121 | json.dump(conv.dict(), outfile)
122 | continue
123 | elif inp.startswith("!!load"):
124 | args = inp.split(" ", 1)
125 |
126 | if len(args) != 2:
127 | print("usage: !!load ")
128 | continue
129 | else:
130 | filename = args[1]
131 |
132 | # Check if file exists and add .json if needed
133 | if not os.path.exists(filename):
134 | if (not filename.endswith(".json")) and os.path.exists(
135 | filename + ".json"
136 | ):
137 | filename += ".json"
138 | else:
139 | print("file not found:", filename)
140 | continue
141 |
142 | print("loading...", filename)
143 | with open(filename, "r") as infile:
144 | new_conv = json.load(infile)
145 |
146 | conv = get_conv_template(new_conv["template_name"])
147 | conv.set_system_message(new_conv["system_message"])
148 | conv.messages = new_conv["messages"]
149 | reload_conv(conv)
150 | continue
151 |
152 | conv.append_message(conv.roles[0], inp)
153 | conv.append_message(conv.roles[1], None)
154 | prompt = conv.get_prompt()
155 |
156 | try:
157 | chatio.prompt_for_output(conv.roles[1])
158 | input_ids = tokenizer.encode(prompt, return_tensors="pt").to(
159 | model.base_model.device
160 | )
161 | outputs = chatio.stream_output(
162 | model.medusa_generate(
163 | input_ids,
164 | temperature=args.temperature,
165 | max_steps=args.max_steps,
166 | )
167 | )
168 | conv.update_last_message(outputs.strip())
169 |
170 | except KeyboardInterrupt:
171 | print("stopped generation.")
172 | # If generation didn't finish
173 | if conv.messages[-1][1] is None:
174 | conv.messages.pop()
175 | # Remove last user message, so there isn't a double up
176 | if conv.messages[-1][0] == conv.roles[0]:
177 | conv.messages.pop()
178 |
179 | reload_conv(conv)
180 |
181 | except KeyboardInterrupt:
182 | print("exit...")
183 |
184 |
185 | if __name__ == "__main__":
186 | parser = argparse.ArgumentParser()
187 | parser.add_argument("--model", type=str, required=True, help="Model name or path.")
188 | parser.add_argument(
189 | "--load-in-8bit", action="store_true", help="Use 8-bit quantization"
190 | )
191 | parser.add_argument(
192 | "--load-in-4bit", action="store_true", help="Use 4-bit quantization"
193 | )
194 | parser.add_argument(
195 | "--conv-template", type=str, default=None, help="Conversation prompt template."
196 | )
197 | parser.add_argument(
198 | "--conv-system-msg", type=str, default=None, help="Conversation system message."
199 | )
200 | parser.add_argument("--temperature", type=float, default=0.7)
201 | parser.add_argument("--max-steps", type=int, default=512)
202 | parser.add_argument("--no-history", action="store_true")
203 | parser.add_argument(
204 | "--style",
205 | type=str,
206 | default="simple",
207 | choices=["simple", "rich", "programmatic"],
208 | help="Display style.",
209 | )
210 | parser.add_argument(
211 | "--multiline",
212 | action="store_true",
213 | help="Enable multiline input. Use ESC+Enter for newline.",
214 | )
215 | parser.add_argument(
216 | "--mouse",
217 | action="store_true",
218 | help="[Rich Style]: Enable mouse support for cursor positioning.",
219 | )
220 | parser.add_argument(
221 | "--debug",
222 | action="store_true",
223 | help="Print useful debug information (e.g., prompts)",
224 | )
225 | args = parser.parse_args()
226 | main(args)
227 |
--------------------------------------------------------------------------------
/medusa/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/model/__init__.py
--------------------------------------------------------------------------------
/medusa/model/kv_cache.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class KVCache:
5 | """
6 | A key-value cache for the model.
7 |
8 | This class provides a mechanism to maintain a growing cache of keys and values,
9 | particularly useful for models that benefit from caching previous states,
10 | like transformers during autoregressive decoding.
11 |
12 | Attributes:
13 | data (torch.Tensor): The tensor storing keys and values.
14 | current_length (int): Current length of the data being stored.
15 | """
16 |
17 | def __init__(self, data, current_length):
18 | """
19 | Initialize the KVCache.
20 |
21 | Args:
22 | data (torch.Tensor): Initial tensor to store the keys and values.
23 | current_length (int): Initial length of the data.
24 | """
25 | self.data = data
26 | self.current_length = current_length
27 |
28 | @property
29 | def shape(self):
30 | """Return the shape of the data tensor with updated length."""
31 | return (
32 | self.data.shape[0],
33 | self.data.shape[1],
34 | self.current_length.item(),
35 | self.data.shape[3],
36 | )
37 |
38 | def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2):
39 | """
40 | Copy values from the current data at specified indices to a new location.
41 |
42 | Args:
43 | indices (torch.Tensor): Indices of the data tensor to be copied.
44 | prev_length (int): Previous length before adding new data.
45 | dim (int, optional): Dimension along which copying should be performed. Default is 2.
46 | """
47 | tgt = self.data.index_select(dim, indices)
48 | dst = self.data.narrow(dim, prev_length, tgt.shape[dim])
49 | dst.copy_(tgt, non_blocking=True)
50 | self.current_length.fill_(prev_length + tgt.shape[dim])
51 |
52 | def cat(self, tensor: torch.Tensor, dim: int = 2):
53 | """
54 | Concatenate the given tensor with the current data.
55 |
56 | Args:
57 | tensor (torch.Tensor): The tensor to be concatenated.
58 | dim (int, optional): The dimension along which concatenation should be done. Default is 2.
59 |
60 | Returns:
61 | torch.Tensor: The data tensor after concatenation up to the current length.
62 | """
63 | dst = self.data.narrow(dim, self.current_length, tensor.shape[dim])
64 | dst.copy_(tensor)
65 | self.current_length.add_(tensor.shape[dim])
66 | return torch.narrow(self.data, 2, 0, self.current_length)
67 |
68 |
69 | def initialize_past_key_values(model):
70 | """
71 | Initialize past key and value states for a given transformer model.
72 |
73 | This function prepares key-value cache structures for the model, allowing it to store and reuse
74 | past key and value states during autoregressive decoding, which can improve efficiency.
75 |
76 | Args:
77 | model (nn.Module): The transformer model for which past key-value states need to be initialized.
78 |
79 | Returns:
80 | tuple:
81 | - past_key_values (list): A list of KVCache objects for each layer in the model.
82 | - past_key_values_data (torch.Tensor): The tensor that will store all keys and values.
83 | - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache.
84 | """
85 | # Extracting configuration from the model
86 | config = model.config
87 | # Initializing the batch size to 1, this can be modified if different batch sizes are required
88 | batch_size = 1
89 | # Initializing a tensor to store past keys and values for all layers
90 | past_key_values_data = torch.zeros(
91 | config.num_hidden_layers * 2,
92 | batch_size,
93 | config.num_key_value_heads,
94 | config.max_position_embeddings,
95 | config.hidden_size // config.num_attention_heads,
96 | device=model.device,
97 | dtype=model.dtype,
98 | )
99 | # Initialize tensor to store the current length of the cached data for all layers.
100 | # [IMPORTANT] It needs to be kept on CPU for quick access and updates.
101 | current_length_data = torch.zeros(
102 | config.num_hidden_layers * 2, dtype=torch.long, device="cpu"
103 | )
104 | # Creating a KVCache for each pair of key and value in all layers
105 | past_key_values = [] * config.num_hidden_layers
106 | for i in range(config.num_hidden_layers):
107 | past_key_values.append(
108 | [
109 | KVCache(past_key_values_data[i * 2 + j], current_length_data[i * 2 + j])
110 | for j in range(2)
111 | ]
112 | )
113 | return past_key_values, past_key_values_data, current_length_data
114 |
--------------------------------------------------------------------------------
/medusa/model/medusa_choices.py:
--------------------------------------------------------------------------------
1 | mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]
2 | vicuna_7b_stage2 = [(0,), (0, 0), (1,), (0, 1), (0, 0, 0), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 3), (3,), (0, 1, 0), (2, 0), (4,), (0, 0, 2), (0, 4), (1, 1), (1, 0, 0), (0, 0, 0, 0), (5,), (0, 0, 3), (0, 5), (0, 2, 0), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 7), (0, 0, 4), (4, 0), (1, 2), (0, 8), (7,), (0, 3, 0), (0, 0, 0, 1), (0, 0, 5), (2, 1), (0, 0, 6), (1, 0, 1), (0, 0, 1, 0), (2, 0, 0), (5, 0), (0, 9), (0, 1, 2), (8,), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), (9,), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)]
3 | vicuna_7b_stage1_ablation = [(0,), (0, 0), (1,), (0, 0, 0), (0, 1), (1, 0), (2,), (0, 2), (0, 0, 1), (3,), (0, 3), (0, 1, 0), (2, 0), (0, 0, 2), (0, 4), (4,), (0, 0, 0, 0), (1, 0, 0), (1, 1), (0, 0, 3), (0, 2, 0), (0, 5), (5,), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 0, 4), (1, 2), (0, 0, 0, 1), (4, 0), (0, 0, 5), (0, 7), (0, 8), (0, 3, 0), (0, 0, 1, 0), (1, 0, 1), (7,), (2, 0, 0), (0, 0, 6), (2, 1), (0, 1, 2), (5, 0), (0, 2, 1), (0, 9), (0, 0, 0, 2), (0, 4, 0), (8,), (1, 3), (0, 0, 7), (0, 1, 0, 0), (1, 1, 0), (6, 0), (9,), (0, 0, 8), (0, 0, 9), (0, 5, 0), (0, 0, 2, 0), (1, 0, 2), (0, 1, 3), (0, 0, 0, 3), (3, 0, 0), (3, 1)]
4 | vicuna_7b_stage1 = [(0,), (0, 0), (1,), (2,), (0, 1), (1, 0), (3,), (0, 2), (4,), (0, 0, 0), (0, 3), (5,), (2, 0), (0, 4), (6,), (0, 5), (1, 1), (0, 0, 1), (7,), (3, 0), (0, 6), (8,), (9,), (0, 1, 0), (0, 7), (0, 8), (4, 0), (0, 0, 2), (1, 2), (0, 9), (2, 1), (5, 0), (1, 0, 0), (0, 0, 3), (1, 3), (0, 2, 0), (0, 1, 1), (0, 0, 4), (6, 0), (1, 4), (0, 0, 5), (2, 2), (0, 3, 0), (3, 1), (0, 0, 6), (7, 0), (1, 5), (1, 0, 1), (2, 0, 0), (0, 0, 7), (8, 0), (0, 0, 0, 0), (4, 1), (0, 1, 2), (0, 4, 0), (9, 0), (0, 2, 1), (2, 3), (1, 6), (0, 0, 8), (0, 5, 0), (3, 2), (5, 1)]
5 | vicuna_13b_stage2 = [(0,), (0, 0), (1,), (0, 0, 0), (0, 1), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 1, 0), (3,), (0, 3), (2, 0), (0, 0, 2), (0, 0, 0, 0), (0, 4), (1, 0, 0), (1, 1), (4,), (0, 0, 3), (0, 5), (0, 2, 0), (5,), (3, 0), (0, 1, 1), (0, 6), (0, 0, 4), (0, 0, 0, 1), (0, 7), (0, 0, 5), (1, 2), (0, 0, 1, 0), (0, 3, 0), (1, 0, 1), (4, 0), (0, 0, 6), (0, 8), (2, 0, 0), (0, 9), (6,), (7,), (2, 1), (5, 0), (0, 1, 2), (0, 0, 0, 2), (8,), (0, 4, 0), (0, 1, 0, 0), (0, 2, 1), (0, 0, 7), (1, 1, 0), (1, 3), (0, 0, 2, 0), (9,), (0, 0, 8), (0, 5, 0), (0, 0, 0, 3), (0, 0, 9), (0, 1, 3), (1, 0, 2), (0, 0, 1, 1), (3, 0, 0), (1, 0, 0, 0)]
6 | vicuna_13b_stage1 = [(0,), (0, 0), (1,), (0, 1), (2,), (1, 0), (0, 0, 0), (0, 2), (3,), (0, 3), (4,), (2, 0), (0, 4), (0, 0, 1), (0, 5), (5,), (1, 1), (0, 1, 0), (6,), (0, 6), (0, 0, 2), (7,), (3, 0), (8,), (0, 7), (0, 8), (1, 0, 0), (0, 0, 3), (4, 0), (1, 2), (9,), (0, 9), (2, 1), (0, 2, 0), (0, 0, 4), (1, 3), (0, 1, 1), (0, 0, 5), (5, 0), (0, 3, 0), (0, 0, 0, 0), (0, 0, 6), (6, 0), (1, 4), (2, 0, 0), (0, 1, 2), (3, 1), (0, 4, 0), (1, 0, 1), (2, 2), (0, 0, 7), (1, 5), (7, 0), (0, 0, 8), (8, 0), (0, 5, 0), (0, 0, 9), (0, 2, 1), (1, 1, 0), (0, 1, 3), (4, 1), (2, 3), (1, 6)]
7 | vicuna_33b_stage2 = [(0,), (0, 0), (1,), (0, 1), (0, 0, 0), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 3), (3,), (0, 1, 0), (2, 0), (0, 4), (4,), (0, 0, 2), (1, 1), (1, 0, 0), (0, 5), (5,), (0, 0, 0, 0), (0, 0, 3), (3, 0), (0, 2, 0), (0, 6), (0, 1, 1), (6,), (0, 0, 4), (0, 7), (7,), (1, 2), (4, 0), (8,), (0, 3, 0), (0, 0, 5), (0, 0, 0, 1), (0, 8), (2, 1), (0, 9), (1, 0, 1), (2, 0, 0), (0, 0, 6), (5, 0), (0, 0, 1, 0), (1, 3), (0, 1, 2), (0, 4, 0), (0, 0, 7), (0, 2, 1), (9,), (1, 1, 0), (0, 0, 0, 2), (6, 0), (0, 0, 8), (0, 1, 0, 0), (7, 0), (0, 1, 3), (0, 5, 0), (1, 4), (0, 0, 9), (3, 1), (1, 0, 2), (2, 2)]
8 | vicuna_33b_stage1 = [(0,), (1,), (0, 0), (2,), (0, 1), (3,), (1, 0), (4,), (0, 2), (5,), (0, 3), (0, 0, 0), (6,), (0, 4), (2, 0), (7,), (1, 1), (0, 5), (3, 0), (8,), (9,), (0, 6), (0, 7), (0, 0, 1), (1, 2), (4, 0), (0, 1, 0), (0, 8), (0, 9), (2, 1), (0, 0, 2), (5, 0), (1, 3), (0, 0, 3), (1, 0, 0), (1, 4), (6, 0), (0, 2, 0), (3, 1), (2, 2), (0, 0, 4), (7, 0), (0, 1, 1), (1, 5), (4, 1), (0, 0, 5), (0, 3, 0), (9, 0), (8, 0), (1, 6), (0, 0, 6), (2, 3), (0, 1, 2), (3, 2), (0, 4, 0), (2, 0, 0), (1, 7), (1, 0, 1), (0, 0, 7), (5, 1), (2, 4), (0, 0, 8), (0, 2, 1)]
9 | zephyr_stage2 = [(0,), (0, 0), (1,), (0, 1), (2,), (0, 0, 0), (1, 0), (0, 2), (3,), (0, 3), (4,), (2, 0), (0, 0, 1), (0, 4), (5,), (0, 5), (0, 1, 0), (1, 1), (6,), (0, 0, 2), (3, 0), (0, 6), (7,), (0, 7), (0, 8), (0, 0, 3), (1, 0, 0), (0, 9), (0, 2, 0), (1, 2), (4, 0), (8,), (9,), (2, 1), (0, 1, 1), (0, 0, 4), (0, 0, 0, 0), (5, 0), (0, 3, 0), (1, 3), (0, 0, 5), (0, 0, 6), (6, 0), (2, 0, 0), (1, 0, 1), (0, 1, 2), (0, 4, 0), (1, 4), (3, 1), (2, 2), (0, 0, 7), (7, 0), (0, 2, 1), (0, 0, 8), (0, 1, 3), (0, 5, 0), (1, 5), (0, 0, 9), (1, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0), (4, 1), (2, 3)]
10 |
--------------------------------------------------------------------------------
/medusa/model/medusa_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
4 | from .modeling_mistral_kv import MistralForCausalLM as KVMistralForCausalLM
5 | # import transformers
6 |
7 | # # monkey patch
8 | # transformers.models.llama.modeling_llama.LlamaForCausalLM = KVLlamaForCausalLM
9 | # transformers.models.mistral.modeling_mistral.MistralForCausalLM = KVMistralForCausalLM
10 |
11 | from transformers import PreTrainedModel, PretrainedConfig
12 | from .utils import *
13 | from .kv_cache import initialize_past_key_values
14 | from .medusa_choices import *
15 | from transformers import AutoTokenizer, AutoConfig
16 | import os
17 | from huggingface_hub import hf_hub_download
18 | import warnings
19 |
20 | class MedusaConfig(PretrainedConfig):
21 | """
22 | Configuration class for Medusa model.
23 |
24 | Args:
25 | medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
26 | medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
27 | base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
28 | **kwargs: Additional keyword arguments to be passed to the parent class constructor.
29 | """
30 |
31 | def __init__(
32 | self,
33 | medusa_num_heads=5,
34 | medusa_num_layers=1,
35 | base_model_name_or_path="lmsys/vicuna-7b-v1.3",
36 | **kwargs,
37 | ):
38 | super().__init__(**kwargs)
39 | self.medusa_num_heads = medusa_num_heads
40 | self.medusa_num_layers = medusa_num_layers
41 | self.base_model_name_or_path = base_model_name_or_path
42 |
43 | class ResBlock(nn.Module):
44 | """
45 | A Residual Block module.
46 |
47 | This module performs a linear transformation followed by a SiLU activation,
48 | and then adds the result to the original input, creating a residual connection.
49 |
50 | Args:
51 | hidden_size (int): The size of the hidden layers in the block.
52 | """
53 |
54 | def __init__(self, hidden_size):
55 | super().__init__()
56 | self.linear = nn.Linear(hidden_size, hidden_size)
57 | # Initialize as an identity mapping
58 | torch.nn.init.zeros_(self.linear.weight)
59 | # Use SiLU activation to keep consistent with the Llama model
60 | self.act = nn.SiLU()
61 |
62 | def forward(self, x):
63 | """
64 | Forward pass of the ResBlock.
65 |
66 | Args:
67 | x (torch.Tensor): Input tensor.
68 |
69 | Returns:
70 | torch.Tensor: Output after the residual connection and activation.
71 | """
72 | return x + self.act(self.linear(x))
73 |
74 |
75 | class MedusaModelABC(nn.Module):
76 | """The Medusa Language Model Head.
77 |
78 | This module creates a series of prediction heads (based on the 'medusa' parameter)
79 | on top of a given base model. Each head is composed of a sequence of residual blocks
80 | followed by a linear layer.
81 | """
82 |
83 | # Load the base model
84 | # base_model_prefix = "model"
85 | # supports_gradient_checkpointing = True
86 | # _no_split_modules = ["LlamaDecoderLayer", "MistralDecoderLayer"]
87 | # _skip_keys_device_placement = "past_key_values"
88 | # _supports_flash_attn_2 = True
89 |
90 | def __init__(
91 | self,
92 | config,
93 | ):
94 | """
95 | Args:
96 | config (PretrainedConfig): The configuration of the MedusaModel.
97 | """
98 | super().__init__(config)
99 | # For compatibility with the old APIs
100 |
101 | medusa_num_heads = config.medusa_num_heads
102 | medusa_num_layers = config.medusa_num_layers
103 | base_model_name_or_path = config._name_or_path
104 | self.hidden_size = config.hidden_size
105 | self.vocab_size = config.vocab_size
106 | self.medusa = medusa_num_heads
107 | self.medusa_num_layers = medusa_num_layers
108 | self.base_model_name_or_path = base_model_name_or_path
109 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
110 | # Create a list of Medusa heads
111 | self.medusa_head = nn.ModuleList(
112 | [
113 | nn.Sequential(
114 | *([ResBlock(self.hidden_size)] * medusa_num_layers),
115 | nn.Linear(self.hidden_size, self.vocab_size, bias=False),
116 | )
117 | for _ in range(medusa_num_heads)
118 | ]
119 | )
120 | # Add a link named base_model to self
121 | @property
122 | def base_model(self):
123 | return self
124 | @classmethod
125 | def from_pretrained(
126 | cls,
127 | pretrained_model_name_or_path,
128 | *args,
129 | **kwargs,
130 | ):
131 | # Manually load config to ensure that the medusa_num_heads parameter is loaded
132 | try:
133 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
134 | return super().from_pretrained(
135 | pretrained_model_name_or_path,
136 | *args,
137 | **kwargs,
138 | config=config,
139 | )
140 | except:
141 | config = MedusaConfig.from_pretrained(pretrained_model_name_or_path)
142 | base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path)
143 | base_model_config.medusa_num_heads = 5 # TODO: fix the uploaded config (only include 2 heads)
144 | base_model_config.medusa_num_layers = config.medusa_num_layers
145 | model = super().from_pretrained(
146 | config.base_model_name_or_path,
147 | *args,
148 | **kwargs,
149 | config=base_model_config,
150 | )
151 | medusa_head_path = os.path.join(pretrained_model_name_or_path, "medusa_lm_head.pt")
152 | if os.path.exists(medusa_head_path):
153 | filename = medusa_head_path
154 | else:
155 | filename = hf_hub_download(pretrained_model_name_or_path, "medusa_lm_head.pt")
156 | medusa_head_state_dict = torch.load(filename, map_location=model.device)
157 | model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False)
158 | return model
159 |
160 |
161 | def get_tokenizer(self):
162 | """Get the tokenizer of the base model.
163 |
164 | Returns:
165 | Tokenizer: The tokenizer of the base model.
166 | """
167 | return self.tokenizer
168 |
169 |
170 | def forward(
171 | self,
172 | input_ids=None,
173 | attention_mask=None,
174 | past_key_values=None,
175 | output_orig=False,
176 | position_ids=None,
177 | medusa_forward=False,
178 | **kwargs,
179 | ):
180 | """Forward pass of the MedusaModel.
181 |
182 | Args:
183 | input_ids (torch.Tensor, optional): Input token IDs.
184 | attention_mask (torch.Tensor, optional): Attention mask.
185 | labels (torch.Tensor, optional): Ground truth labels for loss computation.
186 | past_key_values (tuple, optional): Tuple containing past key and value states for attention.
187 | output_orig (bool, optional): Whether to also output predictions from the original LM head.
188 | position_ids (torch.Tensor, optional): Position IDs.
189 |
190 | Returns:
191 | torch.Tensor: A tensor containing predictions from all Medusa heads.
192 | (Optional) Original predictions from the base model's LM head.
193 | """
194 | if not medusa_forward:
195 | return super().forward(
196 | input_ids=input_ids,
197 | attention_mask=attention_mask,
198 | past_key_values=past_key_values,
199 | position_ids=position_ids,
200 | **kwargs,
201 | )
202 | with torch.inference_mode():
203 | # Pass input through the base model
204 | outputs = self.base_model.model(
205 | input_ids=input_ids,
206 | attention_mask=attention_mask,
207 | past_key_values=past_key_values,
208 | position_ids=position_ids,
209 | **kwargs,
210 | )
211 | if output_orig:
212 | orig = self.base_model.lm_head(outputs[0])
213 | # Clone the output hidden states
214 | hidden_states = outputs[0].clone()
215 | medusa_logits = []
216 | # TODO: Consider parallelizing this loop for efficiency?
217 | for i in range(self.medusa):
218 | medusa_logits.append(self.medusa_head[i](hidden_states))
219 | if output_orig:
220 | return torch.stack(medusa_logits, dim=0), outputs, orig
221 | return torch.stack(medusa_logits, dim=0)
222 | def get_medusa_choice(self, model_name):
223 | if 'vicuna' in model_name:
224 | if '7b' in model_name:
225 | return vicuna_7b_stage2
226 | elif '13b' in model_name:
227 | return vicuna_13b_stage2
228 | elif '33b' in model_name:
229 | return vicuna_33b_stage2
230 | elif 'zephyr' in model_name:
231 | return zephyr_stage2
232 | warnings.warn('Please specify medusa choice configuration!')
233 | return mc_sim_7b_63
234 |
235 | def medusa_generate(
236 | self,
237 | input_ids,
238 | attention_mask=None,
239 | temperature=0.0,
240 | max_steps=512,
241 | # The hyperparameters below are for the Medusa
242 | # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
243 | medusa_choices=None,
244 | posterior_threshold=0.09, # threshold validation of Medusa output
245 | # another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
246 | posterior_alpha=0.3,
247 | top_p=0.8,
248 | sampling = 'typical',
249 | fast = True
250 | ):
251 | """
252 | Args:
253 | input_ids (torch.Tensor, optional): Input token IDs.
254 | attention_mask (torch.Tensor, optional): Attention mask.
255 | temperature (float, optional): Temperature for typical acceptance.
256 | medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head.
257 | posterior_threshold (float, optional): Threshold for posterior validation.
258 | posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold).
259 | top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
260 | sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
261 | fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
262 | Returns:
263 | torch.Tensor: Output token IDs.
264 |
265 | Warning: Only support batch size 1 for now!!
266 | """
267 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
268 | # Avoid modifying the input_ids in-place
269 | input_ids = input_ids.clone()
270 |
271 | # Cache medusa buffers (the fixed patterns for tree attention)
272 | if medusa_choices is None:
273 | medusa_choices = self.get_medusa_choice(self.base_model_name_or_path)
274 |
275 | if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices:
276 | # Load the cached medusa buffer
277 | medusa_buffers = self.medusa_buffers
278 | else:
279 | # Initialize the medusa buffer
280 | medusa_buffers = generate_medusa_buffers(
281 | medusa_choices, device=self.base_model.device
282 | )
283 | self.medusa_buffers = medusa_buffers
284 | self.medusa_choices = medusa_choices
285 |
286 | # Initialize the past key and value states
287 | if hasattr(self, "past_key_values"):
288 | past_key_values = self.past_key_values
289 | past_key_values_data = self.past_key_values_data
290 | current_length_data = self.current_length_data
291 | # Reset the past key and value states
292 | current_length_data.zero_()
293 | else:
294 | (
295 | past_key_values,
296 | past_key_values_data,
297 | current_length_data,
298 | ) = initialize_past_key_values(self.base_model)
299 | self.past_key_values = past_key_values
300 | self.past_key_values_data = past_key_values_data
301 | self.current_length_data = current_length_data
302 |
303 | input_len = input_ids.shape[1]
304 |
305 | reset_medusa_mode(self)
306 | # Initialize tree attention mask and process prefill tokens
307 | medusa_logits, logits = initialize_medusa(
308 | input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values
309 | )
310 |
311 | new_token = 0
312 | last_round_token = 0
313 |
314 | for idx in range(max_steps):
315 | # Generate candidates with topk predictions from Medusa heads
316 | candidates, tree_candidates = generate_candidates(
317 | medusa_logits,
318 | logits,
319 | medusa_buffers["tree_indices"],
320 | medusa_buffers["retrieve_indices"],
321 | temperature=temperature,
322 | posterior_alpha=posterior_alpha,
323 | posterior_threshold=posterior_threshold,
324 | top_p=top_p,
325 | sampling=sampling,
326 | fast=fast,
327 | )
328 |
329 | # Use tree attention to verify the candidates and get predictions
330 | medusa_logits, logits, outputs = tree_decoding(
331 | self,
332 | tree_candidates,
333 | past_key_values,
334 | medusa_buffers["medusa_position_ids"],
335 | input_ids,
336 | medusa_buffers["retrieve_indices"],
337 | )
338 |
339 | # Evaluate the posterior of the candidates to select the accepted candidate prefix
340 | best_candidate, accept_length = evaluate_posterior(
341 | logits, candidates, temperature, posterior_threshold, posterior_alpha, top_p=top_p, sampling=sampling, fast=fast
342 | )
343 |
344 | # Update the input_ids and logits
345 | input_ids, logits, medusa_logits, new_token = update_inference_inputs(
346 | input_ids,
347 | candidates,
348 | best_candidate,
349 | accept_length,
350 | medusa_buffers["retrieve_indices"],
351 | outputs,
352 | logits,
353 | medusa_logits,
354 | new_token,
355 | past_key_values_data,
356 | current_length_data,
357 | )
358 |
359 | yield {
360 | "text": self.tokenizer.decode(
361 | input_ids[0, input_len:],
362 | skip_special_tokens=True,
363 | spaces_between_special_tokens=False,
364 | clean_up_tokenization_spaces=True,
365 | )
366 | }
367 |
368 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
369 | break
370 |
371 |
372 | class MedusaModelLlama(MedusaModelABC, KVLlamaForCausalLM):
373 | pass
374 |
375 | class MedusaModelMistral(MedusaModelABC, KVMistralForCausalLM):
376 | pass
377 |
378 |
379 | class MedusaModel():
380 | @classmethod
381 | def from_pretrained(
382 | cls,
383 | pretrained_model_name_or_path,
384 | *args,
385 | **kwargs,
386 | ):
387 | # Manually load config to ensure that the medusa_num_heads parameter is loaded
388 | try:
389 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
390 | except:
391 | # MEDUSA-v0.1 load
392 | config = MedusaConfig.from_pretrained(pretrained_model_name_or_path)
393 | base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path)
394 | config.model_type = base_model_config.model_type
395 |
396 | if config.model_type == "llama":
397 | return MedusaModelLlama.from_pretrained(
398 | pretrained_model_name_or_path,
399 | *args,
400 | **kwargs,
401 | )
402 | elif config.model_type == "mistral":
403 | return MedusaModelMistral.from_pretrained(
404 | pretrained_model_name_or_path,
405 | *args,
406 | **kwargs,
407 | )
408 | else:
409 | raise ValueError("Only support llama and mistral for now!!")
410 |
--------------------------------------------------------------------------------
/medusa/model/medusa_model_legacy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import PreTrainedModel, PretrainedConfig
4 | from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
5 | from .utils import *
6 | from .kv_cache import initialize_past_key_values
7 | from .medusa_choices import mc_sim_7b_63
8 | from transformers import AutoTokenizer
9 | import os
10 | from huggingface_hub import hf_hub_download
11 |
12 |
13 | class MedusaConfig(PretrainedConfig):
14 | """
15 | Configuration class for Medusa model.
16 |
17 | Args:
18 | medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2.
19 | medusa_num_layers (int, optional): Number of Medusa layers. Default is 1.
20 | base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3".
21 | **kwargs: Additional keyword arguments to be passed to the parent class constructor.
22 | """
23 |
24 | def __init__(
25 | self,
26 | medusa_num_heads=4,
27 | medusa_num_layers=1,
28 | version="2",
29 | base_model_name_or_path="lmsys/vicuna-7b-v1.3",
30 | **kwargs,
31 | ):
32 | super().__init__(**kwargs)
33 | self.medusa_num_heads = medusa_num_heads
34 | self.medusa_num_layers = medusa_num_layers
35 | self.version = version
36 | self.base_model_name_or_path = base_model_name_or_path
37 |
38 |
39 | class ResBlock(nn.Module):
40 | """
41 | A Residual Block module.
42 |
43 | This module performs a linear transformation followed by a SiLU activation,
44 | and then adds the result to the original input, creating a residual connection.
45 |
46 | Args:
47 | hidden_size (int): The size of the hidden layers in the block.
48 | """
49 |
50 | def __init__(self, hidden_size):
51 | super().__init__()
52 | self.linear = nn.Linear(hidden_size, hidden_size)
53 | # Initialize as an identity mapping
54 | torch.nn.init.zeros_(self.linear.weight)
55 | # Use SiLU activation to keep consistent with the Llama model
56 | self.act = nn.SiLU()
57 |
58 | def forward(self, x):
59 | """
60 | Forward pass of the ResBlock.
61 |
62 | Args:
63 | x (torch.Tensor): Input tensor.
64 |
65 | Returns:
66 | torch.Tensor: Output after the residual connection and activation.
67 | """
68 | return x + self.act(self.linear(x))
69 |
70 |
71 | class MedusaModel(nn.Module):
72 | """The Medusa Language Model Head.
73 |
74 | This module creates a series of prediction heads (based on the 'medusa' parameter)
75 | on top of a given base model. Each head is composed of a sequence of residual blocks
76 | followed by a linear layer.
77 | """
78 |
79 | def __init__(
80 | self,
81 | base_model,
82 | medusa_num_heads=4,
83 | medusa_num_layers=1,
84 | base_model_name_or_path="lmsys/vicuna-7b-v1.3",
85 | ):
86 | """
87 | Args:
88 | base_model (nn.Module): The base language model to be used.
89 | medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3.
90 | medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0.
91 | """
92 | super().__init__()
93 | self.base_model = base_model
94 | self.config = base_model.config
95 | self.hidden_size = base_model.config.hidden_size
96 | self.vocab_size = base_model.config.vocab_size
97 | self.medusa = medusa_num_heads
98 | self.medusa_num_layers = medusa_num_layers
99 | self.base_model_name_or_path = base_model_name_or_path
100 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
101 | # Create a list of Medusa heads
102 | self.medusa_head = nn.ModuleList(
103 | [
104 | nn.Sequential(
105 | *([ResBlock(self.hidden_size)] * medusa_num_layers),
106 | )
107 | for _ in range(medusa_num_heads)
108 | ]
109 | )
110 |
111 | # Ensure medusa_head's dtype and device align with the base_model
112 | self.medusa_head.to(self.base_model.dtype).to(self.base_model.device)
113 |
114 | def get_tokenizer(self):
115 | """Get the tokenizer of the base model.
116 |
117 | Returns:
118 | Tokenizer: The tokenizer of the base model.
119 | """
120 | return self.tokenizer
121 |
122 | @classmethod
123 | def from_pretrained(
124 | cls,
125 | medusa_head_name_or_path,
126 | base_model=None,
127 | medusa_num_heads=None,
128 | **kwargs,
129 | ):
130 | """
131 | Args:
132 | medusa_head_name_or_path (str): Name or path of the Medusa head to load.
133 | **kwargs: Additional keyword arguments for loading the base model.
134 |
135 | Returns:
136 | MedusaModel: A MedusaModel instance loaded from the given path.
137 | """
138 | medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path)
139 | if medusa_num_heads is not None:
140 | print("Overriding medusa_num_heads as:", medusa_num_heads)
141 | medusa_config.medusa_num_heads = medusa_num_heads
142 | if base_model is not None:
143 | print("Overriding base_model as:", base_model)
144 | medusa_config.base_model_name_or_path = base_model
145 |
146 | base_model = KVLlamaForCausalLM.from_pretrained(
147 | medusa_config.base_model_name_or_path, **kwargs
148 | )
149 |
150 | model = cls(
151 | base_model,
152 | medusa_config.medusa_num_heads,
153 | medusa_config.medusa_num_layers,
154 | medusa_config.base_model_name_or_path,
155 | )
156 | medusa_head_path = os.path.join(medusa_head_name_or_path, "medusa_lm_head.pt")
157 | if os.path.exists(medusa_head_path):
158 | filename = medusa_head_path
159 | else:
160 | filename = hf_hub_download(medusa_head_name_or_path, "medusa_lm_head.pt")
161 | medusa_head_state_dict = torch.load(filename, map_location=base_model.device)
162 | model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False)
163 |
164 | return model
165 |
166 | def forward(
167 | self,
168 | input_ids=None,
169 | attention_mask=None,
170 | labels=None,
171 | past_key_values=None,
172 | output_orig=False,
173 | position_ids=None,
174 | ):
175 | """Forward pass of the MedusaModel.
176 |
177 | Args:
178 | input_ids (torch.Tensor, optional): Input token IDs.
179 | attention_mask (torch.Tensor, optional): Attention mask.
180 | labels (torch.Tensor, optional): Ground truth labels for loss computation.
181 | past_key_values (tuple, optional): Tuple containing past key and value states for attention.
182 | output_orig (bool, optional): Whether to also output predictions from the original LM head.
183 | position_ids (torch.Tensor, optional): Position IDs.
184 |
185 | Returns:
186 | torch.Tensor: A tensor containing predictions from all Medusa heads.
187 | (Optional) Original predictions from the base model's LM head.
188 | """
189 | with torch.no_grad():
190 | # Pass input through the base model
191 | outputs = self.base_model.model(
192 | input_ids=input_ids,
193 | attention_mask=attention_mask,
194 | past_key_values=past_key_values,
195 | position_ids=position_ids,
196 | )
197 | if output_orig:
198 | orig = self.base_model.lm_head(outputs[0])
199 | # Clone the output hidden states
200 | hidden_states = outputs[0].clone()
201 | medusa_logits = []
202 | # TODO: Consider parallelizing this loop for efficiency?
203 | for i in range(self.medusa):
204 | mhidden_states = self.medusa_head[i](hidden_states)
205 | mlogits = self.base_model.lm_head(mhidden_states)
206 | medusa_logits.append(mlogits)
207 | if output_orig:
208 | return torch.stack(medusa_logits, dim=0), outputs, orig
209 | return torch.stack(medusa_logits, dim=0)
210 |
211 | def medusa_generate(
212 | self,
213 | input_ids,
214 | attention_mask=None,
215 | temperature=0.0,
216 | max_steps=512,
217 | # The hyperparameters below are for the Medusa
218 | # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
219 | medusa_choices=mc_sim_7b_63,
220 | posterior_threshold=0.09, # threshold validation of Medusa output
221 | # another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
222 | posterior_alpha=0.3,
223 | ):
224 | """
225 | Args:
226 | input_ids (torch.Tensor, optional): Input token IDs.
227 | attention_mask (torch.Tensor, optional): Attention mask.
228 | temperature (float, optional): Temperature for typical acceptance.
229 | medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head.
230 | posterior_threshold (float, optional): Threshold for posterior validation.
231 | posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold).
232 | Returns:
233 | torch.Tensor: Output token IDs.
234 |
235 | Warning: Only support batch size 1 for now!!
236 | """
237 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
238 | # Avoid modifying the input_ids in-place
239 | input_ids = input_ids.clone()
240 |
241 | # Cache medusa buffers (the fixed patterns for tree attention)
242 | if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices:
243 | # Load the cached medusa buffer
244 | medusa_buffers = self.medusa_buffers
245 | else:
246 | # Initialize the medusa buffer
247 | medusa_buffers = generate_medusa_buffers(
248 | medusa_choices, device=self.base_model.device
249 | )
250 | self.medusa_buffers = medusa_buffers
251 | self.medusa_choices = medusa_choices
252 |
253 |
254 | # Initialize the past key and value states
255 | if hasattr(self, "past_key_values"):
256 | past_key_values = self.past_key_values
257 | past_key_values_data = self.past_key_values_data
258 | current_length_data = self.current_length_data
259 | # Reset the past key and value states
260 | current_length_data.zero_()
261 | else:
262 | (
263 | past_key_values,
264 | past_key_values_data,
265 | current_length_data,
266 | ) = initialize_past_key_values(self.base_model)
267 | self.past_key_values = past_key_values
268 | self.past_key_values_data = past_key_values_data
269 | self.current_length_data = current_length_data
270 |
271 | input_len = input_ids.shape[1]
272 |
273 | reset_medusa_mode(self)
274 | # Initialize tree attention mask and process prefill tokens
275 | medusa_logits, logits = initialize_medusa(
276 | input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values
277 | )
278 |
279 | new_token = 0
280 | last_round_token = 0
281 |
282 | for idx in range(max_steps):
283 | # Generate candidates with topk predictions from Medusa heads
284 | candidates, tree_candidates = generate_candidates(
285 | medusa_logits,
286 | logits,
287 | medusa_buffers["tree_indices"],
288 | medusa_buffers["retrieve_indices"],
289 | )
290 |
291 | # Use tree attention to verify the candidates and get predictions
292 | medusa_logits, logits, outputs = tree_decoding(
293 | self,
294 | tree_candidates,
295 | past_key_values,
296 | medusa_buffers["medusa_position_ids"],
297 | input_ids,
298 | medusa_buffers["retrieve_indices"],
299 | )
300 |
301 | # Evaluate the posterior of the candidates to select the accepted candidate prefix
302 | best_candidate, accept_length = evaluate_posterior(
303 | logits, candidates, temperature, posterior_threshold, posterior_alpha
304 | )
305 |
306 | # Update the input_ids and logits
307 | input_ids, logits, medusa_logits, new_token = update_inference_inputs(
308 | input_ids,
309 | candidates,
310 | best_candidate,
311 | accept_length,
312 | medusa_buffers["retrieve_indices"],
313 | outputs,
314 | logits,
315 | medusa_logits,
316 | new_token,
317 | past_key_values_data,
318 | current_length_data,
319 | )
320 |
321 | yield {
322 | "text": self.tokenizer.decode(
323 | input_ids[0, input_len:],
324 | skip_special_tokens=True,
325 | spaces_between_special_tokens=False,
326 | clean_up_tokenization_spaces=True,
327 | )
328 | }
329 |
330 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
331 | break
332 |
--------------------------------------------------------------------------------
/medusa/model/medusa_model_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import PreTrainedModel, PretrainedConfig
4 | from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM
5 | from .modeling_mistral_kv import MistralForCausalLM as KVMistralForCausalLM
6 | from .utils import *
7 | from .kv_cache import initialize_past_key_values
8 | from .medusa_choices import mc_sim_7b_63
9 | from transformers import AutoTokenizer, AutoConfig
10 | import os
11 | from huggingface_hub import hf_hub_download
12 |
13 |
14 | class ResBlock(nn.Module):
15 | """
16 | A Residual Block module.
17 |
18 | This module performs a linear transformation followed by a SiLU activation,
19 | and then adds the result to the original input, creating a residual connection.
20 |
21 | Args:
22 | hidden_size (int): The size of the hidden layers in the block.
23 | """
24 |
25 | def __init__(self, hidden_size):
26 | super().__init__()
27 | self.linear = nn.Linear(hidden_size, hidden_size)
28 | # Initialize as an identity mapping
29 | torch.nn.init.zeros_(self.linear.weight)
30 | # Use SiLU activation to keep consistent with the Llama model
31 | self.act = nn.SiLU()
32 |
33 | def forward(self, x):
34 | """
35 | Forward pass of the ResBlock.
36 |
37 | Args:
38 | x (torch.Tensor): Input tensor.
39 |
40 | Returns:
41 | torch.Tensor: Output after the residual connection and activation.
42 | """
43 | return x + self.act(self.linear(x))
44 |
45 | class MedusaModel(PreTrainedModel):
46 | """The Medusa Language Model Head.
47 |
48 | This module creates a series of prediction heads (based on the 'medusa' parameter)
49 | on top of a given base model. Each head is composed of a sequence of residual blocks
50 | followed by a linear layer.
51 | """
52 |
53 | def __init__(
54 | self,
55 | config,
56 | ):
57 | """
58 | Args:
59 | config (PretrainedConfig): The configuration of the MedusaModel.
60 | """
61 | super().__init__(config)
62 | # For compatibility with the old APIs
63 | medusa_num_heads = config.medusa_num_heads
64 | medusa_num_layers = config.medusa_num_layers
65 | base_model_name_or_path = config._name_or_path
66 | self.hidden_size = config.hidden_size
67 | self.vocab_size = config.vocab_size
68 | self.medusa = medusa_num_heads
69 | self.medusa_num_layers = medusa_num_layers
70 | self.base_model_name_or_path = base_model_name_or_path
71 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
72 | # Create a list of Medusa heads
73 | self.medusa_head = nn.ModuleList(
74 | [
75 | nn.Sequential(
76 | *([ResBlock(self.hidden_size)] * medusa_num_layers),
77 | nn.Linear(self.hidden_size, self.vocab_size, bias=False),
78 | )
79 | for _ in range(medusa_num_heads)
80 | ]
81 | )
82 |
83 | # Add a link named base_model to self
84 | @property
85 | def base_model(self):
86 | return self
87 |
88 | def get_tokenizer(self):
89 | """Get the tokenizer of the base model.
90 |
91 | Returns:
92 | Tokenizer: The tokenizer of the base model.
93 | """
94 | return self.tokenizer
95 |
96 | @classmethod
97 | def from_pretrained(
98 | cls,
99 | pretrained_model_name_or_path,
100 | *args,
101 | **kwargs,
102 | ):
103 | # Manually load config to ensure that the medusa_num_heads parameter is loaded
104 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
105 | return super().from_pretrained(
106 | pretrained_model_name_or_path,
107 | *args,
108 | **kwargs,
109 | config=config,
110 | )
111 |
112 | def forward(
113 | self,
114 | input_ids=None,
115 | attention_mask=None,
116 | past_key_values=None,
117 | output_orig=False,
118 | position_ids=None,
119 | medusa_forward=False,
120 | **kwargs,
121 | ):
122 | """Forward pass of the MedusaModel.
123 |
124 | Args:
125 | input_ids (torch.Tensor, optional): Input token IDs.
126 | attention_mask (torch.Tensor, optional): Attention mask.
127 | labels (torch.Tensor, optional): Ground truth labels for loss computation.
128 | past_key_values (tuple, optional): Tuple containing past key and value states for attention.
129 | output_orig (bool, optional): Whether to also output predictions from the original LM head.
130 | position_ids (torch.Tensor, optional): Position IDs.
131 |
132 | Returns:
133 | torch.Tensor: A tensor containing predictions from all Medusa heads.
134 | (Optional) Original predictions from the base model's LM head.
135 | """
136 | """Forward pass of the MedusaModel.
137 |
138 | Args:
139 | input_ids (torch.Tensor, optional): Input token IDs.
140 | attention_mask (torch.Tensor, optional): Attention mask.
141 | labels (torch.Tensor, optional): Ground truth labels for loss computation.
142 | past_key_values (tuple, optional): Tuple containing past key and value states for attention.
143 | output_orig (bool, optional): Whether to also output predictions from the original LM head.
144 | position_ids (torch.Tensor, optional): Position IDs.
145 |
146 | Returns:
147 | torch.Tensor: A tensor containing predictions from all Medusa heads.
148 | (Optional) Original predictions from the base model's LM head.
149 | """
150 | if not medusa_forward:
151 | return super().forward(
152 | input_ids=input_ids,
153 | attention_mask=attention_mask,
154 | past_key_values=past_key_values,
155 | position_ids=position_ids,
156 | **kwargs,
157 | )
158 | with torch.inference_mode():
159 | # Pass input through the base model
160 | outputs = self.base_model.model(
161 | input_ids=input_ids,
162 | attention_mask=attention_mask,
163 | past_key_values=past_key_values,
164 | position_ids=position_ids,
165 | **kwargs,
166 | )
167 | if output_orig:
168 | orig = self.base_model.lm_head(outputs[0])
169 | # Clone the output hidden states
170 | hidden_states = outputs[0].clone()
171 | medusa_logits = []
172 | # TODO: Consider parallelizing this loop for efficiency?
173 | for i in range(self.medusa):
174 | medusa_logits.append(self.medusa_head[i](hidden_states))
175 | if output_orig:
176 | return torch.stack(medusa_logits, dim=0), outputs, orig
177 | return torch.stack(medusa_logits, dim=0)
178 |
179 |
180 |
181 | class MedusaLlamaModel(KVLlamaForCausalLM):
182 | """The Medusa Language Model Head.
183 |
184 | This module creates a series of prediction heads (based on the 'medusa' parameter)
185 | on top of a given base model. Each head is composed of a sequence of residual blocks
186 | followed by a linear layer.
187 | """
188 |
189 | def __init__(
190 | self,
191 | config,
192 | ):
193 | """
194 | Args:
195 | config (PretrainedConfig): The configuration of the MedusaModel.
196 | """
197 | # Load the base model
198 | super().__init__(config)
199 | # For compatibility with the old APIs
200 |
201 | medusa_num_heads = config.medusa_num_heads
202 | medusa_num_layers = config.medusa_num_layers
203 | base_model_name_or_path = config._name_or_path
204 | self.hidden_size = config.hidden_size
205 | self.vocab_size = config.vocab_size
206 | self.medusa = medusa_num_heads
207 | self.medusa_num_layers = medusa_num_layers
208 | self.base_model_name_or_path = base_model_name_or_path
209 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path)
210 | # Create a list of Medusa heads
211 | self.medusa_head = nn.ModuleList(
212 | [
213 | nn.Sequential(
214 | *([ResBlock(self.hidden_size)] * medusa_num_layers),
215 | nn.Linear(self.hidden_size, self.vocab_size, bias=False),
216 | )
217 | for _ in range(medusa_num_heads)
218 | ]
219 | )
220 |
221 | # Add a link named base_model to self
222 | @property
223 | def base_model(self):
224 | return self
225 |
226 | def get_tokenizer(self):
227 | """Get the tokenizer of the base model.
228 |
229 | Returns:
230 | Tokenizer: The tokenizer of the base model.
231 | """
232 | return self.tokenizer
233 |
234 | @classmethod
235 | def from_pretrained(
236 | cls,
237 | pretrained_model_name_or_path,
238 | *args,
239 | **kwargs,
240 | ):
241 | # Manually load config to ensure that the medusa_num_heads parameter is loaded
242 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
243 | return super().from_pretrained(
244 | pretrained_model_name_or_path,
245 | *args,
246 | **kwargs,
247 | config=config,
248 | )
249 |
250 | def forward(
251 | self,
252 | input_ids=None,
253 | attention_mask=None,
254 | past_key_values=None,
255 | output_orig=False,
256 | position_ids=None,
257 | medusa_forward=False,
258 | **kwargs,
259 | ):
260 | """Forward pass of the MedusaModel.
261 |
262 | Args:
263 | input_ids (torch.Tensor, optional): Input token IDs.
264 | attention_mask (torch.Tensor, optional): Attention mask.
265 | labels (torch.Tensor, optional): Ground truth labels for loss computation.
266 | past_key_values (tuple, optional): Tuple containing past key and value states for attention.
267 | output_orig (bool, optional): Whether to also output predictions from the original LM head.
268 | position_ids (torch.Tensor, optional): Position IDs.
269 |
270 | Returns:
271 | torch.Tensor: A tensor containing predictions from all Medusa heads.
272 | (Optional) Original predictions from the base model's LM head.
273 | """
274 | if not medusa_forward:
275 | return super().forward(
276 | input_ids=input_ids,
277 | attention_mask=attention_mask,
278 | past_key_values=past_key_values,
279 | position_ids=position_ids,
280 | **kwargs,
281 | )
282 | with torch.inference_mode():
283 | # Pass input through the base model
284 | outputs = self.base_model.model(
285 | input_ids=input_ids,
286 | attention_mask=attention_mask,
287 | past_key_values=past_key_values,
288 | position_ids=position_ids,
289 | **kwargs,
290 | )
291 | if output_orig:
292 | orig = self.base_model.lm_head(outputs[0])
293 | # Clone the output hidden states
294 | hidden_states = outputs[0].clone()
295 | medusa_logits = []
296 | # TODO: Consider parallelizing this loop for efficiency?
297 | for i in range(self.medusa):
298 | medusa_logits.append(self.medusa_head[i](hidden_states))
299 | if output_orig:
300 | return torch.stack(medusa_logits, dim=0), outputs, orig
301 | return torch.stack(medusa_logits, dim=0)
302 |
303 | def medusa_generate(
304 | self,
305 | input_ids,
306 | attention_mask=None,
307 | temperature=0.0,
308 | max_steps=512,
309 | # The hyperparameters below are for the Medusa
310 | # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token.
311 | medusa_choices=mc_sim_7b_63,
312 | posterior_threshold=0.09, # threshold validation of Medusa output
313 | # another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
314 | posterior_alpha=0.3,
315 | ):
316 | """
317 | Args:
318 | input_ids (torch.Tensor, optional): Input token IDs.
319 | attention_mask (torch.Tensor, optional): Attention mask.
320 | temperature (float, optional): Temperature for typical acceptance.
321 | medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head.
322 | posterior_threshold (float, optional): Threshold for posterior validation.
323 | posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold).
324 | Returns:
325 | torch.Tensor: Output token IDs.
326 |
327 | Warning: Only support batch size 1 for now!!
328 | """
329 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
330 | # Avoid modifying the input_ids in-place
331 | input_ids = input_ids.clone()
332 |
333 | # Cache medusa buffers (the fixed patterns for tree attention)
334 | if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices:
335 | # Load the cached medusa buffer
336 | medusa_buffers = self.medusa_buffers
337 | else:
338 | # Initialize the medusa buffer
339 | medusa_buffers = generate_medusa_buffers(
340 | medusa_choices, device=self.base_model.device
341 | )
342 | self.medusa_buffers = medusa_buffers
343 | self.medusa_choices = medusa_choices
344 |
345 |
346 | # Initialize the past key and value states
347 | if hasattr(self, "past_key_values"):
348 | past_key_values = self.past_key_values
349 | past_key_values_data = self.past_key_values_data
350 | current_length_data = self.current_length_data
351 | # Reset the past key and value states
352 | current_length_data.zero_()
353 | else:
354 | (
355 | past_key_values,
356 | past_key_values_data,
357 | current_length_data,
358 | ) = initialize_past_key_values(self.base_model)
359 | self.past_key_values = past_key_values
360 | self.past_key_values_data = past_key_values_data
361 | self.current_length_data = current_length_data
362 |
363 | input_len = input_ids.shape[1]
364 |
365 | reset_medusa_mode(self)
366 | # Initialize tree attention mask and process prefill tokens
367 | medusa_logits, logits = initialize_medusa(
368 | input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values
369 | )
370 |
371 | new_token = 0
372 | last_round_token = 0
373 |
374 | for idx in range(max_steps):
375 | # Generate candidates with topk predictions from Medusa heads
376 | candidates, tree_candidates = generate_candidates(
377 | medusa_logits,
378 | logits,
379 | medusa_buffers["tree_indices"],
380 | medusa_buffers["retrieve_indices"],
381 | )
382 |
383 | # Use tree attention to verify the candidates and get predictions
384 | medusa_logits, logits, outputs = tree_decoding(
385 | self,
386 | tree_candidates,
387 | past_key_values,
388 | medusa_buffers["medusa_position_ids"],
389 | input_ids,
390 | medusa_buffers["retrieve_indices"],
391 | )
392 |
393 | # Evaluate the posterior of the candidates to select the accepted candidate prefix
394 | best_candidate, accept_length = evaluate_posterior(
395 | logits, candidates, temperature, posterior_threshold, posterior_alpha
396 | )
397 |
398 | # Update the input_ids and logits
399 | input_ids, logits, medusa_logits, new_token = update_inference_inputs(
400 | input_ids,
401 | candidates,
402 | best_candidate,
403 | accept_length,
404 | medusa_buffers["retrieve_indices"],
405 | outputs,
406 | logits,
407 | medusa_logits,
408 | new_token,
409 | past_key_values_data,
410 | current_length_data,
411 | )
412 |
413 | yield {
414 | "text": self.tokenizer.decode(
415 | input_ids[0, input_len:],
416 | skip_special_tokens=True,
417 | spaces_between_special_tokens=False,
418 | clean_up_tokenization_spaces=True,
419 | )
420 | }
421 |
422 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]:
423 | break
424 |
425 | # Currently only support LlamaModel
426 | MedusaModel = MedusaLlamaModel
--------------------------------------------------------------------------------
/medusa/model/utils_legacy.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | TOPK=10 # topk for sparse tree
4 |
5 | def pad_path(path, length, pad_value=-2):
6 | """
7 | Pad the given path list with a specific value up to a specified length.
8 |
9 | Parameters:
10 | - path (list): The original list that needs padding.
11 | - length (int): The desired length of the padded list.
12 | - pad_value (optional, default=-2): The value to use for padding.
13 |
14 | Returns:
15 | - list: A new list based on the original path but padded to the desired length.
16 |
17 | Example:
18 | >>> pad_path([1,2,3], 5)
19 | [1, 2, 3, -2, -2]
20 |
21 | Note:
22 | If the given path is already longer than the specified length,
23 | then no padding occurs, and the original path is returned.
24 | """
25 |
26 | # Calculate the number of padding values needed by subtracting the length
27 | # of the path from the desired length.
28 | # Append the padding values to the original path and return the new list.
29 | return path + [pad_value] * (length - len(path))
30 |
31 | def generate_medusa_buffers(medusa_choices, device="cuda"):
32 | """
33 | Generate buffers for the Medusa structure based on the provided choices.
34 |
35 | Parameters:
36 | - medusa_choices (list): A nested list representing tree in the Medusa structure.
37 | - device (str): Device to which the tensors should be moved. Default is "cuda".
38 |
39 | Returns:
40 | - dict: A dictionary containing buffers related to the Medusa structure.
41 | """
42 |
43 | # Sort the medusa_choices based on their lengths and then their values
44 | sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x))
45 | medusa_len = len(sorted_medusa_choices) + 1
46 |
47 | # Initialize depth_counts to keep track of how many choices have a particular depth
48 | depth_counts = []
49 | prev_depth = 0
50 | for path in sorted_medusa_choices:
51 | depth = len(path)
52 | if depth != prev_depth:
53 | depth_counts.append(0)
54 | depth_counts[depth - 1] += 1
55 | prev_depth = depth
56 |
57 | # Create the attention mask for Medusa
58 | medusa_attn_mask = torch.eye(medusa_len, medusa_len)
59 | medusa_attn_mask[:, 0] = 1
60 | start = 0
61 | for i in range(len(depth_counts)):
62 | for j in range(depth_counts[i]):
63 | cur_medusa_choice = sorted_medusa_choices[start + j]
64 | # retrieve ancestor position
65 | if len(cur_medusa_choice) == 1:
66 | continue
67 | ancestor_idx = []
68 | for c in range(len(cur_medusa_choice) - 1):
69 | ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1)
70 | medusa_attn_mask[j + start + 1, ancestor_idx] = 1
71 | start += depth_counts[i]
72 |
73 | # Generate tree indices for the Medusa structure
74 | medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long)
75 | medusa_tree_indices[0] = 0
76 | start = 0
77 | for i in range(len(depth_counts)):
78 | for j in range(depth_counts[i]):
79 | cur_medusa_choice = sorted_medusa_choices[start + j]
80 | medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1
81 | start += depth_counts[i]
82 |
83 | # Generate position IDs for the Medusa structure
84 | medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long)
85 | start = 0
86 | for i in range(len(depth_counts)):
87 | medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1
88 | start += depth_counts[i]
89 |
90 | # Generate retrieval indices for Medusa structure verification
91 | retrieve_indices_nest = []
92 | retrieve_paths = []
93 | for i in range(len(sorted_medusa_choices)):
94 | cur_medusa_choice = sorted_medusa_choices[-i-1]
95 | retrieve_indice = []
96 | if cur_medusa_choice in retrieve_paths:
97 | continue
98 | else:
99 | for c in range(len(cur_medusa_choice)):
100 | retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]))
101 | retrieve_paths.append(cur_medusa_choice[:c+1])
102 | retrieve_indices_nest.append(retrieve_indice)
103 | max_length = max([len(x) for x in retrieve_indices_nest])
104 | retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest]
105 | retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long)
106 | retrieve_indices = retrieve_indices + 1
107 | retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1)
108 |
109 | # Aggregate the generated buffers into a dictionary
110 | medusa_buffers = {
111 | "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0),
112 | "tree_indices": medusa_tree_indices,
113 | "medusa_position_ids": medusa_position_ids,
114 | "retrieve_indices": retrieve_indices,
115 | }
116 |
117 | # Move the tensors in the dictionary to the specified device
118 | medusa_buffers = {
119 | k: v.clone().to(device)
120 | if isinstance(v, torch.Tensor)
121 | else torch.tensor(v, device=device)
122 | for k, v in medusa_buffers.items()
123 | }
124 | return medusa_buffers
125 |
126 |
127 | def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values):
128 | """
129 | Initializes the Medusa structure for a given model.
130 |
131 | This function performs the following operations:
132 | 1. Forward pass through the model to obtain the Medusa logits, original model outputs, and logits.
133 | 2. Sets the Medusa attention mask within the base model.
134 |
135 | Args:
136 | - input_ids (torch.Tensor): The input tensor containing token ids.
137 | - model (MedusaLMHead): The model containing the Medusa layers and base model.
138 | - medusa_attn_mask (torch.Tensor): The attention mask designed specifically for the Medusa structure.
139 | - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.
140 |
141 | Returns:
142 | - medusa_logits (torch.Tensor): Logits from the Medusa heads.
143 | - logits (torch.Tensor): Original logits from the base model.
144 | """
145 | medusa_logits, outputs, logits = model(
146 | input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True
147 | )
148 | model.base_model.model.medusa_mask = medusa_attn_mask
149 | return medusa_logits, logits
150 |
151 |
152 | def reset_medusa_mode(
153 | model,
154 | ):
155 | """
156 | Resets the Medusa settings and the past key-values to their initial state.
157 |
158 | This function ensures that after any operations involving Medusa,
159 | the base model and its settings return to their default state.
160 | Specifically, it performs the following tasks:
161 | 1. Clears the Medusa attention mask in the base model.
162 | 2. Resets the Medusa mode in the base model.
163 | 3. Resets the current lengths in the past key-values to zero for all layers.
164 |
165 | Args:
166 | - model (MedusaLMHead): The model containing the Medusa layers and base model.
167 | - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.
168 |
169 | Returns:
170 | - past_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
171 | """
172 | model.base_model.model.medusa_mask = None
173 | model.base_model.model.medusa_mode = None
174 |
175 |
176 | def reset_past_key_values(passed_key_values):
177 | """
178 | Resets the current lengths in the passed key-values to zero.
179 |
180 | This function is designed to be used during the evaluation of a baseline model.
181 | It iterates through each layer's key-values and sets their current lengths to zero,
182 | effectively resetting their state.
183 |
184 | Args:
185 | - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer.
186 |
187 | Returns:
188 | - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
189 | """
190 | for i in range(len(passed_key_values)):
191 | for j in range(2):
192 | passed_key_values[i][j].current_length.fill_(0)
193 | return passed_key_values
194 |
195 |
196 | def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices):
197 | """
198 | Generate candidates based on provided logits and indices.
199 |
200 | Parameters:
201 | - medusa_logits (torch.Tensor): Logits associated with the Medusa structure.
202 | - logits (torch.Tensor): Original logits.
203 | - tree_indices (list or torch.Tensor): Indices associated with a tree structure.
204 | - retrieve_indices (list or torch.Tensor): Indices for retrieving candidates.
205 |
206 | Returns:
207 | - tuple: Returns cartesian candidates and tree candidates.
208 | """
209 |
210 | # Greedy decoding: Select the most probable candidate from the original logits.
211 | candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0)
212 |
213 | # Extract the TOPK candidates from the medusa logits.
214 | candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices
215 |
216 | # Combine the selected candidate from the original logits with the topk medusa logits.
217 | candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1)
218 |
219 | # Map the combined candidates to the tree indices to get tree candidates.
220 | tree_candidates = candidates[tree_indices]
221 |
222 | # Extend the tree candidates by appending a zero.
223 | tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0)
224 |
225 | # Retrieve the cartesian candidates using the retrieve indices.
226 | cart_candidates = tree_candidates_ext[retrieve_indices]
227 |
228 | # Unsqueeze the tree candidates for dimension consistency.
229 | tree_candidates = tree_candidates.unsqueeze(0)
230 | return cart_candidates, tree_candidates
231 |
232 |
233 | def tree_decoding(
234 | model,
235 | tree_candidates,
236 | past_key_values,
237 | medusa_position_ids,
238 | input_ids,
239 | retrieve_indices,
240 | ):
241 | """
242 | Decode the tree candidates using the provided model and reorganize the logits.
243 |
244 | Parameters:
245 | - model (nn.Module): Model to be used for decoding the tree candidates.
246 | - tree_candidates (torch.Tensor): Input candidates based on a tree structure.
247 | - past_key_values (torch.Tensor): Past states, such as key and value pairs, used in attention layers.
248 | - medusa_position_ids (torch.Tensor): Positional IDs associated with the Medusa structure.
249 | - input_ids (torch.Tensor): Input sequence IDs.
250 | - retrieve_indices (list or torch.Tensor): Indices for reordering the logits.
251 |
252 | Returns:
253 | - tuple: Returns medusa logits, regular logits, and other outputs from the model.
254 | """
255 |
256 | # Compute new position IDs by adding the Medusa position IDs to the length of the input sequence.
257 | position_ids = medusa_position_ids + input_ids.shape[1]
258 |
259 | # Use the model to decode the tree candidates.
260 | # The model is expected to return logits for the Medusa structure, original logits, and possibly other outputs.
261 | tree_medusa_logits, outputs, tree_logits = model(
262 | tree_candidates,
263 | output_orig=True,
264 | past_key_values=past_key_values,
265 | position_ids=position_ids,
266 | medusa_forward=True,
267 | )
268 |
269 | # Reorder the obtained logits based on the retrieve_indices to ensure consistency with some reference ordering.
270 | logits = tree_logits[0, retrieve_indices]
271 | medusa_logits = tree_medusa_logits[:, 0, retrieve_indices]
272 | return medusa_logits, logits, outputs
273 |
274 |
275 | def evaluate_posterior(
276 | logits, candidates, temperature, posterior_threshold, posterior_alpha
277 | ):
278 | """
279 | Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.
280 |
281 | Depending on the temperature value, the function either uses greedy decoding or evaluates posterior
282 | probabilities to select the best candidate.
283 |
284 | Args:
285 | - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size).
286 | - candidates (torch.Tensor): Candidate token sequences.
287 | - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
288 | - posterior_threshold (float): Threshold for posterior probability.
289 | - posterior_alpha (float): Scaling factor for the threshold.
290 |
291 | Returns:
292 | - best_candidate (torch.Tensor): Index of the chosen best candidate.
293 | - accept_length (int): Length of the accepted candidate sequence.
294 | """
295 | # Greedy decoding based on temperature value
296 | if temperature == 0:
297 | # Find the tokens that match the maximum logits for each position in the sequence
298 | posterior_mask = (
299 | candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1)
300 | ).int()
301 | candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
302 | accept_length = candidates_accept_length.max()
303 | # Choose the best candidate
304 | if accept_length == 0:
305 | # Default to the first candidate if none are accepted
306 | best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
307 | else:
308 | best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
309 | return best_candidate, accept_length
310 | # Calculate posterior probabilities and thresholds for candidate selection
311 | posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
312 | candidates_prob = torch.gather(
313 | posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
314 | ).squeeze(-1)
315 | posterior_entropy = -torch.sum(
316 | posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
317 | ) # torch.sum(torch.log(*)) is faster than torch.prod
318 | threshold = torch.minimum(
319 | torch.ones_like(posterior_entropy) * posterior_threshold,
320 | torch.exp(-posterior_entropy) * posterior_alpha,
321 | )
322 | posterior_mask = candidates_prob > threshold
323 | candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
324 |
325 | # Choose the best candidate based on the evaluated posterior probabilities
326 | accept_length = candidates_accept_length.max()
327 | if accept_length == 0:
328 | # If no candidates are accepted, just choose the first one
329 | best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
330 | else:
331 | best_candidates = torch.where(candidates_accept_length == accept_length)[0]
332 | # Accept the best one according to likelihood
333 | likelihood = torch.sum(
334 | torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
335 | )
336 | best_candidate = best_candidates[torch.argmax(likelihood)]
337 | return best_candidate, accept_length
338 |
339 |
340 | def update_inference_inputs(
341 | input_ids,
342 | candidates,
343 | best_candidate,
344 | accept_length,
345 | retrieve_indices,
346 | outputs,
347 | logits,
348 | medusa_logits,
349 | new_token,
350 | past_key_values_data,
351 | current_length_data,
352 | ):
353 | """
354 | Update the input sequences and relevant tensors based on the selected best candidate from the inference results.
355 |
356 | Args:
357 | - input_ids (torch.Tensor): Current input token sequences.
358 | - candidates (torch.Tensor): Candidate token sequences generated in the current step.
359 | - best_candidate (int): Index of the chosen best candidate.
360 | - accept_length (int): Length of the accepted candidate sequence.
361 | - retrieve_indices (torch.Tensor): Indices to map tree to a cartesian product.
362 | - outputs, logits, medusa_logits (torch.Tensor): Model's outputs from the previous inference step.
363 | - new_token (int): Counter for the new tokens added during inference.
364 | - past_key_values_data (torch.Tensor): Tensor containing past hidden states for the transformer model.
365 | - current_length_data (torch.Tensor): Tensor containing the current length of sequences in the batch.
366 |
367 | Returns:
368 | - input_ids (torch.Tensor): Updated input token sequences.
369 | - logits (torch.Tensor): Updated logits.
370 | - medusa_logits (torch.Tensor): Updated medusa logits.
371 | - new_token (int): Updated counter for the new tokens added.
372 | """
373 | # Calculate the starting position for new tokens based on the previous input length
374 | prev_input_len = input_ids.shape[1]
375 | # Map the best candidate indices to the original indices in the sequence
376 | select_indices = (
377 | retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len
378 | )
379 | # Append the tokens from the best candidate to the input sequence
380 | input_ids = torch.cat(
381 | [input_ids, candidates[None, best_candidate, : accept_length + 1]], dim=-1
382 | )
383 | # Update the past key values based on the selected tokens
384 | # Source tensor that contains relevant past information based on the selected candidate
385 | tgt = past_key_values_data[..., select_indices, :]
386 | # Destination tensor where the relevant past information will be stored
387 | dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :]
388 | # Copy relevant past information from the source to the destination
389 | dst.copy_(tgt, non_blocking=True)
390 |
391 | # Update the current length tensor (currently only support batch size is 1)
392 | current_length_data.fill_(prev_input_len + tgt.shape[-2])
393 |
394 | # Extract logits and medusa logits for the accepted tokens
395 | logits = logits[None, best_candidate, accept_length : accept_length + 1]
396 | medusa_logits = medusa_logits[
397 | :, None, best_candidate, accept_length : accept_length + 1
398 | ]
399 | # Update the new token counter
400 | new_token += accept_length + 1
401 |
402 | return input_ids, logits, medusa_logits, new_token
403 |
--------------------------------------------------------------------------------
/medusa/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/train/__init__.py
--------------------------------------------------------------------------------
/medusa/train/train_legacy.py:
--------------------------------------------------------------------------------
1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
2 | #
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
18 |
19 | from dataclasses import dataclass, field
20 | import json
21 | import math
22 | import pathlib
23 | from typing import Dict, Optional, Sequence
24 |
25 | import numpy as np
26 | import torch
27 | from torch import nn
28 | from torch.utils.data import Dataset
29 | import transformers
30 | from transformers import Trainer, BitsAndBytesConfig
31 | from transformers.trainer_pt_utils import LabelSmoother
32 | from safetensors.torch import save_file
33 |
34 | from fastchat.conversation import SeparatorStyle
35 | from fastchat.model.model_adapter import get_conversation_template
36 | from torch.nn import CrossEntropyLoss
37 | from torch.nn import functional as F
38 | import os
39 | from medusa.model.medusa_model_legacy import MedusaModel, MedusaConfig
40 |
41 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
42 |
43 |
44 | # Customized for training Medusa heads
45 | class CustomizedTrainer(Trainer):
46 | def compute_loss(self, model, inputs, return_outputs=False):
47 | """
48 | Compute the training loss for the model.
49 |
50 | Args:
51 | model (torch.nn.Module): The model for which to compute the loss.
52 | inputs (dict): The input data, including input IDs, attention mask, and labels.
53 | return_outputs (bool): Whether to return model outputs along with the loss.
54 |
55 | Returns:
56 | Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs.
57 | """
58 | # DDP will give us model.module
59 | if hasattr(model, "module"):
60 | medusa = model.module.medusa
61 | else:
62 | medusa = model.medusa
63 |
64 | logits = model(
65 | input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
66 | )
67 | labels = inputs["labels"]
68 | # Shift so that tokens < n predict n
69 | loss = 0
70 | loss_fct = CrossEntropyLoss()
71 | log = {}
72 | for i in range(medusa):
73 | medusa_logits = logits[i, :, : -(2 + i)].contiguous()
74 | medusa_labels = labels[..., 2 + i :].contiguous()
75 | medusa_logits = medusa_logits.view(-1, logits.shape[-1])
76 | medusa_labels = medusa_labels.view(-1)
77 | medusa_labels = medusa_labels.to(medusa_logits.device)
78 | loss_i = loss_fct(medusa_logits, medusa_labels)
79 | loss += loss_i
80 | not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
81 | medusa_labels = medusa_labels[not_ignore]
82 |
83 | # Add top-k accuracy
84 | for k in range(1, 2):
85 | _, topk = medusa_logits.topk(k, dim=-1)
86 | topk = topk[not_ignore]
87 | correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
88 | log[f"medusa{i}_top{k}"] = correct.float().mean().item()
89 |
90 | log[f"medusa{i}_loss"] = loss_i.item()
91 | self.log(log)
92 | return (loss, logits) if return_outputs else loss
93 |
94 |
95 | @dataclass
96 | class ModelArguments:
97 | model_name_or_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.3")
98 | load_in_4bit: bool = field(
99 | default=False,
100 | metadata={"help": "Load in 4 bit."},
101 | )
102 | load_in_8bit: bool = field(
103 | default=False,
104 | metadata={"help": "Load in 8 bit."},
105 | )
106 |
107 |
108 | @dataclass
109 | class DataArguments:
110 | data_path: str = field(
111 | default="sharegpt_clean.json",
112 | metadata={"help": "Path to the training data."},
113 | )
114 | eval_data_path: str = field(
115 | default=None, metadata={"help": "Path to the evaluation data."}
116 | )
117 | lazy_preprocess: bool = True
118 |
119 |
120 | @dataclass
121 | class TrainingArguments(transformers.TrainingArguments):
122 | cache_dir: Optional[str] = field(default=None)
123 | report_to: Optional[str] = None
124 | optim: str = field(default="adamw_torch")
125 | model_max_length: int = field(
126 | default=2048,
127 | metadata={
128 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
129 | },
130 | )
131 | medusa_num_heads: int = field(
132 | default=1,
133 | metadata={"help": "Number of Medusa heads."},
134 | )
135 | medusa_num_layers: int = field(
136 | default=1,
137 | metadata={"help": "Number of layers for each Medusa head."},
138 | )
139 |
140 |
141 | local_rank = None
142 |
143 |
144 | def rank0_print(*args):
145 | if local_rank == 0:
146 | print(*args)
147 |
148 |
149 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
150 | """
151 | Save the model's state dictionary to a specified directory.
152 |
153 | Args:
154 | trainer (transformers.Trainer): The Hugging Face Trainer object.
155 | output_dir (str): The directory where the model state dictionary will be saved.
156 | """
157 | state_dict = trainer.model.state_dict()
158 | if trainer.args.should_save:
159 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
160 | del state_dict
161 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
162 |
163 | def preprocess(
164 | sources,
165 | tokenizer: transformers.PreTrainedTokenizer,
166 | ) -> Dict:
167 | """
168 | Preprocesses conversation data and tokenizes it for model input.
169 |
170 | Args:
171 | sources: A list of conversation sources.
172 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for tokenization.
173 |
174 | Returns:
175 | Dict: A dictionary containing tokenized inputs, labels, and attention mask.
176 | """
177 |
178 | # Apply prompt templates
179 | conversations = []
180 | prompts = []
181 | # # import pdb; pdb.set_trace()
182 | for i, conversation in enumerate(sources):
183 | prompt = tokenizer.apply_chat_template(conversation, tokenize=False)
184 | prompts.append(prompt)
185 | conversations.append(conversation)
186 |
187 | # Tokenize conversations
188 | encoding = tokenizer(
189 | prompts,
190 | return_tensors="pt",
191 | padding="max_length",
192 | truncation=True,
193 | return_offsets_mapping=True,
194 | )
195 | # Set everything to be ignored, except the assistant part
196 | targets = torch.full_like(encoding.input_ids, IGNORE_TOKEN_ID)
197 | input_ids = encoding.input_ids
198 |
199 | # Mask targets. Only compute loss on the assistant outputs.
200 | for conv_index, (conversation, target, prompt) in enumerate(zip(conversations, targets, prompts)):
201 |
202 | for turn in conversation:
203 | if turn["role"] == "assistant":
204 | content = turn["content"]
205 | # Unfortunate strip() necessary because chat templates are doing the same.
206 | start = prompt.index(content.strip())
207 | stop = start + len(content)
208 | indices= []
209 | for tok_index, (tok_start, tok_stop) in enumerate(encoding.offset_mapping[conv_index]):
210 | if tok_stop >= start or tok_start < tok_stop:
211 | indices.append(tok_index)
212 | target[indices] = encoding.input_ids[conv_index][indices]
213 |
214 |
215 | return dict(
216 | input_ids=input_ids,
217 | labels=targets,
218 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
219 | )
220 |
221 |
222 | class SupervisedDataset(Dataset):
223 | """Dataset for supervised fine-tuning.
224 |
225 | Args:
226 | raw_data (list): A list of raw data examples.
227 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
228 | """
229 |
230 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
231 | super(SupervisedDataset, self).__init__()
232 |
233 | rank0_print("Formatting inputs...")
234 | sources = raw_data
235 | data_dict = preprocess(sources, tokenizer)
236 |
237 | self.input_ids = data_dict["input_ids"]
238 | self.labels = data_dict["labels"]
239 | self.attention_mask = data_dict["attention_mask"]
240 |
241 | def __len__(self):
242 | return len(self.input_ids)
243 |
244 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
245 | return dict(
246 | input_ids=self.input_ids[i],
247 | labels=self.labels[i],
248 | attention_mask=self.attention_mask[i],
249 | )
250 |
251 |
252 | class LazySupervisedDataset(Dataset):
253 | """Lazy dataset for supervised fine-tuning.
254 |
255 | This dataset loads data on-the-fly when requested, which can be memory-efficient but slower.
256 |
257 | Args:
258 | raw_data (list): A list of raw data examples.
259 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
260 | """
261 |
262 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
263 | super(LazySupervisedDataset, self).__init__()
264 | self.tokenizer = tokenizer
265 |
266 | rank0_print("Formatting inputs...Skip in lazy mode")
267 | self.tokenizer = tokenizer
268 | self.raw_data = raw_data
269 | self.cached_data_dict = {}
270 |
271 | def __len__(self):
272 | return len(self.raw_data)
273 |
274 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
275 | if i in self.cached_data_dict:
276 | return self.cached_data_dict[i]
277 |
278 | ret = preprocess([self.raw_data[i]], self.tokenizer)
279 | ret = dict(
280 | input_ids=ret["input_ids"][0],
281 | labels=ret["labels"][0],
282 | attention_mask=ret["attention_mask"][0],
283 | )
284 | self.cached_data_dict[i] = ret
285 |
286 | return ret
287 |
288 |
289 | def make_supervised_data_module(
290 | tokenizer: transformers.PreTrainedTokenizer, data_args
291 | ) -> Dict:
292 | """Make dataset and collator for supervised fine-tuning.
293 |
294 | Args:
295 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing.
296 | data_args: Data arguments.
297 |
298 | Returns:
299 | dict: A dictionary containing train and eval datasets.
300 | """
301 | dataset_cls = (
302 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
303 | )
304 | rank0_print("Loading data...")
305 |
306 | train_json = json.load(open(data_args.data_path, "r"))
307 | train_dataset = dataset_cls(train_json, tokenizer=tokenizer)
308 |
309 | if data_args.eval_data_path:
310 | eval_json = json.load(open(data_args.eval_data_path, "r"))
311 | eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
312 | else:
313 | eval_dataset = None
314 |
315 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
316 |
317 |
318 | def train():
319 | global local_rank
320 |
321 | parser = transformers.HfArgumentParser(
322 | (ModelArguments, DataArguments, TrainingArguments)
323 | )
324 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
325 | local_rank = training_args.local_rank
326 |
327 | # Set RoPE scaling factor
328 | config = transformers.AutoConfig.from_pretrained(
329 | model_args.model_name_or_path,
330 | cache_dir=training_args.cache_dir,
331 | )
332 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
333 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
334 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
335 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
336 | config.use_cache = False
337 |
338 | tokenizer = transformers.AutoTokenizer.from_pretrained(
339 | model_args.model_name_or_path,
340 | cache_dir=training_args.cache_dir,
341 | model_max_length=training_args.model_max_length,
342 | padding_side="right",
343 | use_fast=True,
344 | )
345 | tokenizer.pad_token = tokenizer.unk_token
346 | tokenizer.pad_token = tokenizer.eos_token
347 |
348 | # Making sure the tokenizer works before loading the model.
349 | print(tokenizer(["This is a test", "secondary"], padding=True))
350 | print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}]))
351 |
352 | # Load model and tokenizer
353 | model = transformers.AutoModelForCausalLM.from_pretrained(
354 | model_args.model_name_or_path,
355 | config=config,
356 | cache_dir=training_args.cache_dir,
357 | torch_dtype=torch.bfloat16,
358 | )
359 |
360 | # Freeze the base model
361 | for param in model.base_model.parameters():
362 | param.requires_grad = False
363 |
364 | # Add Medusa heads
365 | medusa_lm_head = MedusaModel(
366 | model,
367 | medusa_num_heads=training_args.medusa_num_heads,
368 | medusa_num_layers=training_args.medusa_num_layers,
369 | base_model_name_or_path=model_args.model_name_or_path,
370 | )
371 |
372 | # Format output dir
373 | training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"
374 |
375 |
376 | # Load data
377 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
378 |
379 | # Generate Medusa config for pushing to HF hub
380 | medusa_config = MedusaConfig(
381 | medusa_num_heads=training_args.medusa_num_heads,
382 | medusa_num_layers=training_args.medusa_num_layers,
383 | base_model_name_or_path=model_args.model_name_or_path,
384 | version="2"
385 | )
386 |
387 | # Save Medusa config
388 | medusa_config.save_pretrained(training_args.output_dir)
389 |
390 | # Start trainner
391 | trainer = CustomizedTrainer(
392 | model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module
393 | )
394 |
395 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
396 | trainer.train(resume_from_checkpoint=True)
397 | else:
398 | trainer.train()
399 | model.config.use_cache = True
400 | # trainer.save_state()
401 | # safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
402 | # Save MedusaHead seperately
403 | if hasattr(medusa_lm_head, "module"):
404 | lm_head = medusa_lm_head.module.medusa_head
405 | else:
406 | lm_head = medusa_lm_head.medusa_head
407 | import deepspeed
408 | with deepspeed.zero.GatheredParameters(lm_head.parameters()):
409 | state_dict = lm_head.state_dict()
410 |
411 | # Save Medusa heads
412 | if local_rank == 0:
413 | # Modify the tokenizer internal state before saving.
414 | tokenizer.encode("Test", truncation=None, padding="do_not_pad")
415 | tokenizer.save_pretrained(training_args.output_dir)
416 | save_file(
417 | state_dict,
418 | os.path.join(training_args.output_dir, "medusa_lm_head.safetensors"),
419 | )
420 |
421 |
422 | if __name__ == "__main__":
423 | train()
424 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "medusa-llm"
7 | version = "1.0"
8 | description = "Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads"
9 | readme = "README.md"
10 | requires-python = ">=3.9"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "fschat", "torch", "transformers>=4.34", "accelerate", "sentencepiece", "protobuf"
17 | ]
18 |
19 | [project.optional-dependencies]
20 | train = ["bitsandbytes", "wandb", "scipy"]
21 |
22 | [project.urls]
23 | "Homepage" = "https://github.com/FasterDecoding/Medusa"
24 | "Blog" = "https://sites.google.com/view/medusa-llm"
25 |
26 | [tool.setuptools.packages.find]
27 | exclude = ["assets*", "notebooks*", "scripts*", "llm_judge"]
--------------------------------------------------------------------------------
/scripts/train_vicuna_33b_8bit.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-33b-v1.3 \
2 | --data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
3 | --bf16 True \
4 | --output_dir test \
5 | --num_train_epochs 1 \
6 | --per_device_train_batch_size 8 \
7 | --per_device_eval_batch_size 8 \
8 | --gradient_accumulation_steps 4 \
9 | --evaluation_strategy "no" \
10 | --save_strategy "no" \
11 | --learning_rate 1e-3 \
12 | --weight_decay 0.0 \
13 | --warmup_ratio 0.1 \
14 | --lr_scheduler_type "cosine" \
15 | --logging_steps 1 \
16 | --tf32 True \
17 | --model_max_length 2048 \
18 | --lazy_preprocess True \
19 | --medusa_num_heads 3 \
20 | --medusa_num_layers 1 \
21 | --load_in_8bit
--------------------------------------------------------------------------------
/scripts/train_vicuna_7b.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
2 | --data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \
3 | --bf16 True \
4 | --output_dir test \
5 | --num_train_epochs 1 \
6 | --per_device_train_batch_size 8 \
7 | --per_device_eval_batch_size 8 \
8 | --gradient_accumulation_steps 4 \
9 | --evaluation_strategy "no" \
10 | --save_strategy "no" \
11 | --learning_rate 1e-3 \
12 | --weight_decay 0.0 \
13 | --warmup_ratio 0.1 \
14 | --lr_scheduler_type "cosine" \
15 | --logging_steps 1 \
16 | --tf32 True \
17 | --model_max_length 2048 \
18 | --lazy_preprocess True \
19 | --medusa_num_heads 3 \
20 | --medusa_num_layers 1
--------------------------------------------------------------------------------
/simple_gradio_interface.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import time
3 | import torch
4 | from medusa.model.medusa_model import MedusaModel
5 | from fastchat.model.model_adapter import get_conversation_template
6 |
7 | # Global variables
8 | chat_history = ""
9 | model = None
10 | tokenizer = None
11 | conv = None
12 |
13 |
14 | def load_model_function(model_name, load_in_8bit=False, load_in_4bit=False):
15 | model_name = model_name or "FasterDecoding/medusa-vicuna-7b-v1.3"
16 | global model, tokenizer, conv
17 |
18 | try:
19 | model = MedusaModel.from_pretrained(
20 | model_name,
21 | torch_dtype=torch.float16,
22 | low_cpu_mem_usage=True,
23 | device_map="auto",
24 | load_in_8bit=load_in_8bit,
25 | load_in_4bit=load_in_4bit
26 | )
27 | tokenizer = model.get_tokenizer()
28 | conv = get_conversation_template("vicuna")
29 | return "Model loaded successfully!"
30 | except:
31 | return "Error loading the model. Please check the model name and try again."
32 |
33 |
34 | def reset_conversation():
35 | """
36 | Reset the global conversation and chat history
37 | """
38 | global conv, chat_history
39 | conv = get_conversation_template("vicuna")
40 | chat_history = ""
41 |
42 |
43 | def medusa_chat_interface(user_input, temperature, max_steps, no_history):
44 | global model, tokenizer, conv, chat_history
45 |
46 | # Reset the conversation if no_history is checked
47 | if no_history:
48 | reset_conversation()
49 |
50 | if not model or not tokenizer:
51 | return "Error: Model not loaded!", chat_history
52 |
53 | chat_history += "\nYou: " + user_input
54 | conv.append_message(conv.roles[0], user_input)
55 | conv.append_message(conv.roles[1], '')
56 | prompt = conv.get_prompt()
57 |
58 | input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.base_model.device)
59 |
60 | outputs = model.medusa_generate(input_ids, temperature=temperature, max_steps=max_steps)
61 | response = ""
62 | for output in outputs:
63 | response = output['text']
64 | yield response, chat_history
65 | time.sleep(0.01)
66 |
67 | chat_history += "\nMedusa: " + response.strip()
68 |
69 | return response, chat_history
70 |
71 |
72 | if __name__ == "__main__":
73 | load_model_interface = gr.Interface(
74 | load_model_function,
75 | [
76 | gr.components.Textbox(placeholder="FasterDecoding/medusa-vicuna-7b-v1.3", label="Model Name"),
77 | gr.components.Checkbox(label="Use 8-bit Quantization"),
78 | gr.components.Checkbox(label="Use 4-bit Quantization"),
79 | ],
80 | gr.components.Textbox(label="Model Load Status", type="text"),
81 | description="Load Medusa Model",
82 | title="Medusa Model Loader",
83 | live=False,
84 | api_name="load_model"
85 | )
86 |
87 | # Chat Interface
88 | chat_interface = gr.Interface(
89 | medusa_chat_interface,
90 | [
91 | gr.components.Textbox(placeholder="Ask Medusa...", label="User Input"),
92 | gr.components.Slider(minimum=0, maximum=1.5, label="Temperature"),
93 | gr.components.Slider(minimum=50, maximum=1000, label="Max Steps"),
94 | gr.components.Checkbox(label="No History"),
95 | ],
96 | [
97 | gr.components.Textbox(label="Medusa's Response", type="text"),
98 | gr.components.Textbox(label="Chat History", type="text")
99 | ],
100 | live=False,
101 | description="Chat with Medusa",
102 | title="Medusa Chatbox",
103 | api_name="chat"
104 | )
105 |
106 | # Combine the interfaces in a TabbedInterface
107 | combined_interface = gr.TabbedInterface([load_model_interface, chat_interface],
108 | ["Load Model", "Chat"])
109 |
110 | # Launch the combined interface
111 | combined_interface.queue().launch()
112 |
--------------------------------------------------------------------------------