├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── ROADMAP.md ├── assets ├── category_speedup.png ├── logo.png ├── medusa_acc.csv ├── medusa_choices.png ├── medusa_demo.gif ├── medusa_pipeline.jpg ├── medusa_speedup_cmp.jpg └── size_speedup.png ├── create_data.py ├── data_generation ├── README.md ├── convert_to_sharegpt.py └── generate.py ├── deepspeed.json ├── llm_judge ├── README.md ├── data │ ├── judge_prompts.jsonl │ └── mt_bench │ │ ├── model_answer │ │ ├── medusa-vicuna-13b-v1.3-1-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── medusa-vicuna-13b-v1.3-2-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── medusa-vicuna-13b-v1.3-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── medusa-vicuna-33b-v1.3-1-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── medusa-vicuna-33b-v1.3-2-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── medusa-vicuna-33b-v1.3-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── medusa-vicuna-7b-v1.3-1-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── medusa-vicuna-7b-v1.3-2-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ └── medusa-vicuna-7b-v1.3-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3.jsonl │ │ ├── model_judgment │ │ └── gpt-4_single.jsonl │ │ ├── question.jsonl │ │ └── reference_answer │ │ └── gpt-4.jsonl ├── gen_judgement.py ├── gen_model_answer_baseline.py ├── gen_model_answer_baseline_inf_only.py ├── gen_model_answer_huggingface.py ├── gen_model_answer_medusa.py ├── gen_model_answer_medusa_inf_only.py ├── gen_model_answer_medusa_legacy.py └── show_result.py ├── medusa ├── __init__.py ├── eval │ ├── README.md │ ├── gen_results.py │ └── heads_accuracy.py ├── hf_utils.py ├── inference │ ├── __init__.py │ └── cli.py ├── model │ ├── __init__.py │ ├── kv_cache.py │ ├── medusa_choices.py │ ├── medusa_model.py │ ├── medusa_model_legacy.py │ ├── medusa_model_new.py │ ├── modeling_llama_kv.py │ ├── modeling_llama_kv_legacy.py │ ├── modeling_mistral_kv.py │ ├── utils.py │ └── utils_legacy.py └── train │ ├── __init__.py │ └── train_legacy.py ├── notebooks ├── medusa_configuration_explained.ipynb ├── medusa_inference_explained.ipynb └── medusa_introduction.ipynb ├── pyproject.toml ├── scripts ├── train_vicuna_33b_8bit.sh └── train_vicuna_7b.sh └── simple_gradio_interface.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # wandb 163 | .wandb/ 164 | wandb/ 165 | 166 | ShareGPT_Vicuna_unfiltered/ 167 | 168 | test_medusa* 169 | 170 | # test 171 | notebooks/test*.ipynb 172 | notebooks/*.pdf 173 | llm_judge/*.sh 174 | llm_judge/data/mt_bench_test 175 | llm_judge/data/mt_bench_test_rs 176 | data 177 | medusa/eval/*.sh -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | references: 4 | - type: article 5 | authors: 6 | - family-names: Cai 7 | given-names: Tianle 8 | - family-names: Li 9 | given-names: Yuhong 10 | - family-names: Geng 11 | given-names: Zhengyang 12 | - family-names: Peng 13 | given-names: Hongwu 14 | - family-names: Lee 15 | given-names: Jason D. 16 | - family-names: Chen 17 | given-names: Deming 18 | - family-names: Dao 19 | given-names: Tri 20 | title: "Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads" 21 | year: 2024 22 | journal: "arXiv preprint arXiv: 2401.10774" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Medusa

 Medusa: Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads

2 | 3 |

4 | | Blog | Report | Roadmap | 6 |

7 | 8 | --- 9 | *News* 🔥 10 | - [2024/1] Medusa technical report is now available on [arXiv](https://arxiv.org/abs/2401.10774). We've added multiple new features, including Medusa-2 recipe for full-model training, self-distillation for adding Medusa to any fine-tuned LLM, etc. The new results show a 2.2-3.6x speedup over the original model on a range of LLMs. 11 | 12 | --- 13 | ## Introduction 14 | 15 | Medusa is a simple framework that democratizes the acceleration techniques for LLM generation with multiple decoding heads. 16 | 17 |
18 | 19 | 20 | 21 |
22 |
23 | Medusa-1 on Vicuna-7b. 24 |
25 |
26 |
27 | 28 | 29 | We aim to tackle the three pain points of popular acceleration techniques like speculative decoding: 30 | 31 | - Requirement of a good draft model. 32 | - System complexity. 33 | - Inefficiency when using sampling-based generation. 34 | 35 | 36 |
37 | 38 | 39 | 40 |
41 |
42 | Medusa adds extra "heads" to LLMs to predict multiple future tokens simultaneously. When augmenting a model with Medusa, the original model stays untouched, and only the new heads are fine-tuned during training. During generation, these heads each produce multiple likely words for the corresponding position. These options are then combined and processed using a tree-based attention mechanism. Finally, a typical acceptance scheme is employed to pick the longest plausible prefix from the candidates for further decoding. 43 |
44 |
45 |
46 | 47 | We aim to solve the challenges associated with speculative decoding by implementing the following ideas: 48 | 49 | - Instead of introducing a new model, we train multiple decoding heads on the *same* model. 50 | - The training is parameter-efficient so that even the "GPU-Poor" can do it. And since there is no additional model, there is no need to adjust the distributed computing setup. 51 | - Relaxing the requirement of matching the distribution of the original model makes the non-greedy generation even faster than greedy decoding. 52 | 53 | In the initial release, our primary focus is on optimizing Medusa for a batch size of 1—a setting commonly utilized for local model hosting. In this configuration, Medusa delivers approximately a 2x speed increase across a range of Vicuna models. We are actively working to extend Medusa's capabilities by integrating it into additional inference frameworks, with the aim of achieving even greater performance gains and extending Medusa to broader settings. 54 | 55 |

56 | 57 | 58 | 59 |

60 | 61 | In the updated version, we add support for full-model training, called Medusa-2 (compared to Medusa-1, which only trains the new heads), which requires a special recipe that adds the speculative prediction ability while keeping the original model's performance. 62 | 63 | We also add support for self-distillation, which allows us to add Medusa to any fine-tuned LLM without requiring the availability of the original training data. 64 | 65 | ## Contents 66 | - [Introduction](#introduction) 67 | - [Contents](#contents) 68 | - [Installation](#installation) 69 | - [Method 1: With pip (may not be the latest version)](#method-1-with-pip-may-not-be-the-latest-version) 70 | - [Method 2: From the source (recommended)](#method-2-from-the-source-recommended) 71 | - [Model Weights](#model-weights) 72 | - [Inference](#inference) 73 | - [Training](#training) 74 | - [Training (legacy)](#training-legacy) 75 | - [Push to Hugging Face Hub](#push-to-hugging-face-hub) 76 | - [Citation](#citation) 77 | - [Codebase Guide](#codebase-guide) 78 | - [Community Adoption](#community-adoption) 79 | - [Contributing](#contributing) 80 | - [Acknowledgements](#acknowledgements) 81 | 82 | ## Installation 83 | ### Method 1: With pip (may not be the latest version) 84 | ```bash 85 | pip install medusa-llm 86 | ``` 87 | ### Method 2: From the source (recommended) 88 | ```bash 89 | git clone https://github.com/FasterDecoding/Medusa.git 90 | cd Medusa 91 | pip install -e . 92 | ``` 93 | 94 | ### Model Weights 95 | #### Medusa-1 96 | | Size | Chat Command | Hugging Face Repo | 97 | | ---- | --------------------------------------------- | --------------------------------------------------------------------- | 98 | | 7B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-7b-v1.3` | [FasterDecoding/medusa-vicuna-7b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-7b-v1.3) | 99 | | 13B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-13b-v1.3` | [FasterDecoding/medusa-vicuna-13b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-13b-v1.3) | 100 | | 33B | `python -m medusa.inference.cli --model FasterDecoding/medusa-vicuna-33b-v1.3` | [FasterDecoding/medusa-vicuna-33b-v1.3](https://huggingface.co/FasterDecoding/medusa-vicuna-33b-v1.3) | 101 | 102 | #### Medusa-2 103 | | Size | Chat Command | Hugging Face Repo | 104 | | ---- | --------------------------------------------- | --------------------------------------------------------------------- | 105 | | Zephyr-7B-Beta | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-zephyr-7b-beta` | [FasterDecoding/medusa-1.0-zephyr-7b-beta](https://huggingface.co/FasterDecoding/medusa-1.0-zephyr-7b-beta) | 106 | | Vicuna-7B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-7b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-7b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-7b-v1.5) | 107 | | Vicuna-13B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-13b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-13b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-13b-v1.5) | 108 | | Vicuna-33B-v1.5 | `python -m medusa.inference.cli --model FasterDecoding/medusa-1.0-vicuna-33b-v1.5` | [FasterDecoding/medusa-1.0-vicuna-33b-v1.5](https://huggingface.co/FasterDecoding/medusa-1.0-vicuna-33b-v1.5) | 109 | 110 | 111 | ### Inference 112 | We currently support single-GPU inference with a batch size of 1, which is the most common setup for local model hosting. We are actively working to extend Medusa's capabilities by integrating it into other inference frameworks; please don't hesitate to reach out if you are interested in contributing to this effort. 113 | 114 | You can use the following command to launch a CLI interface: 115 | ```bash 116 | CUDA_VISIBLE_DEVICES=0 python -m medusa.inference.cli --model [path of medusa model] 117 | ``` 118 | You can also pass `--load-in-8bit` or `--load-in-4bit` to load the base model in quantized format. If you download the base model elsewhere, you may override base model name or path with `--base-model [path of base model]`. 119 | 120 | ### Training 121 | In the updated version, we use the amazing [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl) library to manage the training process. Please refer to our [fork](https://github.com/ctlllll/axolotl) for the training code. The major code modifications are in [`src/axolotl/utils/models.py`](https://github.com/ctlllll/axolotl/blob/main/src/axolotl/utils/models.py). The training configs can be found in [`examples/medusa`](https://github.com/ctlllll/axolotl/tree/main/examples/medusa). A typical training command is as follows: 122 | ```bash 123 | accelerate launch -m axolotl.cli.train examples/medusa/your_config.yml 124 | ``` 125 | 126 | The data preparation code for self-distillation can be found in [`data_generation` folder](data_generation) of the current repo. For other datasets, you can directly download the data from the corresponding Hugging Face dataset repo. 127 | 128 | ### Training on various architectures 129 | *The following instructions are for the initial release of Medusa, it provides a minimal example of how to train a Medusa-1 model. For the updated version, please refer to the previous section.* 130 | 131 | For training, please install: 132 | ```bash 133 | pip install -e ".[train]" 134 | ``` 135 | #### Prepare the data 136 | We take a public version of the ShareGPT dataset, which is a subset of the Vicuna training data. For other models, you can use the corresponding training dataset. 137 | ```bash 138 | git clone https://huggingface.co/datasets/Aeala/ShareGPT_Vicuna_unfiltered 139 | ``` 140 | Remark: If you haven't installed `git-lfs`, please install it before cloning: 141 | ```bash 142 | git lfs install 143 | ``` 144 | 145 | #### Adapt the data to the model you want to enable medusa on. 146 | 147 | Start by launch an inference server you like that will run the model you want to train on. 148 | Let's use [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) as an example. 149 | 150 | For instance you can use [text-generation-inference](https://github.com/huggingface/text-generation-inference), which you 151 | can also use after you've trained the medusa heads. 152 | 153 | ``` 154 | model=mistralai/Mistral-7B-Instruct-v0.2 155 | volume=$PWD/data # share a volume with the Docker container to avoid downloading weights every run 156 | docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:latest --model-id $model --input-length 4000 --max-total-tokens 4096 --max-batch-prefill-tokens 4000 157 | ``` 158 | The sequences in shareGPT are relatively long for some, so make sure you can infer on those. If you do not have enough room, the script will simply ignore those long conversation. 159 | It shouldn't impact too much downstream performance, but more data is always better. 160 | You can use various tradeoffs to [speed up inference](https://huggingface.co/docs/text-generation-inference/index) but the defaults show be good enough in most cases. 161 | 162 | ``` 163 | python create_data.py --input-filename ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json --output-filename mistral.json 164 | ``` 165 | 166 | #### Train the model 167 | We follow the training setup from [FastChat](https://github.com/lm-sys/FastChat#fine-tuning), but with a much larger learning rate because we freeze the original model and only train the new heads. Here is the training command for the Vicuna-7b model on 4 GPUs. Since we are only training the new heads, the training does not require a lot of memory, and only data parallelism is needed. You can modify the script to fit your own setup. For larger models, we use the same setup. You can also use `--load_in_8bit` or `--load_in_4bit` to load the base model in quantized format. 168 | ```bash 169 | torchrun --nproc_per_node=4 medusa/train/train_legacy.py --model_name_or_path mistralai/Mistral-7B-Instruct-v0.2 \ 170 | --data_path mistral.json \ 171 | --bf16 True \ 172 | --output_dir test \ 173 | --num_train_epochs 2 \ 174 | --per_device_train_batch_size 8 \ 175 | --per_device_eval_batch_size 8 \ 176 | --gradient_accumulation_steps 4 \ 177 | --evaluation_strategy "no" \ 178 | --save_strategy "no" \ 179 | --learning_rate 1e-3 \ 180 | --weight_decay 0.0 \ 181 | --warmup_ratio 0.1 \ 182 | --lr_scheduler_type "cosine" \ 183 | --logging_steps 1 \ 184 | --tf32 True \ 185 | --model_max_length 2048 \ 186 | --lazy_preprocess True \ 187 | --medusa_num_heads 3 \ 188 | --medusa_num_layers 1 \ 189 | --deepspeed deepspeed.json 190 | ``` 191 | ### Push to Hugging Face Hub 192 | You can use the following command to push your model to the Hugging Face Hub: 193 | ```bash 194 | python -m medusa.hf_utils --folder [path of the model folder] --repo [name of the repo] 195 | ``` 196 | 197 | ## Citation 198 | ```bibtex 199 | @article{cai2024medusa, 200 | title = {Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads}, 201 | author = {Tianle Cai and Yuhong Li and Zhengyang Geng and Hongwu Peng and Jason D. Lee and Deming Chen and Tri Dao}, 202 | year = {2024}, 203 | journal = {arXiv preprint arXiv: 2401.10774} 204 | } 205 | ``` 206 | 207 | ## Codebase Guide 208 | `medusa/model/medusa_model.py` is the key file for Medusa. It contains the `MedusaModel` class, which is a wrapper of the original model and the new heads. This class also has an implementation of a streaming generation method. If you want to dive into the details of Medusa, this is the place to start. 209 | 210 | We also provide some illustrative notebooks in `notebooks/` to help you understand the codebase. 211 | 212 | ## Community Adoption 213 | We are super excited to see that Medusa has been adopted by many open-source projects. Here is an (incomplete) list: 214 | - [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/medusa) 215 | - [TGI](https://github.com/huggingface/text-generation-inference/blob/main/server/text_generation_server/utils/medusa.py) 216 | - [RTP-LLM](https://github.com/alibaba/rtp-llm/blob/main/docs/SpeculativeDecoding-Tutroial.md#medusa-decoding) 217 | 218 | We are grateful to the authors for their contributions to the community and sincerely hope that Medusa can help accelerate the development of LLMs. If you are using Medusa in your project, please let us know, and we will add your project to the list. 219 | 220 | ## Contributing 221 | We welcome community contributions to Medusa. If you have an idea for how to improve it, please open an issue to discuss it with us. When submitting a pull request, please ensure that your changes are well-tested. Please split each major change into a separate pull request. We also have a [Roadmap](ROADMAP.md) summarizing our future plans for Medusa. Don't hesitate to reach out if you are interested in contributing to any of the items on the roadmap. 222 | 223 | ## Acknowledgements 224 | This codebase is influenced by remarkable projects from the LLM community, including [FastChat](https://github.com/lm-sys/FastChat), [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/), [vllm](https://github.com/vllm-project/vllm), [axolotl](https://github.com/OpenAccess-AI-Collective/axolotl). 225 | 226 | This project is supported by [Together AI](https://together.ai/), [MyShell AI](https://myshell.ai/), [Chai AI](https://www.chai-research.com/). 227 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | # Roadmap 2 | 3 | ## Functionality 4 | - [ ] Batched inference 5 | - [ ] Fine-grained KV cache management 6 | - [x] Explore tree sparsity 7 | - [x] Fine-tune Medusa heads together with LM head from scratch 8 | - [x] Distill from any model without access to the original training data 9 | 10 | ## Integration 11 | ### Local Deployment 12 | - [ ] [mlc-llm](https://github.com/mlc-ai/mlc-llm) 13 | - [ ] [exllama](https://github.com/turboderp/exllama) 14 | - [ ] [llama.cpp](https://github.com/ggerganov/llama.cpp) 15 | ### Serving 16 | - [ ] [vllm](https://github.com/vllm-project/vllm) 17 | - [ ] [lightllm](https://github.com/ModelTC/lightllm) 18 | - [x] [TGI](https://github.com/huggingface/text-generation-inference) 19 | - [x] [TensorRT](https://github.com/NVIDIA/TensorRT-LLM) -------------------------------------------------------------------------------- /assets/category_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/category_speedup.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/logo.png -------------------------------------------------------------------------------- /assets/medusa_acc.csv: -------------------------------------------------------------------------------- 1 | "Name","Created","Runtime","End Time","Hostname","ID","Notes","State","Updated","Tags","eval_batch_size","logging_dir","output_dir","per_device_eval_batch_size","per_device_train_batch_size","train_batch_size","train/loss","train/medusa0_loss","train/medusa0_top1","train/medusa0_top2","train/medusa0_top3","train/medusa0_top4","train/medusa0_top5","train/medusa1_loss","train/medusa1_top1","train/medusa1_top2","train/medusa1_top3","train/medusa1_top4","train/medusa1_top5","train/medusa2_loss","train/medusa2_top1","train/medusa2_top2","train/medusa2_top3","train/medusa2_top4","train/medusa2_top5","train/medusa3_loss","train/medusa3_top1","train/medusa3_top2","train/medusa3_top3","train/medusa3_top4","train/medusa3_top5","train/medusa4_loss","train/medusa4_top1","train/medusa4_top2","train/medusa4_top3","train/medusa4_top4","train/medusa4_top5","train/train_loss","train/train_runtime","train/train_samples_per_second","train/train_steps_per_second" 2 | "33b","2023-08-14T02:40:59.000Z","3199","2023-08-14T03:34:18.000Z","della-l07g4","av0zctkn","-","failed","2023-08-17T18:54:58.000Z","","4","test/runs/Aug13_22-38-18_della-l07g4","test_medusa_mlp_vicuna-33b-v1.3_medusa_5_lr_0.001_layers_1","4","4","4","19.7555","1.8279917240142824","0.6045850515365601","0.7161898016929626","0.7648836374282837","0.798293948173523","0.8222854137420654","3.4793782234191895","0.3557846248149872","0.4689888060092926","0.5287009477615356","0.5692198276519775","0.59605473279953","4.460977077484131","0.22463124990463257","0.31970855593681335","0.3785320818424225","0.4210058748722077","0.4462413489818573","5.0544304847717285","0.15958771109580994","0.24880042672157288","0.30726853013038635","0.34796518087387085","0.37551093101501465","5.388251781463623","0.13701795041561127","0.20668207108974457","0.2621290385723114","0.29838281869888306","0.32575085759162903","24.166565484840778","3234.9651","21.213","0.166" 3 | "13b","2023-08-13T22:31:29.000Z","2763","2023-08-13T23:17:32.000Z","della-l08g5","hy3g0c62","-","finished","2023-08-14T02:22:24.000Z","","8","test/runs/Aug13_18-29-53_della-l08g5","test_medusa_mlp_vicuna-13b-v1.3_medusa_5_lr_0.001_layers_1","8","8","8","19.5949","1.8737130165100095","0.5939363837242126","0.705268383026123","0.7578279972076416","0.7924950122833252","0.8161033391952515","3.575400114059448","0.3397117257118225","0.439985066652298","0.5068339705467224","0.5464711785316467","0.580268383026123","4.575368881225586","0.216078519821167","0.30852383375167847","0.3612077236175537","0.4029572308063507","0.4363816976547241","5.15444803237915","0.14997513592243197","0.23173458874225616","0.28777334094047546","0.32343438267707825","0.3541252315044403","5.453802585601807","0.12226639688014984","0.1957007795572281","0.243414506316185","0.2790755331516266","0.3070327937602997","24.09846121860838","2793.9231","24.562","0.192" 4 | "7b","2023-08-13T22:07:43.000Z","1909","2023-08-13T22:39:32.000Z","della-l08g2","ub9cluo4","-","finished","2023-08-14T02:22:20.000Z","","8","test/runs/Aug13_18-06-30_della-l08g2","test_medusa_mlp_vicuna-7b-v1.3_medusa_5_lr_0.001_layers_1","8","8","8","20.2451","2.069507122039795","0.5603876709938049","0.6717196702957153","0.7271371483802795","0.7599403262138367","0.7799453139305115","3.723043203353882","0.31635186076164246","0.4235834777355194","0.4850894510746002","0.5282057523727417","0.5625","4.692985534667969","0.2010437250137329","0.28789758682250977","0.3475397527217865","0.3834492862224579","0.4168737530708313","5.258499622344971","0.14736579358577728","0.22502483427524567","0.27373260259628296","0.3108846843242645","0.3397117257118225","5.5384345054626465","0.11754472553730012","0.18663020431995392","0.23459243774414065","0.2721172869205475","0.2998260259628296","24.93223100445568","1917.1552","35.794","0.28" -------------------------------------------------------------------------------- /assets/medusa_choices.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_choices.png -------------------------------------------------------------------------------- /assets/medusa_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_demo.gif -------------------------------------------------------------------------------- /assets/medusa_pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_pipeline.jpg -------------------------------------------------------------------------------- /assets/medusa_speedup_cmp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/medusa_speedup_cmp.jpg -------------------------------------------------------------------------------- /assets/size_speedup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/assets/size_speedup.png -------------------------------------------------------------------------------- /create_data.py: -------------------------------------------------------------------------------- 1 | import typer 2 | import json 3 | from transformers import Conversation 4 | from typing_extensions import Annotated 5 | import httpx 6 | import tqdm 7 | import asyncio 8 | 9 | app = typer.Typer() 10 | 11 | 12 | client = httpx.AsyncClient(timeout=None) 13 | 14 | async def run(conv: Conversation, url: str): 15 | payload = {"model":"tgi", "messages": conv.messages} 16 | response = await client.post(url, json=payload) 17 | content = response.json() 18 | message = content["choices"][0]["message"] 19 | message.pop("name", None) 20 | conv.add_message(message) 21 | 22 | 23 | 24 | 25 | def fix_source(source): 26 | if source and source[0]["from"] == "gpt": 27 | # Skip if GPT is first to talk 28 | source = source[1:] 29 | new_source = [] 30 | for item in source: 31 | role = "assistant" if item["from"] == "gpt" else "user" 32 | content = item["value"] 33 | new_source.append({"role": role, "content": content}) 34 | return new_source 35 | 36 | 37 | async def recreate_conversation(conversation, sem, url): 38 | async with sem: 39 | conv = Conversation() 40 | try: 41 | for message in conversation[::2]: 42 | assert message["role"] == "user" 43 | conv.add_message(message) 44 | await run(conv, url) 45 | except Exception as e: 46 | print(e) 47 | pass 48 | return conv.messages 49 | 50 | @app.command() 51 | def main( 52 | *, 53 | input_filename: Annotated[str, typer.Option("--input-filename")], 54 | output_filename: Annotated[str, typer.Option("--output-filename")], 55 | url: Annotated[str, typer.Option("--url")] = "http://localhost:8080/v1/chat/completions", 56 | concurrency: Annotated[int, typer.Option("--concurrency")] = 64 57 | ): 58 | sem = asyncio.Semaphore(concurrency) 59 | async def _main(): 60 | with open(input_filename, "r") as f: 61 | input_data = json.loads(f.read()) 62 | conversations = [fix_source(source["conversations"]) for source in input_data] 63 | 64 | futures = [] 65 | for conversation in conversations: 66 | future = recreate_conversation(conversation, sem, url) 67 | futures.append(future) 68 | 69 | recreated_conversations = await tqdm.asyncio.tqdm.gather(*futures) 70 | 71 | with open(output_filename, "w") as f: 72 | json.dump(recreated_conversations, f, indent=4) 73 | asyncio.run(_main()) 74 | 75 | 76 | if __name__ == "__main__": 77 | app() 78 | -------------------------------------------------------------------------------- /data_generation/README.md: -------------------------------------------------------------------------------- 1 | # Generate chat data for self-distillation 2 | We use vLLM to enable batched generation. First, install dependencies: 3 | ```bash 4 | pip install vllm openai 5 | ``` 6 | 7 | ## Start server 8 | 9 | ```bash 10 | python -m vllm.entrypoints.openai.api_server \ 11 | --model YOUR_MODEL_NAME --port 8000 12 | ``` 13 | You can also start multiple servers with different ports to enable parallel generation. In `generate.py`, we scan the ports from 8000 to 8009 to find available servers. You can modify the code to use other ports. 14 | 15 | ## Generate data 16 | The following command will let the model to continue the first prompt from each sample in `DATA_PATH`, this is suitable for models that can play both roles in a conversation (e.g., Zephyr 7B). If you want to use all prompts in each sample to repeatly talk to the model, use `--chat` instead. `--chat` mode works for more models but may take longer time to generate due to repeated computation (welcome to contribute a better implementation). 17 | 18 | ```bash 19 | python generate.py --data_path YOUR_DATA_PATH --output_path YOUR_OUTPUT_PATH --num_threads NUM_THREADS --max_tokens YOUR_MAX_TOKENS --temperature YOUR_TEMPERATURE 20 | ``` 21 | 22 | ## (Optional) Format data 23 | When generated with `--chat`, the output file will follow the ShareGPT format ([example](https://github.com/lm-sys/FastChat/blob/main/data/dummy_conversation.json)). 24 | You can use the following command to convert the generated text withour `--chat` to the same format: 25 | ```bash 26 | python convert_to_sharegpt.py --input_path YOUR_INPUT_PATH --model_name YOUR_MODEL_NAME --output_path YOUR_OUTPUT_PATH 27 | ``` -------------------------------------------------------------------------------- /data_generation/convert_to_sharegpt.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import concurrent.futures 5 | 6 | import openai 7 | import shortuuid 8 | import tqdm 9 | 10 | import argparse 11 | import random 12 | 13 | from tenacity import ( 14 | retry, 15 | stop_after_attempt, 16 | wait_random_exponential, 17 | ) 18 | 19 | from fastchat.conversation import Conversation, SeparatorStyle 20 | from fastchat.model.model_adapter import get_conversation_template 21 | from transformers import AutoTokenizer 22 | 23 | # Use the same arguments as in generate.py 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--input_path", type=str) 26 | parser.add_argument("--model_name", type=str, default="HuggingFaceH4/zephyr-7b-beta") 27 | args = parser.parse_args() 28 | 29 | conv = get_conversation_template(args.model_name) 30 | tokenizer = AutoTokenizer.from_pretrained(args.model_name) 31 | 32 | data = [] 33 | with open(args.input_path) as f: 34 | for line in f.readlines(): 35 | data.append(json.loads(line)) 36 | 37 | def convert(text): 38 | messages = [] 39 | 40 | for turn in text.split(conv.roles[0]): 41 | pairs = turn.split(conv.roles[1]) 42 | if len(pairs) != 2: 43 | continue 44 | messages.append({ 45 | "from": "human", 46 | "value": pairs[0].split(conv.sep)[0].strip() 47 | }) 48 | messages.append({ 49 | "from": "gpt", 50 | "value": pairs[1].split(conv.sep)[0].strip() 51 | }) 52 | # pop the last message because it might be incomplete 53 | if len(messages) > 0: 54 | messages.pop() 55 | # make sure number of messages is even 56 | if len(messages) % 2 == 1: 57 | messages.pop() 58 | return {"conversations": messages} 59 | 60 | sharegpt_data = [] 61 | for d in tqdm.tqdm(data): 62 | sample = convert(d["text"]) 63 | if len(sample["conversations"]) < 2: 64 | continue 65 | sharegpt_data.append(sample) 66 | 67 | # dump to jsonl 68 | with open(args.input_path.replace(".jsonl", "_sharegpt.jsonl"), "w") as f: 69 | for d in sharegpt_data: 70 | f.write(json.dumps(d) + "\n") -------------------------------------------------------------------------------- /data_generation/generate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import concurrent.futures 5 | 6 | import openai 7 | import shortuuid 8 | import tqdm 9 | 10 | import argparse 11 | import random 12 | 13 | from tenacity import ( 14 | retry, 15 | stop_after_attempt, 16 | wait_random_exponential, 17 | ) 18 | 19 | from fastchat.conversation import Conversation, SeparatorStyle 20 | from fastchat.model.model_adapter import get_conversation_template 21 | 22 | # Modify OpenAI's API key and API base to use vLLM's API server. 23 | openai.api_key = "EMPTY" 24 | openai.api_base = "http://localhost:8000/v1" 25 | 26 | api_base_pool = [] 27 | 28 | # List models API 29 | for i in range(10): 30 | openai.api_base = "http://localhost:800{}/v1".format(i) 31 | try: 32 | models = openai.Model.list()["data"][0]["id"] 33 | print(openai.api_base, models) 34 | api_base_pool.append(openai.api_base) 35 | except: 36 | break 37 | 38 | print("API base pool: ", api_base_pool) 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--data_path", type=str) 42 | parser.add_argument("--output_path", type=str) 43 | parser.add_argument("--num_threads", type=int, default=256) 44 | parser.add_argument("--temperature", type=float, default=0.3) 45 | parser.add_argument("--max_tokens", type=int, default=2048) 46 | parser.add_argument("--chat", action="store_true") 47 | args = parser.parse_args() 48 | 49 | # Assuming the ShareGPT format 50 | data = json.load(open(args.data_path, "r")) 51 | 52 | def generate_data(messages, idx): 53 | try: 54 | # load balanced 55 | openai.api_base = api_base_pool[idx % len(api_base_pool)] 56 | model_name=openai.Model.list()["data"][0]["id"] 57 | 58 | if args.chat: 59 | converted_messages = [] 60 | output_messages = [] 61 | if messages[0]["from"] == "system": 62 | converted_messages.append( 63 | { 64 | "role": "system", 65 | "content": messages[0]["text"], 66 | } 67 | ) 68 | output_messages.append(messages[0]) 69 | messages = messages[1:] 70 | for message in messages[::2]: 71 | if message["from"] != "human": 72 | return 73 | converted_messages.append( 74 | { 75 | "role": "user", 76 | "content": message["value"], 77 | } 78 | ) 79 | try: 80 | response = openai.ChatCompletion.create( 81 | model=model_name, 82 | messages=converted_messages, 83 | max_tokens=args.max_tokens, 84 | temperature=args.temperature, 85 | ) 86 | if response.choices[0]['finish_reason'] == "length": 87 | break 88 | response = response.choices[0]['message']['content'].strip() 89 | output_messages.append(message) 90 | output_messages.append( 91 | { 92 | "from": "gpt", 93 | "value": response, 94 | } 95 | ) 96 | converted_messages.append( 97 | { 98 | "role": "assistant", 99 | "content": response, 100 | } 101 | ) 102 | except: 103 | break 104 | if len(output_messages) == 0: 105 | return 106 | with open(args.output_path, "a") as f: 107 | # write in share gpt format 108 | f.write(json.dumps({"conversations": output_messages}) + "\n") 109 | else: 110 | conv = get_conversation_template(model_name) 111 | if messages[0]["from"] == "system": 112 | conv.system_message = messages[0]["text"] 113 | messages = messages[1:] 114 | conv.append_message(conv.roles[0], messages[0]["value"]) 115 | conv.append_message(conv.roles[1], None) 116 | prompt = conv.get_prompt() 117 | 118 | response = openai.Completion.create( 119 | model=model_name, 120 | prompt=prompt, 121 | max_tokens=args.max_tokens, 122 | temperature=args.temperature, 123 | ignore_eos=True, 124 | skip_special_tokens=False, 125 | spaces_between_special_tokens=False, 126 | ) 127 | response = response.choices[0]['text'].strip() 128 | with open(args.output_path, "a") as f: 129 | # write in share gpt format 130 | f.write(json.dumps({"text": prompt+response}) + "\n") 131 | except Exception as e: 132 | print(e) 133 | print(prompt) 134 | print("Failed to generate data") 135 | 136 | # if output_path exists, count the number of lines and skip the first n data 137 | start = 0 138 | if os.path.exists(args.output_path): 139 | with open(args.output_path, "r") as f: 140 | start = len(f.readlines()) 141 | print("Skip first {} data".format(start)) 142 | 143 | with concurrent.futures.ThreadPoolExecutor(max_workers=args.num_threads) as executor: 144 | futures = [] 145 | for idx, sample in enumerate(data[start:]): 146 | future = executor.submit( 147 | generate_data, 148 | sample["conversations"], 149 | idx, 150 | ) 151 | futures.append(future) 152 | 153 | for future in tqdm.tqdm( 154 | concurrent.futures.as_completed(futures), total=len(futures) 155 | ): 156 | future.result() -------------------------------------------------------------------------------- /deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | 6 | "zero_optimization": { 7 | "stage": 3, 8 | "overlap_comm": true, 9 | "contiguous_gradients": true, 10 | "sub_group_size": 1e9, 11 | "reduce_bucket_size": "auto", 12 | "stage3_prefetch_bucket_size": "auto", 13 | "stage3_param_persistence_threshold": "auto", 14 | "stage3_max_live_parameters": 1e9, 15 | "stage3_max_reuse_distance": 1e9, 16 | "stage3_gather_16bit_weights_on_model_save": true 17 | }, 18 | 19 | "gradient_accumulation_steps": "auto", 20 | "steps_per_print": 2000, 21 | "train_batch_size": "auto", 22 | "train_micro_batch_size_per_gpu": "auto", 23 | "wall_clock_breakdown": false 24 | } 25 | -------------------------------------------------------------------------------- /llm_judge/README.md: -------------------------------------------------------------------------------- 1 | # LLM Judge 2 | | [Original Github Repository](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge) 3 | 4 | ## Installation 5 | 6 | | [Guide](https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/README.md) 7 | 8 | ## Usage 9 | 10 | We report the 3 times running results of the Medusa X Vicuna v1.3 7/13/33b on a single A100 in `./data/mt_bench/model_answer/`. The original settings are: `temperature` (it is deprecated and use the default LLM Judge setting), `posterior_threshold=0.09`, `posterior_alpha=0.3`. 11 | 12 | - Run benchmark 13 | 14 | 15 | ``` 16 | export CUDA_VISIBLE_DEVICES=0 # set the GPU id 17 | python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-7b-v1.3 --model-id medusa-vicuna-7b-v1.3-0 18 | python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-13b-v1.3 --model-id medusa-vicuna-13b-v1.3-0 19 | python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-33b-v1.3 --model-id medusa-vicuna-33b-v1.3-0 20 | ``` 21 | 22 | - Run baseline: replace `gen_model_answer_medusa.py` with `gen_model_answer_baseline.py` (Please note we only implement the greedy inference for wall-time comparison. If you want to use the sampling generator, please refer to the original repository.) 23 | 24 | 25 | - Query the results 26 | 27 | ``` 28 | export OPENAI_API_KEY=$OPENAI_API_KEYs # set the OpenAI API key 29 | python gen_judgement.py --model-list medusa-vicuna-7b-v1.3-0-temperature-0.0-posterior_threshold-0.09-posterior_alpha-0.3 30 | ``` 31 | 32 | - Show results 33 | 34 | To obtain the results of GPT-4 judge for Vicuna-7b ( Huggingface greedy | Huggingface sampling | Medusa sampling), run: 35 | 36 | ``` 37 | python show_result.py 38 | ``` 39 | 40 | ## Citation 41 | Please cite the original paper if you find the code or datasets helpful. 42 | ``` 43 | @misc{zheng2023judging, 44 | title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena}, 45 | author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica}, 46 | year={2023}, 47 | eprint={2306.05685}, 48 | archivePrefix={arXiv}, 49 | primaryClass={cs.CL} 50 | } 51 | ``` -------------------------------------------------------------------------------- /llm_judge/data/judge_prompts.jsonl: -------------------------------------------------------------------------------- 1 | {"name": "pair-v2", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. You should choose the assistant that follows the user's instructions and answers the user's question better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. Begin your evaluation by comparing the two responses and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[A]]"} 2 | {"name": "pair-v2-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. You should choose the assistant that follows the user's instructions and answers the user's questions better. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of their responses. You should focus on who provides a better answer to the second user question. Begin your evaluation by comparing the responses of the two assistants and provide a short explanation. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} 3 | {"name": "pair-math-v1", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer, assistant A's answer, and assistant B's answer. Your job is to evaluate which assistant's answer is better. Begin your evaluation by comparing both assistants' answers with the reference answer. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "[User Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant A's Answer]\n{answer_a}\n[The End of Assistant A's Answer]\n\n[The Start of Assistant B's Answer]\n{answer_b}\n[The End of Assistant B's Answer]", "description": "Prompt for math questions", "category": "math", "output_format": "[[A]]"} 4 | {"name": "pair-math-v1-multi-turn", "type": "pairwise", "system_prompt": "Please act as an impartial judge and evaluate the quality of the responses provided by two AI assistants to the user questions. Your evaluation should consider correctness and helpfulness. You will be given reference answers, the assistant A's answers, the assistant B's answers. Your job is to determine which assistant provides correct and helpful answers to the second user question. Begin your evaluation by comparing both assistants' answers with the reference answers. Identify and correct any mistakes. Avoid any position biases and ensure that the order in which the responses were presented does not influence your decision. Do not allow the length of the responses to influence your evaluation. Do not favor certain names of the assistants. Be as objective as possible. After providing your explanation, output your final verdict by strictly following this format: \"[[A]]\" if assistant A is better, \"[[B]]\" if assistant B is better, and \"[[C]]\" for a tie.", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_a_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_a_2}\n\n<|The End of Assistant A's Conversation with User|>\n\n\n<|The Start of Assistant B's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant B:\n{answer_b_1}\n\n### User:\n{question_2}\n\n### Assistant B:\n{answer_b_2}\n\n<|The End of Assistant B's Conversation with User|>", "description": "Prompt for multi-turn general questions", "category": "general", "output_format": "[[A]]"} 5 | {"name": "single-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} 6 | {"name": "single-math-v1", "type": "single", "system_prompt": "You are a helpful assistant.", "prompt_template": "[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n[Question]\n{question}\n\n[The Start of Reference Answer]\n{ref_answer_1}\n[The End of Reference Answer]\n\n[The Start of Assistant's Answer]\n{answer}\n[The End of Assistant's Answer]", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} 7 | {"name": "single-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. Your evaluation should consider factors such as the helpfulness, relevance, accuracy, depth, creativity, and level of detail of the response. You evaluation should focus on the assistant's answer to the second user question. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "general", "output_format": "[[rating]]"} 8 | {"name": "single-math-v1-multi-turn", "type": "single", "system_prompt": "Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question. Your evaluation should consider correctness and helpfulness. You will be given a reference answer and the assistant's answer. You evaluation should focus on the assistant's answer to the second question. Begin your evaluation by comparing the assistant's answer with the reference answer. Identify and correct any mistakes. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: \"[[rating]]\", for example: \"Rating: [[5]]\".\n\n", "prompt_template": "<|The Start of Reference Answer|>\n\n### User:\n{question_1}\n\n### Reference answer:\n{ref_answer_1}\n\n### User:\n{question_2}\n\n### Reference answer:\n{ref_answer_2}\n\n<|The End of Reference Answer|>\n\n\n<|The Start of Assistant A's Conversation with User|>\n\n### User:\n{question_1}\n\n### Assistant A:\n{answer_1}\n\n### User:\n{question_2}\n\n### Assistant A:\n{answer_2}\n\n<|The End of Assistant A's Conversation with User|>", "description": "Prompt for general questions", "category": "math", "output_format": "[[rating]]"} 9 | -------------------------------------------------------------------------------- /llm_judge/gen_judgement.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python gen_judgment.py --model-list [LIST-OF-MODEL-ID] --parallel [num-concurrent-api-call] --mode [single|pairwise-baseline|pairwise-all] 4 | """ 5 | import argparse 6 | from concurrent.futures import ThreadPoolExecutor 7 | import json 8 | 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from fastchat.llm_judge.common import ( 13 | load_questions, 14 | load_model_answers, 15 | load_judge_prompts, 16 | check_data, 17 | play_a_match_pair, 18 | play_a_match_single, 19 | get_model_list, 20 | Judge, 21 | MatchPair, 22 | MatchSingle, 23 | NEED_REF_CATS, 24 | ) 25 | 26 | 27 | def make_match( 28 | questions, 29 | models, 30 | model_answers, 31 | judge, 32 | baseline_model, 33 | ref_answers=None, 34 | multi_turn=False, 35 | ): 36 | matches = [] 37 | for q in questions: 38 | if multi_turn and len(q["turns"]) != 2: 39 | continue 40 | for i in range(len(models)): 41 | q_id = q["question_id"] 42 | m_1 = models[i] 43 | m_2 = baseline_model 44 | if m_1 == m_2: 45 | continue 46 | a_1 = model_answers[m_1][q_id] 47 | a_2 = model_answers[baseline_model][q_id] 48 | if ref_answers is not None: 49 | ref = ref_answers[judge.model_name][q_id] 50 | match = MatchPair( 51 | dict(q), 52 | m_1, 53 | m_2, 54 | a_1, 55 | a_2, 56 | judge, 57 | ref_answer=ref, 58 | multi_turn=multi_turn, 59 | ) 60 | else: 61 | match = MatchPair( 62 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn 63 | ) 64 | matches.append(match) 65 | return matches 66 | 67 | 68 | def make_match_all_pairs( 69 | questions, 70 | models, 71 | model_answers, 72 | judge, 73 | baseline_model=None, 74 | ref_answers=None, 75 | multi_turn=False, 76 | ): 77 | matches = [] 78 | for q in questions: 79 | if multi_turn and len(q["turns"]) != 2: 80 | continue 81 | for i in range(len(models)): 82 | for j in range(i + 1, len(models)): 83 | q_id = q["question_id"] 84 | m_1 = models[i] 85 | m_2 = models[j] 86 | a_1 = model_answers[m_1][q_id] 87 | a_2 = model_answers[m_2][q_id] 88 | if ref_answers is not None: 89 | ref = ref_answers[judge.model_name][q_id] 90 | match = MatchPair( 91 | dict(q), 92 | m_1, 93 | m_2, 94 | a_1, 95 | a_2, 96 | judge, 97 | ref_answer=ref, 98 | multi_turn=multi_turn, 99 | ) 100 | else: 101 | match = MatchPair( 102 | dict(q), m_1, m_2, a_1, a_2, judge, multi_turn=multi_turn 103 | ) 104 | matches.append(match) 105 | return matches 106 | 107 | 108 | def make_match_single( 109 | questions, 110 | models, 111 | model_answers, 112 | judge, 113 | baseline_model=None, 114 | ref_answers=None, 115 | multi_turn=False, 116 | ): 117 | matches = [] 118 | for q in questions: 119 | if multi_turn and len(q["turns"]) != 2: 120 | continue 121 | for i in range(len(models)): 122 | q_id = q["question_id"] 123 | m = models[i] 124 | a = model_answers[m][q_id] 125 | if ref_answers is not None: 126 | ref = ref_answers[judge.model_name][q_id] 127 | matches.append( 128 | MatchSingle( 129 | dict(q), m, a, judge, ref_answer=ref, multi_turn=multi_turn 130 | ) 131 | ) 132 | else: 133 | matches.append(MatchSingle(dict(q), m, a, judge, multi_turn=multi_turn)) 134 | return matches 135 | 136 | 137 | def make_judge_pairwise(judge_model, judge_prompts): 138 | judges = {} 139 | judges["default"] = Judge(judge_model, judge_prompts["pair-v2"]) 140 | judges["math"] = Judge(judge_model, judge_prompts["pair-math-v1"], ref_based=True) 141 | judges["default-mt"] = Judge( 142 | judge_model, judge_prompts["pair-v2-multi-turn"], multi_turn=True 143 | ) 144 | judges["math-mt"] = Judge( 145 | judge_model, 146 | judge_prompts["pair-math-v1-multi-turn"], 147 | ref_based=True, 148 | multi_turn=True, 149 | ) 150 | return judges 151 | 152 | 153 | def make_judge_single(judge_model, judge_prompts): 154 | judges = {} 155 | judges["default"] = Judge(judge_model, judge_prompts["single-v1"]) 156 | judges["math"] = Judge(judge_model, judge_prompts["single-math-v1"], ref_based=True) 157 | judges["default-mt"] = Judge( 158 | judge_model, judge_prompts["single-v1-multi-turn"], multi_turn=True 159 | ) 160 | judges["math-mt"] = Judge( 161 | judge_model, 162 | judge_prompts["single-math-v1-multi-turn"], 163 | ref_based=True, 164 | multi_turn=True, 165 | ) 166 | return judges 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = argparse.ArgumentParser() 171 | parser.add_argument( 172 | "--bench-name", 173 | type=str, 174 | default="mt_bench", 175 | help="The name of the benchmark question set.", 176 | ) 177 | parser.add_argument( 178 | "--judge-file", 179 | type=str, 180 | default="data/judge_prompts.jsonl", 181 | help="The file of judge prompts.", 182 | ) 183 | parser.add_argument("--judge-model", type=str, default="gpt-4") 184 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") 185 | parser.add_argument( 186 | "--mode", 187 | type=str, 188 | default="single", 189 | choices=["pairwise-baseline", "pairwise-all", "single"], 190 | help=( 191 | "Evaluation mode. " 192 | "`pairwise-baseline` runs pairwise comparision against a baseline. " 193 | "`pairwise-all` runs pairwise comparision between all pairs. " 194 | "`single` runs single answer grading." 195 | ), 196 | ) 197 | parser.add_argument( 198 | "--model-list", 199 | type=str, 200 | nargs="+", 201 | default=None, 202 | help="A list of models to be evaluated", 203 | ) 204 | parser.add_argument( 205 | "--parallel", type=int, default=1, help="The number of concurrent API calls." 206 | ) 207 | parser.add_argument( 208 | "--first-n", type=int, help="A debug option. Only run the first `n` judgments." 209 | ) 210 | args = parser.parse_args() 211 | 212 | question_file = f"data/{args.bench_name}/question.jsonl" 213 | answer_dir = f"data/{args.bench_name}/model_answer" 214 | ref_answer_dir = f"data/{args.bench_name}/reference_answer" 215 | 216 | # Load questions 217 | questions = load_questions(question_file, None, None) 218 | 219 | # Load answers 220 | model_answers = load_model_answers(answer_dir) 221 | ref_answers = load_model_answers(ref_answer_dir) 222 | 223 | # Load judge 224 | judge_prompts = load_judge_prompts(args.judge_file) 225 | 226 | if args.first_n: 227 | questions = questions[: args.first_n] 228 | 229 | if args.model_list is None: 230 | models = get_model_list(answer_dir) 231 | else: 232 | models = args.model_list 233 | 234 | if args.mode == "single": 235 | judges = make_judge_single(args.judge_model, judge_prompts) 236 | play_a_match_func = play_a_match_single 237 | output_file = ( 238 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" 239 | ) 240 | make_match_func = make_match_single 241 | baseline_model = None 242 | else: 243 | judges = make_judge_pairwise(args.judge_model, judge_prompts) 244 | play_a_match_func = play_a_match_pair 245 | output_file = ( 246 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" 247 | ) 248 | if args.mode == "pairwise-all": 249 | make_match_func = make_match_all_pairs 250 | baseline_model = None 251 | else: 252 | make_match_func = make_match 253 | baseline_model = args.baseline_model 254 | 255 | check_data(questions, model_answers, ref_answers, models, judges) 256 | 257 | question_math = [q for q in questions if q["category"] in NEED_REF_CATS] 258 | question_default = [q for q in questions if q["category"] not in NEED_REF_CATS] 259 | 260 | # Make matches 261 | matches = [] 262 | matches += make_match_func( 263 | question_default, models, model_answers, judges["default"], baseline_model 264 | ) 265 | matches += make_match_func( 266 | question_math, 267 | models, 268 | model_answers, 269 | judges["math"], 270 | baseline_model, 271 | ref_answers, 272 | ) 273 | matches += make_match_func( 274 | question_default, 275 | models, 276 | model_answers, 277 | judges["default-mt"], 278 | baseline_model, 279 | multi_turn=True, 280 | ) 281 | matches += make_match_func( 282 | question_math, 283 | models, 284 | model_answers, 285 | judges["math-mt"], 286 | baseline_model, 287 | ref_answers, 288 | multi_turn=True, 289 | ) 290 | 291 | # Filter out existed matches 292 | total_num_matches = len(matches) 293 | filtered_matches = [] 294 | try: 295 | with open(output_file, "r") as f: 296 | existed_matches = [json.loads(line) for line in f] 297 | except FileNotFoundError: 298 | existed_matches = [] 299 | uniq_ids = set( 300 | [ 301 | f"{e['question_id']}_{e['model']}_{e['judge'][0]}_{e['judge'][1]}_{e['turn']}" 302 | for e in existed_matches 303 | ] 304 | ) 305 | for match in matches: 306 | turn = 2 if match.judge.multi_turn else 1 307 | uniq_id = f"{match.question['question_id']}_{match.answer['model_id']}_{match.judge.model_name}_{match.judge.prompt_template['name']}_{turn}" 308 | if uniq_id in uniq_ids: 309 | print(f"Skip {uniq_id}") 310 | else: 311 | filtered_matches.append(match) 312 | matches = filtered_matches 313 | 314 | match_stat = {} 315 | match_stat["bench_name"] = args.bench_name 316 | match_stat["mode"] = args.mode 317 | match_stat["judge"] = args.judge_model 318 | match_stat["baseline"] = baseline_model 319 | match_stat["model_list"] = models 320 | match_stat["total_num_questions"] = len(questions) 321 | match_stat["total_num_matches"] = total_num_matches 322 | match_stat["current_num_matches"] = len(matches) 323 | match_stat["output_path"] = output_file 324 | 325 | # Show match stats and prompt enter to continue 326 | print("Stats:") 327 | print(json.dumps(match_stat, indent=4)) 328 | input("Press Enter to confirm...") 329 | 330 | # Play matches 331 | if args.parallel == 1: 332 | for match in tqdm(matches): 333 | play_a_match_func(match, output_file=output_file) 334 | else: 335 | 336 | def play_a_match_wrapper(match): 337 | play_a_match_func(match, output_file=output_file) 338 | 339 | np.random.seed(0) 340 | np.random.shuffle(matches) 341 | 342 | with ThreadPoolExecutor(args.parallel) as executor: 343 | for match in tqdm( 344 | executor.map(play_a_match_wrapper, matches), total=len(matches) 345 | ): 346 | pass -------------------------------------------------------------------------------- /llm_judge/gen_model_answer_medusa_legacy.py: -------------------------------------------------------------------------------- 1 | """Generate answers with local models. 2 | 3 | Usage: 4 | python3 gen_model_answer.py --model-path lmsys/fastchat-t5-3b-v1.0 --model-id fastchat-t5-3b-v1.0 5 | """ 6 | import argparse 7 | import json 8 | import os 9 | import random 10 | import time 11 | import shortuuid 12 | import torch 13 | from tqdm import tqdm 14 | 15 | from fastchat.llm_judge.common import load_questions, temperature_config 16 | from fastchat.model import load_model, get_conversation_template 17 | 18 | # Medusa imports 19 | import transformers 20 | 21 | 22 | from medusa.model.utils import * 23 | from medusa.model.medusa_model import MedusaModel 24 | from medusa.model.kv_cache import initialize_past_key_values 25 | from medusa.model.medusa_choices import * 26 | 27 | def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512): 28 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 29 | # Avoid modifying the input_ids in-place 30 | input_ids = input_ids.clone() 31 | 32 | # Cache medusa buffers (the fixed patterns for tree attention) 33 | if hasattr(model, "medusa_choices") and model.medusa_choices == medusa_choices: 34 | # Load the cached medusa buffer 35 | medusa_buffers = model.medusa_buffers 36 | else: 37 | # Initialize the medusa buffer 38 | medusa_buffers = generate_medusa_buffers( 39 | medusa_choices, device=model.base_model.device 40 | ) 41 | model.medusa_buffers = medusa_buffers 42 | model.medusa_choices = medusa_choices 43 | 44 | # Initialize the past key and value states 45 | if hasattr(model, "past_key_values"): 46 | past_key_values = model.past_key_values 47 | past_key_values_data = model.past_key_values_data 48 | current_length_data = model.current_length_data 49 | # Reset the past key and value states 50 | current_length_data.zero_() 51 | else: 52 | ( 53 | past_key_values, 54 | past_key_values_data, 55 | current_length_data, 56 | ) = initialize_past_key_values(model.base_model) 57 | model.past_key_values = past_key_values 58 | model.past_key_values_data = past_key_values_data 59 | model.current_length_data = current_length_data 60 | 61 | input_len = input_ids.shape[1] 62 | reset_medusa_mode(model) 63 | medusa_logits, logits = initialize_medusa( 64 | input_ids, model, medusa_buffers["medusa_attn_mask"], past_key_values 65 | ) 66 | new_token = 0 67 | 68 | for idx in range(max_steps): 69 | candidates, tree_candidates = generate_candidates( 70 | medusa_logits, 71 | logits, 72 | medusa_buffers["tree_indices"], 73 | medusa_buffers["retrieve_indices"], 74 | ) 75 | medusa_logits, logits, outputs = tree_decoding( 76 | model, 77 | tree_candidates, 78 | past_key_values, 79 | medusa_buffers["medusa_position_ids"], 80 | input_ids, 81 | medusa_buffers["retrieve_indices"], 82 | ) 83 | best_candidate, accept_length = evaluate_posterior( 84 | logits, candidates, temperature, posterior_threshold, posterior_alpha 85 | ) 86 | input_ids, logits, medusa_logits, new_token = update_inference_inputs( 87 | input_ids, 88 | candidates, 89 | best_candidate, 90 | accept_length, 91 | medusa_buffers["retrieve_indices"], 92 | outputs, 93 | logits, 94 | medusa_logits, 95 | new_token, 96 | past_key_values_data, 97 | current_length_data, 98 | ) 99 | if tokenizer.eos_token_id in input_ids[0, input_len:].tolist(): 100 | break 101 | if new_token > 1024: 102 | break 103 | return input_ids, new_token, idx 104 | 105 | def run_eval( 106 | model_path, 107 | model_id, 108 | question_file, 109 | question_begin, 110 | question_end, 111 | answer_file, 112 | max_new_token, 113 | num_choices, 114 | num_gpus_per_model, 115 | num_gpus_total, 116 | max_gpu_memory, 117 | temperature, 118 | posterior_threshold, 119 | posterior_alpha, 120 | medusa_choices, 121 | ): 122 | questions = load_questions(question_file, question_begin, question_end) 123 | # random shuffle the questions to balance the loading 124 | # random.shuffle(questions) 125 | shuffled_ids = [q["question_id"] for q in questions] 126 | # with open(f"data/{args.bench_name}/model_ids/{args.model_id}.shuffled_ids", "w") as fout: 127 | # json.dump(shuffled_ids, fout) 128 | 129 | # Split the question file into `num_gpus` files 130 | assert num_gpus_total % num_gpus_per_model == 0 131 | use_ray = num_gpus_total // num_gpus_per_model > 1 132 | 133 | if use_ray: 134 | get_answers_func = ray.remote(num_gpus=num_gpus_per_model)( 135 | get_model_answers 136 | ).remote 137 | else: 138 | get_answers_func = get_model_answers 139 | 140 | chunk_size = len(questions) // (num_gpus_total // num_gpus_per_model) # // 2 141 | ans_handles = [] 142 | for i in range(0, len(questions), chunk_size): 143 | ans_handles.append( 144 | get_answers_func( 145 | model_path, 146 | model_id, 147 | questions[i : i + chunk_size], 148 | answer_file, 149 | max_new_token, 150 | num_choices, 151 | num_gpus_per_model, 152 | max_gpu_memory, 153 | temperature, 154 | posterior_threshold, 155 | posterior_alpha, 156 | medusa_choices, 157 | ) 158 | ) 159 | 160 | if use_ray: 161 | ray.get(ans_handles) 162 | 163 | 164 | @torch.inference_mode() 165 | def get_model_answers( 166 | model_path, 167 | model_id, 168 | questions, 169 | answer_file, 170 | max_new_token, 171 | num_choices, 172 | num_gpus_per_model, 173 | max_gpu_memory, 174 | temperature, 175 | posterior_threshold, 176 | posterior_alpha, 177 | medusa_choices, 178 | ): 179 | 180 | # Medusa model setup 181 | num_heads = 4 182 | 183 | model = MedusaModel.from_pretrained( 184 | model_path, 185 | medusa_num_heads = num_heads, 186 | torch_dtype=torch.float16, 187 | low_cpu_mem_usage=True, 188 | device_map="auto" 189 | ) 190 | 191 | tokenizer = model.get_tokenizer() 192 | 193 | model.eval() 194 | print('Check model training state:',model.training) 195 | 196 | cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES') 197 | print('CUDA VISIBLE DEVICES:', cuda_visible_devices) 198 | 199 | question = questions[0] 200 | 201 | # warmup 202 | for _ in range(3): 203 | torch.manual_seed(0) 204 | conv = get_conversation_template(model_id) 205 | turns = [] 206 | idxs = [] 207 | new_tokens = [] 208 | wall_time = [] 209 | for j in range(len(question["turns"])): 210 | qs = question["turns"][j] 211 | conv.append_message(conv.roles[0], qs) 212 | conv.append_message(conv.roles[1], None) 213 | prompt = conv.get_prompt() 214 | input_ids = tokenizer([prompt]).input_ids 215 | 216 | # if temperature < 1e-4: 217 | # do_sample = False 218 | # else: 219 | # do_sample = True 220 | 221 | # some models may error out when generating long outputs 222 | try: 223 | torch.cuda.synchronize() 224 | start_time = time.time() 225 | output_ids, new_token, idx = medusa_forward( 226 | torch.as_tensor(input_ids).cuda(), 227 | model, 228 | tokenizer, 229 | medusa_choices, 230 | temperature, 231 | posterior_threshold, 232 | posterior_alpha, 233 | ) 234 | torch.cuda.synchronize() 235 | total_time = time.time() - start_time 236 | output_ids = output_ids[0][len(input_ids[0]) :] 237 | # be consistent with the template's stop_token_ids 238 | if conv.stop_token_ids: 239 | stop_token_ids_index = [ 240 | i 241 | for i, id in enumerate(output_ids) 242 | if id in conv.stop_token_ids 243 | ] 244 | if len(stop_token_ids_index) > 0: 245 | output_ids = output_ids[: stop_token_ids_index[0]] 246 | 247 | output = tokenizer.decode( 248 | output_ids, 249 | spaces_between_special_tokens=False, 250 | ) 251 | if conv.stop_str and output.find(conv.stop_str) > 0: 252 | output = output[: output.find(conv.stop_str)] 253 | for special_token in tokenizer.special_tokens_map.values(): 254 | if isinstance(special_token, list): 255 | for special_tok in special_token: 256 | output = output.replace(special_tok, "") 257 | else: 258 | output = output.replace(special_token, "") 259 | output = output.strip() 260 | 261 | if conv.name == "xgen" and output.startswith("Assistant:"): 262 | output = output.replace("Assistant:", "", 1).strip() 263 | except RuntimeError as e: 264 | print("ERROR question ID: ", question["question_id"]) 265 | output = "ERROR" 266 | 267 | turns.append(output) 268 | idxs.append(int(idx)) 269 | new_tokens.append(int(new_token)) 270 | wall_time.append(total_time) 271 | conv.messages[-1][-1] = output 272 | print('Warmup done') 273 | 274 | 275 | for question in tqdm(questions): 276 | if question["category"] in temperature_config: 277 | temperature = temperature_config[question["category"]] 278 | else: 279 | temperature = 0.7 280 | 281 | choices = [] 282 | for i in range(num_choices): 283 | torch.manual_seed(i) 284 | conv = get_conversation_template(model_id) 285 | turns = [] 286 | idxs = [] 287 | new_tokens = [] 288 | wall_time = [] 289 | for j in range(len(question["turns"])): 290 | qs = question["turns"][j] 291 | conv.append_message(conv.roles[0], qs) 292 | conv.append_message(conv.roles[1], None) 293 | prompt = conv.get_prompt() 294 | input_ids = tokenizer([prompt]).input_ids 295 | 296 | # if temperature < 1e-4: 297 | # do_sample = False 298 | # else: 299 | # do_sample = True 300 | 301 | # some models may error out when generating long outputs 302 | try: 303 | torch.cuda.synchronize() 304 | start_time = time.time() 305 | output_ids, new_token, idx = medusa_forward( 306 | torch.as_tensor(input_ids).cuda(), 307 | model, 308 | tokenizer, 309 | medusa_choices, 310 | temperature, 311 | posterior_threshold, 312 | posterior_alpha, 313 | ) 314 | torch.cuda.synchronize() 315 | total_time = time.time() - start_time 316 | # if model.config.is_encoder_decoder: 317 | # output_ids = output_ids[0] 318 | # else: 319 | output_ids = output_ids[0][len(input_ids[0]) :] 320 | 321 | # be consistent with the template's stop_token_ids 322 | if conv.stop_token_ids: 323 | stop_token_ids_index = [ 324 | i 325 | for i, id in enumerate(output_ids) 326 | if id in conv.stop_token_ids 327 | ] 328 | if len(stop_token_ids_index) > 0: 329 | output_ids = output_ids[: stop_token_ids_index[0]] 330 | 331 | output = tokenizer.decode( 332 | output_ids, 333 | spaces_between_special_tokens=False, 334 | ) 335 | if conv.stop_str and output.find(conv.stop_str) > 0: 336 | output = output[: output.find(conv.stop_str)] 337 | for special_token in tokenizer.special_tokens_map.values(): 338 | if isinstance(special_token, list): 339 | for special_tok in special_token: 340 | output = output.replace(special_tok, "") 341 | else: 342 | output = output.replace(special_token, "") 343 | output = output.strip() 344 | 345 | if conv.name == "xgen" and output.startswith("Assistant:"): 346 | output = output.replace("Assistant:", "", 1).strip() 347 | except RuntimeError as e: 348 | print("ERROR question ID: ", question["question_id"]) 349 | output = "ERROR" 350 | 351 | turns.append(output) 352 | idxs.append(int(idx)) 353 | new_tokens.append(int(new_token)) 354 | wall_time.append(total_time) 355 | conv.messages[-1][-1] = output 356 | # torch.cuda.empty_cache() 357 | choices.append({"index": i, "turns": turns, "idxs": idxs, "new_tokens": new_tokens, "wall_time": wall_time}) 358 | 359 | # Dump answers 360 | os.makedirs(os.path.dirname(answer_file), exist_ok=True) 361 | with open(os.path.expanduser(answer_file), "a") as fout: 362 | ans_json = { 363 | "question_id": question["question_id"], 364 | "answer_id": shortuuid.uuid(), 365 | "model_id": model_id, 366 | "choices": choices, 367 | "tstamp": time.time(), 368 | } 369 | fout.write(json.dumps(ans_json) + "\n") 370 | 371 | 372 | def reorg_answer_file(answer_file): 373 | """Sort by question id and de-duplication""" 374 | answers = {} 375 | with open(answer_file, "r") as fin: 376 | for l in fin: 377 | qid = json.loads(l)["question_id"] 378 | answers[qid] = l 379 | 380 | qids = sorted(list(answers.keys())) 381 | with open(answer_file, "w") as fout: 382 | for qid in qids: 383 | fout.write(answers[qid]) 384 | 385 | 386 | if __name__ == "__main__": 387 | parser = argparse.ArgumentParser() 388 | parser.add_argument( 389 | "--model-path", 390 | type=str, 391 | required=True, 392 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 393 | ) 394 | parser.add_argument("--model-id", type=str, required=True) 395 | parser.add_argument( 396 | "--bench-name", 397 | type=str, 398 | default="mt_bench", 399 | help="The name of the benchmark question set.", 400 | ) 401 | parser.add_argument( 402 | "--question-begin", 403 | type=int, 404 | help="A debug option. The begin index of questions.", 405 | ) 406 | parser.add_argument( 407 | "--question-end", type=int, help="A debug option. The end index of questions." 408 | ) 409 | parser.add_argument("--answer-file", type=str, help="The output answer file.") 410 | parser.add_argument( 411 | "--max-new-token", 412 | type=int, 413 | default=1024, 414 | help="The maximum number of new generated tokens.", 415 | ) 416 | parser.add_argument( 417 | "--num-choices", 418 | type=int, 419 | default=1, 420 | help="How many completion choices to generate.", 421 | ) 422 | parser.add_argument( 423 | "--num-gpus-per-model", 424 | type=int, 425 | default=1, 426 | help="The number of GPUs per model.", 427 | ) 428 | parser.add_argument( 429 | "--num-gpus-total", type=int, default=1, help="The total number of GPUs." 430 | ) 431 | parser.add_argument( 432 | "--max-gpu-memory", 433 | type=str, 434 | help="Maxmum GPU memory used for model weights per GPU.", 435 | ) 436 | 437 | # YL: Medusa args 438 | parser.add_argument( 439 | "--temperature", 440 | type=float, 441 | default=0.0, 442 | help="The temperature for medusa sampling.", 443 | ) 444 | 445 | parser.add_argument( 446 | "--posterior-threshold", 447 | type=float, 448 | default=0.09, 449 | help="The posterior threshold for medusa sampling.", 450 | ) 451 | 452 | parser.add_argument( 453 | "--posterior-alpha", 454 | type=float, 455 | default=0.3, 456 | help="The posterior alpha for medusa sampling.", 457 | ) 458 | 459 | parser.add_argument( 460 | "--medusa-choices", 461 | type=str, 462 | default="mc_sim_7b_63", 463 | help="The medusa choices for medusa sampling.", 464 | ) 465 | 466 | 467 | 468 | 469 | args = parser.parse_args() 470 | 471 | args.model_id = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha) 472 | args.medusa_choices = eval(args.medusa_choices) 473 | if args.num_gpus_total // args.num_gpus_per_model > 1: 474 | import ray 475 | 476 | ray.init() 477 | 478 | question_file = f"data/{args.bench_name}/question.jsonl" 479 | if args.answer_file: 480 | answer_file = args.answer_file 481 | else: 482 | answer_file = f"data/{args.bench_name}/model_answer/{args.model_id}.jsonl" 483 | 484 | print(f"Output to {answer_file}") 485 | 486 | run_eval( 487 | args.model_path, 488 | args.model_id, 489 | question_file, 490 | args.question_begin, 491 | args.question_end, 492 | answer_file, 493 | args.max_new_token, 494 | args.num_choices, 495 | args.num_gpus_per_model, 496 | args.num_gpus_total, 497 | args.max_gpu_memory, 498 | 499 | args.temperature, 500 | args.posterior_threshold, 501 | args.posterior_alpha, 502 | args.medusa_choices, 503 | ) 504 | 505 | reorg_answer_file(answer_file) -------------------------------------------------------------------------------- /llm_judge/show_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 show_result.py --mode [single|pairwise-baseline|pairwise-all] 4 | """ 5 | import argparse 6 | import pandas as pd 7 | 8 | 9 | def display_result_single(args): 10 | if args.input_file is None: 11 | input_file = ( 12 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_single.jsonl" 13 | ) 14 | else: 15 | input_file = args.input_file 16 | 17 | print(f"Input file: {input_file}") 18 | df_all = pd.read_json(input_file, lines=True) 19 | df = df_all[["model", "score", "turn"]] 20 | df = df[df["score"] != -1] 21 | 22 | if args.model_list is not None: 23 | df = df[df["model"].isin(args.model_list)] 24 | 25 | print("\n########## First turn ##########") 26 | df_1 = df[df["turn"] == 1].groupby(["model", "turn"]).mean() 27 | print(df_1.sort_values(by="score", ascending=False)) 28 | 29 | if "mt_bench" in args.bench_name: 30 | print("\n########## Second turn ##########") 31 | df_2 = df[df["turn"] == 2].groupby(["model", "turn"]).mean() 32 | print(df_2.sort_values(by="score", ascending=False)) 33 | 34 | print("\n########## Average ##########") 35 | df_3 = df[["model", "score"]].groupby(["model"]).mean() 36 | print(df_3.sort_values(by="score", ascending=False)) 37 | 38 | 39 | def display_result_pairwise(args): 40 | if args.input_file is None: 41 | input_file = ( 42 | f"data/{args.bench_name}/model_judgment/{args.judge_model}_pair.jsonl" 43 | ) 44 | else: 45 | input_file = args.input_file 46 | 47 | print(f"Input file: {input_file}") 48 | df_all = pd.read_json(input_file, lines=True) 49 | df_all = df_all[(df_all["g1_winner"] != "error") & (df_all["g2_winner"] != "error")] 50 | 51 | model_list = ( 52 | df_all["model_1"].unique().tolist() + df_all["model_2"].unique().tolist() 53 | ) 54 | model_list = list(set(model_list)) 55 | 56 | list_res = [] 57 | # traverse df row by row 58 | for index, row in df_all.iterrows(): 59 | if args.model_list is not None and row["model_1"] not in args.model_list: 60 | continue 61 | if args.baseline_model is not None: 62 | if args.baseline_model not in [row["model_1"], row["model_2"]]: 63 | continue 64 | if row["g1_winner"] == "tie" or row["g1_winner"] != row["g2_winner"]: 65 | list_res.append({"model": row["model_1"], "win": 0, "loss": 0, "tie": 1}) 66 | list_res.append({"model": row["model_2"], "win": 0, "loss": 0, "tie": 1}) 67 | else: 68 | if row["g1_winner"] == "model_1": 69 | winner = row["model_1"] 70 | loser = row["model_2"] 71 | else: 72 | winner = row["model_2"] 73 | loser = row["model_1"] 74 | list_res.append({"model": winner, "win": 1, "loss": 0, "tie": 0}) 75 | list_res.append({"model": loser, "win": 0, "loss": 1, "tie": 0}) 76 | 77 | df = pd.DataFrame(list_res) 78 | df = df.groupby(["model"]).sum() 79 | 80 | # remove baseline model 81 | if args.baseline_model is not None: 82 | df = df[df.index != args.baseline_model] 83 | # add win rate 84 | df["win_rate"] = df["win"] / (df["win"] + df["loss"] + df["tie"]) 85 | df["loss_rate"] = df["loss"] / (df["win"] + df["loss"] + df["tie"]) 86 | # each tie counts as 0.5 win + 0.5 loss 87 | df["win_rate_adjusted"] = (df["win"] + 0.5 * df["tie"]) / ( 88 | df["win"] + df["loss"] + df["tie"] 89 | ) 90 | # print(df.sort_values(by="win_rate", ascending=False)) 91 | # print(df.sort_values(by="loss_rate", ascending=True)) 92 | print(df.sort_values(by="win_rate_adjusted", ascending=False)) 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--bench-name", type=str, default="mt_bench") 98 | parser.add_argument("--input-file", type=str) 99 | parser.add_argument("--judge-model", type=str, default="gpt-4") 100 | parser.add_argument("--baseline-model", type=str, default="gpt-3.5-turbo") 101 | parser.add_argument( 102 | "--model-list", 103 | type=str, 104 | nargs="+", 105 | default=None, 106 | help="A list of models to be evaluated", 107 | ) 108 | parser.add_argument( 109 | "--mode", 110 | type=str, 111 | default="single", 112 | choices=["pairwise-baseline", "pairwise-all", "single"], 113 | help=( 114 | "Evaluation mode. " 115 | "`pairwise-baseline` runs pairwise comparision against a baseline. " 116 | "`pairwise-all` runs pairwise comparision between all pairs. " 117 | "`single` runs single answer grading." 118 | ), 119 | ) 120 | args = parser.parse_args() 121 | 122 | if args.mode == "single": 123 | display_result_func = display_result_single 124 | else: 125 | if args.mode == "pairwise-all": 126 | args.baseline_model = None 127 | display_result_func = display_result_pairwise 128 | 129 | print(f"Mode: {args.mode}") 130 | display_result_func(args) -------------------------------------------------------------------------------- /medusa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/__init__.py -------------------------------------------------------------------------------- /medusa/eval/README.md: -------------------------------------------------------------------------------- 1 | 2 | We use [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca_eval/blob/0cd24d711fe90d0c1aae5bde03fe98ee48ae52f8/alpaca_eval.json) dataset for evaluating each head's accuracy during generation in `heads_accuracy.py`. 3 | 4 | ``` 5 | python heads_accuracy.py --model_path 'FasterDecoding/medusa-vicuna-7b-v1.3' --model_name 'medusa-vicuna-7b-v1.3' --medusa_num_heads 5 --data_path '../../data/alpaca_eval.json' 6 | ``` 7 | 8 | 9 | To create the tree and plot the tree (requires `pygraphviz` package), please run: 10 | 11 | ``` 12 | python gen_results.py --accuracy-path '../../data/medusa-vicuna-7b-v1.3_heads_accuracy.pt' --output-path '../../data/graph.jpg' 13 | ``` 14 | 15 | If you want to use the tree, please add the generated tree (in a nested tuple) to `../model/medusa_choices.py`. 16 | 17 | Citation: 18 | 19 | ``` 20 | @misc{alpaca_eval, 21 | author = {Xuechen Li and Tianyi Zhang and Yann Dubois and Rohan Taori and Ishaan Gulrajani and Carlos Guestrin and Percy Liang and Tatsunori B. Hashimoto }, 22 | title = {AlpacaEval: An Automatic Evaluator of Instruction-following Models}, 23 | year = {2023}, 24 | publisher = {GitHub}, 25 | journal = {GitHub repository}, 26 | howpublished = {\url{https://github.com/tatsu-lab/alpaca_eval}} 27 | }``` -------------------------------------------------------------------------------- /medusa/eval/gen_results.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import copy 3 | import networkx as nx 4 | import torch 5 | import argparse 6 | 7 | def load_accuracy_table(path): 8 | test_accuracy = torch.load(path) 9 | accuracy_table = [] 10 | for i in range(len(test_accuracy)): 11 | accuracy_table.append(test_accuracy[i].sum(0)/16100) 12 | return torch.stack(accuracy_table) 13 | 14 | def get_node_expectation(accuracies, node): 15 | expectation = copy.deepcopy(accuracies[0, node[0]]) 16 | for i in range(1, len(node)): 17 | expectation *= accuracies[i, node[i]] 18 | return expectation 19 | 20 | def explore_graph(accuracies, max_depth, max_child, num_iterations): 21 | explored_nodes = {} 22 | accept_nodes = [tuple([0])] 23 | expectations = get_node_expectation(accuracies, accept_nodes[0]) 24 | explored_nodes[tuple(accept_nodes[0])] = expectations 25 | 26 | for _ in range(num_iterations): 27 | # find neighbors 28 | neighbors = [] 29 | for node in accept_nodes: 30 | if node[-1] < max_child[len(node) - 1] - 1: 31 | neighbor = list(copy.deepcopy(node)) 32 | neighbor[-1] = neighbor[-1] + 1 33 | neighbors.append(neighbor) 34 | if len(node) < max_depth: 35 | neighbor = list(copy.deepcopy(node)) 36 | neighbor.append(0) 37 | neighbors.append(neighbor) 38 | 39 | # find the best neighbor 40 | best_neighbor = None 41 | best_neighbor_expectation = 0 42 | for neighbor in neighbors: 43 | if tuple(neighbor) in accept_nodes: 44 | continue 45 | if tuple(neighbor) in explored_nodes: 46 | neighbor_expectation = explored_nodes[tuple(neighbor)] 47 | else: 48 | neighbor_expectation = get_node_expectation(accuracies, neighbor) 49 | explored_nodes[tuple(neighbor)] = neighbor_expectation 50 | if neighbor_expectation > best_neighbor_expectation: 51 | best_neighbor = neighbor 52 | best_neighbor_expectation = neighbor_expectation 53 | accept_nodes.append(tuple(best_neighbor)) 54 | expectations += best_neighbor_expectation 55 | 56 | return accept_nodes 57 | 58 | def plot_and_save_graph(accept_nodes, output_path): 59 | plt.figure(figsize=(40, 20)) 60 | 61 | G = nx.DiGraph() 62 | 63 | for path in accept_nodes: 64 | for i in range(len(path)): 65 | if i == 0: 66 | parent = 'root' 67 | else: 68 | parent = tuple(path[:i]) 69 | child = tuple(path[:i+1]) 70 | G.add_edge(parent, child) 71 | 72 | pos = nx.nx_agraph.graphviz_layout(G, prog='dot') 73 | nx.draw(G, pos, with_labels=True, node_size=500, node_color="skyblue", font_size=10, width=2, edge_color="gray") 74 | plt.savefig(output_path) 75 | 76 | def main(): 77 | parser = argparse.ArgumentParser(description="Generate Results.") 78 | parser.add_argument('--accuracy-path', type=str, required=True, help="Path to load accuracy tensor.") 79 | parser.add_argument('--output-path', type=str, required=True, help="Path to save the generated graph.") 80 | parser.add_argument('--max-depth', type=int, default=5, help="Maximum depth of the graph.") 81 | parser.add_argument('--num-iterations', type=int, default=62, help="Number of exploration iterations.") 82 | parser.add_argument('--max-child', nargs='+', type=int, default=[10, 10, 10, 10, 10], help="Maximum number of children per depth.") 83 | 84 | args = parser.parse_args() 85 | 86 | accuracies = load_accuracy_table(args.accuracy_path) 87 | accept_nodes = explore_graph(accuracies, args.max_depth, args.max_child, args.num_iterations) 88 | 89 | print("Accepted Nodes:", accept_nodes) 90 | 91 | try: 92 | plot_and_save_graph(accept_nodes, args.output_path) 93 | print(f"Graph saved to {args.output_path}.") 94 | except Exception as e: 95 | print(f"Failed to save the graph due to the following error: {e}") 96 | print("Ensure that Graphviz and pygraphviz are installed and set up correctly.") 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /medusa/eval/heads_accuracy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import json 4 | from contextlib import contextmanager 5 | import numpy as np 6 | from medusa.model.medusa_model import MedusaModel 7 | from medusa.model.kv_cache import * 8 | from medusa.model.utils import * 9 | from medusa.model.medusa_choices import * 10 | from copy import deepcopy 11 | import matplotlib.pyplot as plt 12 | import torch.nn.functional as F 13 | from fastchat.model.model_adapter import get_conversation_template 14 | from tqdm import tqdm 15 | import argparse 16 | 17 | def get_accuracies(medusa, logit): 18 | # get the correct counts of each head 19 | seq_len, choices, topk = medusa.shape 20 | results = [] 21 | for choice in range(choices): 22 | results.append(medusa[:-choice - 1,choice].eq(logit[choice + 1:,0])) 23 | return results 24 | 25 | 26 | 27 | def main(args): 28 | model = MedusaModel.from_pretrained( 29 | args.model_path, 30 | # medusa_num_heads=args.medusa_num_heads, 31 | torch_dtype=torch.float16, 32 | low_cpu_mem_usage=True, 33 | device_map="auto" 34 | ) 35 | tokenizer = model.get_tokenizer() 36 | 37 | 38 | data = json.load(open(args.data_path)) 39 | past_key_values, past_key_values_data, current_length_data = initialize_past_key_values(model.base_model) 40 | model.past_key_values = past_key_values 41 | model.past_key_values_data = past_key_values_data 42 | model.current_length_data = current_length_data 43 | results = None 44 | 45 | for sample in tqdm((data)): 46 | conv = get_conversation_template("vicuna") 47 | conv.messages = [] 48 | conv.append_message(conv.roles[0], sample["instruction"]) 49 | conv.append_message(conv.roles[1], "") 50 | prompt = conv.get_prompt() 51 | steps = args.steps 52 | logits_ids = [] 53 | medusa_topk_ids = [] 54 | 55 | with torch.inference_mode(): 56 | input_ids = tokenizer([prompt]).input_ids 57 | input_ids = torch.as_tensor(input_ids).cuda() 58 | model.current_length_data.zero_() # this is for rerun 59 | reset_medusa_mode(model) 60 | medusa_logits, outputs, logits = model( 61 | input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True 62 | ) 63 | _, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1) 64 | input_id = logits[:, -1:].argmax(dim=-1) 65 | logits_ids.append(input_id.detach().cpu()) 66 | medusa_topk_ids.append(medusa_topk.detach().cpu()) 67 | for _ in range(steps): 68 | medusa_logits, outputs, logits = model( 69 | input_id, past_key_values=past_key_values, output_orig=True, medusa_forward=True 70 | ) 71 | _, medusa_topk = medusa_logits[...,-1,:].topk(20, dim=-1) 72 | input_id = logits[:, -1:].argmax(dim=-1) 73 | logits_ids.append(input_id.detach().cpu()) 74 | medusa_topk_ids.append(medusa_topk.detach().cpu()) 75 | logits_ids = torch.stack(logits_ids, dim=0) 76 | medusa_topk_ids = torch.stack(medusa_topk_ids, dim=0).squeeze(2) 77 | if results is None: 78 | results = get_accuracies(medusa_topk_ids, logits_ids) 79 | else: 80 | # cat sub results 81 | cur_results = get_accuracies(medusa_topk_ids, logits_ids) 82 | for i in range(len(results)): 83 | results[i] = torch.cat((results[i], cur_results[i]), dim=0) 84 | 85 | save_path = os.path.join(args.save_dir, args.model_name + "_heads_accuracy.pt") 86 | torch.save(results, save_path) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser(description="Medusa Model Evaluator") 90 | 91 | parser.add_argument("--model_path", type=str, required=True, 92 | help="Path to the pre-trained Medusa model.") 93 | parser.add_argument("--model_name", type=str, required=True, 94 | help="Name of the model.") 95 | parser.add_argument("--medusa_num_heads", type=int, default=5, 96 | help="Number of medusa heads.") 97 | parser.add_argument("--data_path", type=str, required=True, 98 | help="Path to the evaluation data in JSON format.") 99 | parser.add_argument("--save_dir", type=str, default="../../data", 100 | help="Directory to save the results.") 101 | parser.add_argument("--steps", type=int, default=20, 102 | help="Number of steps to run the model.") 103 | args = parser.parse_args() 104 | 105 | # If the save directory doesn't exist, create it 106 | if not os.path.exists(args.save_dir): 107 | os.makedirs(args.save_dir) 108 | main(args) -------------------------------------------------------------------------------- /medusa/hf_utils.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser("Upload Medusa model to HuggingFace Hub") 5 | parser.add_argument("--folder", type=str, help="Path to model folder") 6 | parser.add_argument("--repo", type=str, help="Repo name") 7 | parser.add_argument("--private", action="store_true", help="Make repo private") 8 | 9 | args = parser.parse_args() 10 | 11 | api = HfApi() 12 | 13 | api.create_repo( 14 | repo_id=args.repo, 15 | private=args.private, 16 | exist_ok=True, 17 | ) 18 | 19 | api.upload_folder( 20 | folder_path=args.folder, 21 | repo_id=args.repo, 22 | ) -------------------------------------------------------------------------------- /medusa/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/inference/__init__.py -------------------------------------------------------------------------------- /medusa/inference/cli.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py 2 | """ 3 | Chat with a model with command line interface. 4 | 5 | Usage: 6 | python3 -m medusa.inference.cli --model 7 | Other commands: 8 | - Type "!!exit" or an empty line to exit. 9 | - Type "!!reset" to start a new conversation. 10 | - Type "!!remove" to remove the last prompt. 11 | - Type "!!regen" to regenerate the last message. 12 | - Type "!!save " to save the conversation history to a json file. 13 | - Type "!!load " to load a conversation history from a json file. 14 | """ 15 | import argparse 16 | import os 17 | import re 18 | import sys 19 | import torch 20 | from fastchat.serve.cli import SimpleChatIO, RichChatIO, ProgrammaticChatIO 21 | from fastchat.model.model_adapter import get_conversation_template 22 | from fastchat.conversation import get_conv_template 23 | import json 24 | from medusa.model.medusa_model import MedusaModel 25 | 26 | 27 | def main(args): 28 | if args.style == "simple": 29 | chatio = SimpleChatIO(args.multiline) 30 | elif args.style == "rich": 31 | chatio = RichChatIO(args.multiline, args.mouse) 32 | elif args.style == "programmatic": 33 | chatio = ProgrammaticChatIO() 34 | else: 35 | raise ValueError(f"Invalid style for console: {args.style}") 36 | try: 37 | model = MedusaModel.from_pretrained( 38 | args.model, 39 | torch_dtype=torch.float16, 40 | low_cpu_mem_usage=True, 41 | device_map="auto", 42 | load_in_8bit=args.load_in_8bit, 43 | load_in_4bit=args.load_in_4bit, 44 | ) 45 | tokenizer = model.get_tokenizer() 46 | conv = None 47 | 48 | def new_chat(): 49 | return get_conversation_template(args.model) 50 | 51 | def reload_conv(conv): 52 | """ 53 | Reprints the conversation from the start. 54 | """ 55 | for message in conv.messages[conv.offset :]: 56 | chatio.prompt_for_output(message[0]) 57 | chatio.print_output(message[1]) 58 | 59 | while True: 60 | if not conv: 61 | conv = new_chat() 62 | 63 | try: 64 | inp = chatio.prompt_for_input(conv.roles[0]) 65 | except EOFError: 66 | inp = "" 67 | 68 | if inp == "!!exit" or not inp: 69 | print("exit...") 70 | break 71 | elif inp == "!!reset": 72 | print("resetting...") 73 | conv = new_chat() 74 | continue 75 | elif inp == "!!remove": 76 | print("removing last message...") 77 | if len(conv.messages) > conv.offset: 78 | # Assistant 79 | if conv.messages[-1][0] == conv.roles[1]: 80 | conv.messages.pop() 81 | # User 82 | if conv.messages[-1][0] == conv.roles[0]: 83 | conv.messages.pop() 84 | reload_conv(conv) 85 | else: 86 | print("No messages to remove.") 87 | continue 88 | elif inp == "!!regen": 89 | print("regenerating last message...") 90 | if len(conv.messages) > conv.offset: 91 | # Assistant 92 | if conv.messages[-1][0] == conv.roles[1]: 93 | conv.messages.pop() 94 | # User 95 | if conv.messages[-1][0] == conv.roles[0]: 96 | reload_conv(conv) 97 | # Set inp to previous message 98 | inp = conv.messages.pop()[1] 99 | else: 100 | # Shouldn't happen in normal circumstances 101 | print("No user message to regenerate from.") 102 | continue 103 | else: 104 | print("No messages to regenerate.") 105 | continue 106 | elif inp.startswith("!!save"): 107 | args = inp.split(" ", 1) 108 | 109 | if len(args) != 2: 110 | print("usage: !!save ") 111 | continue 112 | else: 113 | filename = args[1] 114 | 115 | # Add .json if extension not present 116 | if not "." in filename: 117 | filename += ".json" 118 | 119 | print("saving...", filename) 120 | with open(filename, "w") as outfile: 121 | json.dump(conv.dict(), outfile) 122 | continue 123 | elif inp.startswith("!!load"): 124 | args = inp.split(" ", 1) 125 | 126 | if len(args) != 2: 127 | print("usage: !!load ") 128 | continue 129 | else: 130 | filename = args[1] 131 | 132 | # Check if file exists and add .json if needed 133 | if not os.path.exists(filename): 134 | if (not filename.endswith(".json")) and os.path.exists( 135 | filename + ".json" 136 | ): 137 | filename += ".json" 138 | else: 139 | print("file not found:", filename) 140 | continue 141 | 142 | print("loading...", filename) 143 | with open(filename, "r") as infile: 144 | new_conv = json.load(infile) 145 | 146 | conv = get_conv_template(new_conv["template_name"]) 147 | conv.set_system_message(new_conv["system_message"]) 148 | conv.messages = new_conv["messages"] 149 | reload_conv(conv) 150 | continue 151 | 152 | conv.append_message(conv.roles[0], inp) 153 | conv.append_message(conv.roles[1], None) 154 | prompt = conv.get_prompt() 155 | 156 | try: 157 | chatio.prompt_for_output(conv.roles[1]) 158 | input_ids = tokenizer.encode(prompt, return_tensors="pt").to( 159 | model.base_model.device 160 | ) 161 | outputs = chatio.stream_output( 162 | model.medusa_generate( 163 | input_ids, 164 | temperature=args.temperature, 165 | max_steps=args.max_steps, 166 | ) 167 | ) 168 | conv.update_last_message(outputs.strip()) 169 | 170 | except KeyboardInterrupt: 171 | print("stopped generation.") 172 | # If generation didn't finish 173 | if conv.messages[-1][1] is None: 174 | conv.messages.pop() 175 | # Remove last user message, so there isn't a double up 176 | if conv.messages[-1][0] == conv.roles[0]: 177 | conv.messages.pop() 178 | 179 | reload_conv(conv) 180 | 181 | except KeyboardInterrupt: 182 | print("exit...") 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument("--model", type=str, required=True, help="Model name or path.") 188 | parser.add_argument( 189 | "--load-in-8bit", action="store_true", help="Use 8-bit quantization" 190 | ) 191 | parser.add_argument( 192 | "--load-in-4bit", action="store_true", help="Use 4-bit quantization" 193 | ) 194 | parser.add_argument( 195 | "--conv-template", type=str, default=None, help="Conversation prompt template." 196 | ) 197 | parser.add_argument( 198 | "--conv-system-msg", type=str, default=None, help="Conversation system message." 199 | ) 200 | parser.add_argument("--temperature", type=float, default=0.7) 201 | parser.add_argument("--max-steps", type=int, default=512) 202 | parser.add_argument("--no-history", action="store_true") 203 | parser.add_argument( 204 | "--style", 205 | type=str, 206 | default="simple", 207 | choices=["simple", "rich", "programmatic"], 208 | help="Display style.", 209 | ) 210 | parser.add_argument( 211 | "--multiline", 212 | action="store_true", 213 | help="Enable multiline input. Use ESC+Enter for newline.", 214 | ) 215 | parser.add_argument( 216 | "--mouse", 217 | action="store_true", 218 | help="[Rich Style]: Enable mouse support for cursor positioning.", 219 | ) 220 | parser.add_argument( 221 | "--debug", 222 | action="store_true", 223 | help="Print useful debug information (e.g., prompts)", 224 | ) 225 | args = parser.parse_args() 226 | main(args) 227 | -------------------------------------------------------------------------------- /medusa/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/model/__init__.py -------------------------------------------------------------------------------- /medusa/model/kv_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class KVCache: 5 | """ 6 | A key-value cache for the model. 7 | 8 | This class provides a mechanism to maintain a growing cache of keys and values, 9 | particularly useful for models that benefit from caching previous states, 10 | like transformers during autoregressive decoding. 11 | 12 | Attributes: 13 | data (torch.Tensor): The tensor storing keys and values. 14 | current_length (int): Current length of the data being stored. 15 | """ 16 | 17 | def __init__(self, data, current_length): 18 | """ 19 | Initialize the KVCache. 20 | 21 | Args: 22 | data (torch.Tensor): Initial tensor to store the keys and values. 23 | current_length (int): Initial length of the data. 24 | """ 25 | self.data = data 26 | self.current_length = current_length 27 | 28 | @property 29 | def shape(self): 30 | """Return the shape of the data tensor with updated length.""" 31 | return ( 32 | self.data.shape[0], 33 | self.data.shape[1], 34 | self.current_length.item(), 35 | self.data.shape[3], 36 | ) 37 | 38 | def copy(self, indices: torch.Tensor, prev_length: int, dim: int = 2): 39 | """ 40 | Copy values from the current data at specified indices to a new location. 41 | 42 | Args: 43 | indices (torch.Tensor): Indices of the data tensor to be copied. 44 | prev_length (int): Previous length before adding new data. 45 | dim (int, optional): Dimension along which copying should be performed. Default is 2. 46 | """ 47 | tgt = self.data.index_select(dim, indices) 48 | dst = self.data.narrow(dim, prev_length, tgt.shape[dim]) 49 | dst.copy_(tgt, non_blocking=True) 50 | self.current_length.fill_(prev_length + tgt.shape[dim]) 51 | 52 | def cat(self, tensor: torch.Tensor, dim: int = 2): 53 | """ 54 | Concatenate the given tensor with the current data. 55 | 56 | Args: 57 | tensor (torch.Tensor): The tensor to be concatenated. 58 | dim (int, optional): The dimension along which concatenation should be done. Default is 2. 59 | 60 | Returns: 61 | torch.Tensor: The data tensor after concatenation up to the current length. 62 | """ 63 | dst = self.data.narrow(dim, self.current_length, tensor.shape[dim]) 64 | dst.copy_(tensor) 65 | self.current_length.add_(tensor.shape[dim]) 66 | return torch.narrow(self.data, 2, 0, self.current_length) 67 | 68 | 69 | def initialize_past_key_values(model): 70 | """ 71 | Initialize past key and value states for a given transformer model. 72 | 73 | This function prepares key-value cache structures for the model, allowing it to store and reuse 74 | past key and value states during autoregressive decoding, which can improve efficiency. 75 | 76 | Args: 77 | model (nn.Module): The transformer model for which past key-value states need to be initialized. 78 | 79 | Returns: 80 | tuple: 81 | - past_key_values (list): A list of KVCache objects for each layer in the model. 82 | - past_key_values_data (torch.Tensor): The tensor that will store all keys and values. 83 | - current_length_data (torch.Tensor): A tensor tracking the current length of keys/values in the cache. 84 | """ 85 | # Extracting configuration from the model 86 | config = model.config 87 | # Initializing the batch size to 1, this can be modified if different batch sizes are required 88 | batch_size = 1 89 | # Initializing a tensor to store past keys and values for all layers 90 | past_key_values_data = torch.zeros( 91 | config.num_hidden_layers * 2, 92 | batch_size, 93 | config.num_key_value_heads, 94 | config.max_position_embeddings, 95 | config.hidden_size // config.num_attention_heads, 96 | device=model.device, 97 | dtype=model.dtype, 98 | ) 99 | # Initialize tensor to store the current length of the cached data for all layers. 100 | # [IMPORTANT] It needs to be kept on CPU for quick access and updates. 101 | current_length_data = torch.zeros( 102 | config.num_hidden_layers * 2, dtype=torch.long, device="cpu" 103 | ) 104 | # Creating a KVCache for each pair of key and value in all layers 105 | past_key_values = [] * config.num_hidden_layers 106 | for i in range(config.num_hidden_layers): 107 | past_key_values.append( 108 | [ 109 | KVCache(past_key_values_data[i * 2 + j], current_length_data[i * 2 + j]) 110 | for j in range(2) 111 | ] 112 | ) 113 | return past_key_values, past_key_values_data, current_length_data 114 | -------------------------------------------------------------------------------- /medusa/model/medusa_choices.py: -------------------------------------------------------------------------------- 1 | mc_sim_7b_63 = [[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]] 2 | vicuna_7b_stage2 = [(0,), (0, 0), (1,), (0, 1), (0, 0, 0), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 3), (3,), (0, 1, 0), (2, 0), (4,), (0, 0, 2), (0, 4), (1, 1), (1, 0, 0), (0, 0, 0, 0), (5,), (0, 0, 3), (0, 5), (0, 2, 0), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 7), (0, 0, 4), (4, 0), (1, 2), (0, 8), (7,), (0, 3, 0), (0, 0, 0, 1), (0, 0, 5), (2, 1), (0, 0, 6), (1, 0, 1), (0, 0, 1, 0), (2, 0, 0), (5, 0), (0, 9), (0, 1, 2), (8,), (0, 4, 0), (0, 2, 1), (1, 3), (0, 0, 7), (0, 0, 0, 2), (0, 0, 8), (1, 1, 0), (0, 1, 0, 0), (6, 0), (9,), (0, 1, 3), (0, 0, 0, 3), (1, 0, 2), (0, 5, 0), (3, 1), (0, 0, 2, 0), (7, 0), (1, 4)] 3 | vicuna_7b_stage1_ablation = [(0,), (0, 0), (1,), (0, 0, 0), (0, 1), (1, 0), (2,), (0, 2), (0, 0, 1), (3,), (0, 3), (0, 1, 0), (2, 0), (0, 0, 2), (0, 4), (4,), (0, 0, 0, 0), (1, 0, 0), (1, 1), (0, 0, 3), (0, 2, 0), (0, 5), (5,), (3, 0), (0, 1, 1), (0, 6), (6,), (0, 0, 4), (1, 2), (0, 0, 0, 1), (4, 0), (0, 0, 5), (0, 7), (0, 8), (0, 3, 0), (0, 0, 1, 0), (1, 0, 1), (7,), (2, 0, 0), (0, 0, 6), (2, 1), (0, 1, 2), (5, 0), (0, 2, 1), (0, 9), (0, 0, 0, 2), (0, 4, 0), (8,), (1, 3), (0, 0, 7), (0, 1, 0, 0), (1, 1, 0), (6, 0), (9,), (0, 0, 8), (0, 0, 9), (0, 5, 0), (0, 0, 2, 0), (1, 0, 2), (0, 1, 3), (0, 0, 0, 3), (3, 0, 0), (3, 1)] 4 | vicuna_7b_stage1 = [(0,), (0, 0), (1,), (2,), (0, 1), (1, 0), (3,), (0, 2), (4,), (0, 0, 0), (0, 3), (5,), (2, 0), (0, 4), (6,), (0, 5), (1, 1), (0, 0, 1), (7,), (3, 0), (0, 6), (8,), (9,), (0, 1, 0), (0, 7), (0, 8), (4, 0), (0, 0, 2), (1, 2), (0, 9), (2, 1), (5, 0), (1, 0, 0), (0, 0, 3), (1, 3), (0, 2, 0), (0, 1, 1), (0, 0, 4), (6, 0), (1, 4), (0, 0, 5), (2, 2), (0, 3, 0), (3, 1), (0, 0, 6), (7, 0), (1, 5), (1, 0, 1), (2, 0, 0), (0, 0, 7), (8, 0), (0, 0, 0, 0), (4, 1), (0, 1, 2), (0, 4, 0), (9, 0), (0, 2, 1), (2, 3), (1, 6), (0, 0, 8), (0, 5, 0), (3, 2), (5, 1)] 5 | vicuna_13b_stage2 = [(0,), (0, 0), (1,), (0, 0, 0), (0, 1), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 1, 0), (3,), (0, 3), (2, 0), (0, 0, 2), (0, 0, 0, 0), (0, 4), (1, 0, 0), (1, 1), (4,), (0, 0, 3), (0, 5), (0, 2, 0), (5,), (3, 0), (0, 1, 1), (0, 6), (0, 0, 4), (0, 0, 0, 1), (0, 7), (0, 0, 5), (1, 2), (0, 0, 1, 0), (0, 3, 0), (1, 0, 1), (4, 0), (0, 0, 6), (0, 8), (2, 0, 0), (0, 9), (6,), (7,), (2, 1), (5, 0), (0, 1, 2), (0, 0, 0, 2), (8,), (0, 4, 0), (0, 1, 0, 0), (0, 2, 1), (0, 0, 7), (1, 1, 0), (1, 3), (0, 0, 2, 0), (9,), (0, 0, 8), (0, 5, 0), (0, 0, 0, 3), (0, 0, 9), (0, 1, 3), (1, 0, 2), (0, 0, 1, 1), (3, 0, 0), (1, 0, 0, 0)] 6 | vicuna_13b_stage1 = [(0,), (0, 0), (1,), (0, 1), (2,), (1, 0), (0, 0, 0), (0, 2), (3,), (0, 3), (4,), (2, 0), (0, 4), (0, 0, 1), (0, 5), (5,), (1, 1), (0, 1, 0), (6,), (0, 6), (0, 0, 2), (7,), (3, 0), (8,), (0, 7), (0, 8), (1, 0, 0), (0, 0, 3), (4, 0), (1, 2), (9,), (0, 9), (2, 1), (0, 2, 0), (0, 0, 4), (1, 3), (0, 1, 1), (0, 0, 5), (5, 0), (0, 3, 0), (0, 0, 0, 0), (0, 0, 6), (6, 0), (1, 4), (2, 0, 0), (0, 1, 2), (3, 1), (0, 4, 0), (1, 0, 1), (2, 2), (0, 0, 7), (1, 5), (7, 0), (0, 0, 8), (8, 0), (0, 5, 0), (0, 0, 9), (0, 2, 1), (1, 1, 0), (0, 1, 3), (4, 1), (2, 3), (1, 6)] 7 | vicuna_33b_stage2 = [(0,), (0, 0), (1,), (0, 1), (0, 0, 0), (1, 0), (2,), (0, 2), (0, 0, 1), (0, 3), (3,), (0, 1, 0), (2, 0), (0, 4), (4,), (0, 0, 2), (1, 1), (1, 0, 0), (0, 5), (5,), (0, 0, 0, 0), (0, 0, 3), (3, 0), (0, 2, 0), (0, 6), (0, 1, 1), (6,), (0, 0, 4), (0, 7), (7,), (1, 2), (4, 0), (8,), (0, 3, 0), (0, 0, 5), (0, 0, 0, 1), (0, 8), (2, 1), (0, 9), (1, 0, 1), (2, 0, 0), (0, 0, 6), (5, 0), (0, 0, 1, 0), (1, 3), (0, 1, 2), (0, 4, 0), (0, 0, 7), (0, 2, 1), (9,), (1, 1, 0), (0, 0, 0, 2), (6, 0), (0, 0, 8), (0, 1, 0, 0), (7, 0), (0, 1, 3), (0, 5, 0), (1, 4), (0, 0, 9), (3, 1), (1, 0, 2), (2, 2)] 8 | vicuna_33b_stage1 = [(0,), (1,), (0, 0), (2,), (0, 1), (3,), (1, 0), (4,), (0, 2), (5,), (0, 3), (0, 0, 0), (6,), (0, 4), (2, 0), (7,), (1, 1), (0, 5), (3, 0), (8,), (9,), (0, 6), (0, 7), (0, 0, 1), (1, 2), (4, 0), (0, 1, 0), (0, 8), (0, 9), (2, 1), (0, 0, 2), (5, 0), (1, 3), (0, 0, 3), (1, 0, 0), (1, 4), (6, 0), (0, 2, 0), (3, 1), (2, 2), (0, 0, 4), (7, 0), (0, 1, 1), (1, 5), (4, 1), (0, 0, 5), (0, 3, 0), (9, 0), (8, 0), (1, 6), (0, 0, 6), (2, 3), (0, 1, 2), (3, 2), (0, 4, 0), (2, 0, 0), (1, 7), (1, 0, 1), (0, 0, 7), (5, 1), (2, 4), (0, 0, 8), (0, 2, 1)] 9 | zephyr_stage2 = [(0,), (0, 0), (1,), (0, 1), (2,), (0, 0, 0), (1, 0), (0, 2), (3,), (0, 3), (4,), (2, 0), (0, 0, 1), (0, 4), (5,), (0, 5), (0, 1, 0), (1, 1), (6,), (0, 0, 2), (3, 0), (0, 6), (7,), (0, 7), (0, 8), (0, 0, 3), (1, 0, 0), (0, 9), (0, 2, 0), (1, 2), (4, 0), (8,), (9,), (2, 1), (0, 1, 1), (0, 0, 4), (0, 0, 0, 0), (5, 0), (0, 3, 0), (1, 3), (0, 0, 5), (0, 0, 6), (6, 0), (2, 0, 0), (1, 0, 1), (0, 1, 2), (0, 4, 0), (1, 4), (3, 1), (2, 2), (0, 0, 7), (7, 0), (0, 2, 1), (0, 0, 8), (0, 1, 3), (0, 5, 0), (1, 5), (0, 0, 9), (1, 1, 0), (0, 0, 0, 1), (0, 0, 1, 0), (4, 1), (2, 3)] 10 | -------------------------------------------------------------------------------- /medusa/model/medusa_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM 4 | from .modeling_mistral_kv import MistralForCausalLM as KVMistralForCausalLM 5 | # import transformers 6 | 7 | # # monkey patch 8 | # transformers.models.llama.modeling_llama.LlamaForCausalLM = KVLlamaForCausalLM 9 | # transformers.models.mistral.modeling_mistral.MistralForCausalLM = KVMistralForCausalLM 10 | 11 | from transformers import PreTrainedModel, PretrainedConfig 12 | from .utils import * 13 | from .kv_cache import initialize_past_key_values 14 | from .medusa_choices import * 15 | from transformers import AutoTokenizer, AutoConfig 16 | import os 17 | from huggingface_hub import hf_hub_download 18 | import warnings 19 | 20 | class MedusaConfig(PretrainedConfig): 21 | """ 22 | Configuration class for Medusa model. 23 | 24 | Args: 25 | medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2. 26 | medusa_num_layers (int, optional): Number of Medusa layers. Default is 1. 27 | base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3". 28 | **kwargs: Additional keyword arguments to be passed to the parent class constructor. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | medusa_num_heads=5, 34 | medusa_num_layers=1, 35 | base_model_name_or_path="lmsys/vicuna-7b-v1.3", 36 | **kwargs, 37 | ): 38 | super().__init__(**kwargs) 39 | self.medusa_num_heads = medusa_num_heads 40 | self.medusa_num_layers = medusa_num_layers 41 | self.base_model_name_or_path = base_model_name_or_path 42 | 43 | class ResBlock(nn.Module): 44 | """ 45 | A Residual Block module. 46 | 47 | This module performs a linear transformation followed by a SiLU activation, 48 | and then adds the result to the original input, creating a residual connection. 49 | 50 | Args: 51 | hidden_size (int): The size of the hidden layers in the block. 52 | """ 53 | 54 | def __init__(self, hidden_size): 55 | super().__init__() 56 | self.linear = nn.Linear(hidden_size, hidden_size) 57 | # Initialize as an identity mapping 58 | torch.nn.init.zeros_(self.linear.weight) 59 | # Use SiLU activation to keep consistent with the Llama model 60 | self.act = nn.SiLU() 61 | 62 | def forward(self, x): 63 | """ 64 | Forward pass of the ResBlock. 65 | 66 | Args: 67 | x (torch.Tensor): Input tensor. 68 | 69 | Returns: 70 | torch.Tensor: Output after the residual connection and activation. 71 | """ 72 | return x + self.act(self.linear(x)) 73 | 74 | 75 | class MedusaModelABC(nn.Module): 76 | """The Medusa Language Model Head. 77 | 78 | This module creates a series of prediction heads (based on the 'medusa' parameter) 79 | on top of a given base model. Each head is composed of a sequence of residual blocks 80 | followed by a linear layer. 81 | """ 82 | 83 | # Load the base model 84 | # base_model_prefix = "model" 85 | # supports_gradient_checkpointing = True 86 | # _no_split_modules = ["LlamaDecoderLayer", "MistralDecoderLayer"] 87 | # _skip_keys_device_placement = "past_key_values" 88 | # _supports_flash_attn_2 = True 89 | 90 | def __init__( 91 | self, 92 | config, 93 | ): 94 | """ 95 | Args: 96 | config (PretrainedConfig): The configuration of the MedusaModel. 97 | """ 98 | super().__init__(config) 99 | # For compatibility with the old APIs 100 | 101 | medusa_num_heads = config.medusa_num_heads 102 | medusa_num_layers = config.medusa_num_layers 103 | base_model_name_or_path = config._name_or_path 104 | self.hidden_size = config.hidden_size 105 | self.vocab_size = config.vocab_size 106 | self.medusa = medusa_num_heads 107 | self.medusa_num_layers = medusa_num_layers 108 | self.base_model_name_or_path = base_model_name_or_path 109 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) 110 | # Create a list of Medusa heads 111 | self.medusa_head = nn.ModuleList( 112 | [ 113 | nn.Sequential( 114 | *([ResBlock(self.hidden_size)] * medusa_num_layers), 115 | nn.Linear(self.hidden_size, self.vocab_size, bias=False), 116 | ) 117 | for _ in range(medusa_num_heads) 118 | ] 119 | ) 120 | # Add a link named base_model to self 121 | @property 122 | def base_model(self): 123 | return self 124 | @classmethod 125 | def from_pretrained( 126 | cls, 127 | pretrained_model_name_or_path, 128 | *args, 129 | **kwargs, 130 | ): 131 | # Manually load config to ensure that the medusa_num_heads parameter is loaded 132 | try: 133 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) 134 | return super().from_pretrained( 135 | pretrained_model_name_or_path, 136 | *args, 137 | **kwargs, 138 | config=config, 139 | ) 140 | except: 141 | config = MedusaConfig.from_pretrained(pretrained_model_name_or_path) 142 | base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path) 143 | base_model_config.medusa_num_heads = 5 # TODO: fix the uploaded config (only include 2 heads) 144 | base_model_config.medusa_num_layers = config.medusa_num_layers 145 | model = super().from_pretrained( 146 | config.base_model_name_or_path, 147 | *args, 148 | **kwargs, 149 | config=base_model_config, 150 | ) 151 | medusa_head_path = os.path.join(pretrained_model_name_or_path, "medusa_lm_head.pt") 152 | if os.path.exists(medusa_head_path): 153 | filename = medusa_head_path 154 | else: 155 | filename = hf_hub_download(pretrained_model_name_or_path, "medusa_lm_head.pt") 156 | medusa_head_state_dict = torch.load(filename, map_location=model.device) 157 | model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False) 158 | return model 159 | 160 | 161 | def get_tokenizer(self): 162 | """Get the tokenizer of the base model. 163 | 164 | Returns: 165 | Tokenizer: The tokenizer of the base model. 166 | """ 167 | return self.tokenizer 168 | 169 | 170 | def forward( 171 | self, 172 | input_ids=None, 173 | attention_mask=None, 174 | past_key_values=None, 175 | output_orig=False, 176 | position_ids=None, 177 | medusa_forward=False, 178 | **kwargs, 179 | ): 180 | """Forward pass of the MedusaModel. 181 | 182 | Args: 183 | input_ids (torch.Tensor, optional): Input token IDs. 184 | attention_mask (torch.Tensor, optional): Attention mask. 185 | labels (torch.Tensor, optional): Ground truth labels for loss computation. 186 | past_key_values (tuple, optional): Tuple containing past key and value states for attention. 187 | output_orig (bool, optional): Whether to also output predictions from the original LM head. 188 | position_ids (torch.Tensor, optional): Position IDs. 189 | 190 | Returns: 191 | torch.Tensor: A tensor containing predictions from all Medusa heads. 192 | (Optional) Original predictions from the base model's LM head. 193 | """ 194 | if not medusa_forward: 195 | return super().forward( 196 | input_ids=input_ids, 197 | attention_mask=attention_mask, 198 | past_key_values=past_key_values, 199 | position_ids=position_ids, 200 | **kwargs, 201 | ) 202 | with torch.inference_mode(): 203 | # Pass input through the base model 204 | outputs = self.base_model.model( 205 | input_ids=input_ids, 206 | attention_mask=attention_mask, 207 | past_key_values=past_key_values, 208 | position_ids=position_ids, 209 | **kwargs, 210 | ) 211 | if output_orig: 212 | orig = self.base_model.lm_head(outputs[0]) 213 | # Clone the output hidden states 214 | hidden_states = outputs[0].clone() 215 | medusa_logits = [] 216 | # TODO: Consider parallelizing this loop for efficiency? 217 | for i in range(self.medusa): 218 | medusa_logits.append(self.medusa_head[i](hidden_states)) 219 | if output_orig: 220 | return torch.stack(medusa_logits, dim=0), outputs, orig 221 | return torch.stack(medusa_logits, dim=0) 222 | def get_medusa_choice(self, model_name): 223 | if 'vicuna' in model_name: 224 | if '7b' in model_name: 225 | return vicuna_7b_stage2 226 | elif '13b' in model_name: 227 | return vicuna_13b_stage2 228 | elif '33b' in model_name: 229 | return vicuna_33b_stage2 230 | elif 'zephyr' in model_name: 231 | return zephyr_stage2 232 | warnings.warn('Please specify medusa choice configuration!') 233 | return mc_sim_7b_63 234 | 235 | def medusa_generate( 236 | self, 237 | input_ids, 238 | attention_mask=None, 239 | temperature=0.0, 240 | max_steps=512, 241 | # The hyperparameters below are for the Medusa 242 | # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token. 243 | medusa_choices=None, 244 | posterior_threshold=0.09, # threshold validation of Medusa output 245 | # another threshold hyperparameter, recommended to be sqrt(posterior_threshold) 246 | posterior_alpha=0.3, 247 | top_p=0.8, 248 | sampling = 'typical', 249 | fast = True 250 | ): 251 | """ 252 | Args: 253 | input_ids (torch.Tensor, optional): Input token IDs. 254 | attention_mask (torch.Tensor, optional): Attention mask. 255 | temperature (float, optional): Temperature for typical acceptance. 256 | medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head. 257 | posterior_threshold (float, optional): Threshold for posterior validation. 258 | posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold). 259 | top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8. 260 | sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'. 261 | fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False. 262 | Returns: 263 | torch.Tensor: Output token IDs. 264 | 265 | Warning: Only support batch size 1 for now!! 266 | """ 267 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 268 | # Avoid modifying the input_ids in-place 269 | input_ids = input_ids.clone() 270 | 271 | # Cache medusa buffers (the fixed patterns for tree attention) 272 | if medusa_choices is None: 273 | medusa_choices = self.get_medusa_choice(self.base_model_name_or_path) 274 | 275 | if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices: 276 | # Load the cached medusa buffer 277 | medusa_buffers = self.medusa_buffers 278 | else: 279 | # Initialize the medusa buffer 280 | medusa_buffers = generate_medusa_buffers( 281 | medusa_choices, device=self.base_model.device 282 | ) 283 | self.medusa_buffers = medusa_buffers 284 | self.medusa_choices = medusa_choices 285 | 286 | # Initialize the past key and value states 287 | if hasattr(self, "past_key_values"): 288 | past_key_values = self.past_key_values 289 | past_key_values_data = self.past_key_values_data 290 | current_length_data = self.current_length_data 291 | # Reset the past key and value states 292 | current_length_data.zero_() 293 | else: 294 | ( 295 | past_key_values, 296 | past_key_values_data, 297 | current_length_data, 298 | ) = initialize_past_key_values(self.base_model) 299 | self.past_key_values = past_key_values 300 | self.past_key_values_data = past_key_values_data 301 | self.current_length_data = current_length_data 302 | 303 | input_len = input_ids.shape[1] 304 | 305 | reset_medusa_mode(self) 306 | # Initialize tree attention mask and process prefill tokens 307 | medusa_logits, logits = initialize_medusa( 308 | input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values 309 | ) 310 | 311 | new_token = 0 312 | last_round_token = 0 313 | 314 | for idx in range(max_steps): 315 | # Generate candidates with topk predictions from Medusa heads 316 | candidates, tree_candidates = generate_candidates( 317 | medusa_logits, 318 | logits, 319 | medusa_buffers["tree_indices"], 320 | medusa_buffers["retrieve_indices"], 321 | temperature=temperature, 322 | posterior_alpha=posterior_alpha, 323 | posterior_threshold=posterior_threshold, 324 | top_p=top_p, 325 | sampling=sampling, 326 | fast=fast, 327 | ) 328 | 329 | # Use tree attention to verify the candidates and get predictions 330 | medusa_logits, logits, outputs = tree_decoding( 331 | self, 332 | tree_candidates, 333 | past_key_values, 334 | medusa_buffers["medusa_position_ids"], 335 | input_ids, 336 | medusa_buffers["retrieve_indices"], 337 | ) 338 | 339 | # Evaluate the posterior of the candidates to select the accepted candidate prefix 340 | best_candidate, accept_length = evaluate_posterior( 341 | logits, candidates, temperature, posterior_threshold, posterior_alpha, top_p=top_p, sampling=sampling, fast=fast 342 | ) 343 | 344 | # Update the input_ids and logits 345 | input_ids, logits, medusa_logits, new_token = update_inference_inputs( 346 | input_ids, 347 | candidates, 348 | best_candidate, 349 | accept_length, 350 | medusa_buffers["retrieve_indices"], 351 | outputs, 352 | logits, 353 | medusa_logits, 354 | new_token, 355 | past_key_values_data, 356 | current_length_data, 357 | ) 358 | 359 | yield { 360 | "text": self.tokenizer.decode( 361 | input_ids[0, input_len:], 362 | skip_special_tokens=True, 363 | spaces_between_special_tokens=False, 364 | clean_up_tokenization_spaces=True, 365 | ) 366 | } 367 | 368 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]: 369 | break 370 | 371 | 372 | class MedusaModelLlama(MedusaModelABC, KVLlamaForCausalLM): 373 | pass 374 | 375 | class MedusaModelMistral(MedusaModelABC, KVMistralForCausalLM): 376 | pass 377 | 378 | 379 | class MedusaModel(): 380 | @classmethod 381 | def from_pretrained( 382 | cls, 383 | pretrained_model_name_or_path, 384 | *args, 385 | **kwargs, 386 | ): 387 | # Manually load config to ensure that the medusa_num_heads parameter is loaded 388 | try: 389 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) 390 | except: 391 | # MEDUSA-v0.1 load 392 | config = MedusaConfig.from_pretrained(pretrained_model_name_or_path) 393 | base_model_config = AutoConfig.from_pretrained(config.base_model_name_or_path) 394 | config.model_type = base_model_config.model_type 395 | 396 | if config.model_type == "llama": 397 | return MedusaModelLlama.from_pretrained( 398 | pretrained_model_name_or_path, 399 | *args, 400 | **kwargs, 401 | ) 402 | elif config.model_type == "mistral": 403 | return MedusaModelMistral.from_pretrained( 404 | pretrained_model_name_or_path, 405 | *args, 406 | **kwargs, 407 | ) 408 | else: 409 | raise ValueError("Only support llama and mistral for now!!") 410 | -------------------------------------------------------------------------------- /medusa/model/medusa_model_legacy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import PreTrainedModel, PretrainedConfig 4 | from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM 5 | from .utils import * 6 | from .kv_cache import initialize_past_key_values 7 | from .medusa_choices import mc_sim_7b_63 8 | from transformers import AutoTokenizer 9 | import os 10 | from huggingface_hub import hf_hub_download 11 | 12 | 13 | class MedusaConfig(PretrainedConfig): 14 | """ 15 | Configuration class for Medusa model. 16 | 17 | Args: 18 | medusa_num_heads (int, optional): Number of heads for the Medusa layer. Default is 2. 19 | medusa_num_layers (int, optional): Number of Medusa layers. Default is 1. 20 | base_model_name_or_path (str, optional): The name or path of the base model. Default is "lmsys/vicuna-7b-v1.3". 21 | **kwargs: Additional keyword arguments to be passed to the parent class constructor. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | medusa_num_heads=4, 27 | medusa_num_layers=1, 28 | version="2", 29 | base_model_name_or_path="lmsys/vicuna-7b-v1.3", 30 | **kwargs, 31 | ): 32 | super().__init__(**kwargs) 33 | self.medusa_num_heads = medusa_num_heads 34 | self.medusa_num_layers = medusa_num_layers 35 | self.version = version 36 | self.base_model_name_or_path = base_model_name_or_path 37 | 38 | 39 | class ResBlock(nn.Module): 40 | """ 41 | A Residual Block module. 42 | 43 | This module performs a linear transformation followed by a SiLU activation, 44 | and then adds the result to the original input, creating a residual connection. 45 | 46 | Args: 47 | hidden_size (int): The size of the hidden layers in the block. 48 | """ 49 | 50 | def __init__(self, hidden_size): 51 | super().__init__() 52 | self.linear = nn.Linear(hidden_size, hidden_size) 53 | # Initialize as an identity mapping 54 | torch.nn.init.zeros_(self.linear.weight) 55 | # Use SiLU activation to keep consistent with the Llama model 56 | self.act = nn.SiLU() 57 | 58 | def forward(self, x): 59 | """ 60 | Forward pass of the ResBlock. 61 | 62 | Args: 63 | x (torch.Tensor): Input tensor. 64 | 65 | Returns: 66 | torch.Tensor: Output after the residual connection and activation. 67 | """ 68 | return x + self.act(self.linear(x)) 69 | 70 | 71 | class MedusaModel(nn.Module): 72 | """The Medusa Language Model Head. 73 | 74 | This module creates a series of prediction heads (based on the 'medusa' parameter) 75 | on top of a given base model. Each head is composed of a sequence of residual blocks 76 | followed by a linear layer. 77 | """ 78 | 79 | def __init__( 80 | self, 81 | base_model, 82 | medusa_num_heads=4, 83 | medusa_num_layers=1, 84 | base_model_name_or_path="lmsys/vicuna-7b-v1.3", 85 | ): 86 | """ 87 | Args: 88 | base_model (nn.Module): The base language model to be used. 89 | medusa_num_heads (int, optional): Number of additional tokens to predict. Defaults to 3. 90 | medusa_num_layers (int, optional): Number of ResBlock layers for each Medusa head. Defaults to 0. 91 | """ 92 | super().__init__() 93 | self.base_model = base_model 94 | self.config = base_model.config 95 | self.hidden_size = base_model.config.hidden_size 96 | self.vocab_size = base_model.config.vocab_size 97 | self.medusa = medusa_num_heads 98 | self.medusa_num_layers = medusa_num_layers 99 | self.base_model_name_or_path = base_model_name_or_path 100 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) 101 | # Create a list of Medusa heads 102 | self.medusa_head = nn.ModuleList( 103 | [ 104 | nn.Sequential( 105 | *([ResBlock(self.hidden_size)] * medusa_num_layers), 106 | ) 107 | for _ in range(medusa_num_heads) 108 | ] 109 | ) 110 | 111 | # Ensure medusa_head's dtype and device align with the base_model 112 | self.medusa_head.to(self.base_model.dtype).to(self.base_model.device) 113 | 114 | def get_tokenizer(self): 115 | """Get the tokenizer of the base model. 116 | 117 | Returns: 118 | Tokenizer: The tokenizer of the base model. 119 | """ 120 | return self.tokenizer 121 | 122 | @classmethod 123 | def from_pretrained( 124 | cls, 125 | medusa_head_name_or_path, 126 | base_model=None, 127 | medusa_num_heads=None, 128 | **kwargs, 129 | ): 130 | """ 131 | Args: 132 | medusa_head_name_or_path (str): Name or path of the Medusa head to load. 133 | **kwargs: Additional keyword arguments for loading the base model. 134 | 135 | Returns: 136 | MedusaModel: A MedusaModel instance loaded from the given path. 137 | """ 138 | medusa_config = MedusaConfig.from_pretrained(medusa_head_name_or_path) 139 | if medusa_num_heads is not None: 140 | print("Overriding medusa_num_heads as:", medusa_num_heads) 141 | medusa_config.medusa_num_heads = medusa_num_heads 142 | if base_model is not None: 143 | print("Overriding base_model as:", base_model) 144 | medusa_config.base_model_name_or_path = base_model 145 | 146 | base_model = KVLlamaForCausalLM.from_pretrained( 147 | medusa_config.base_model_name_or_path, **kwargs 148 | ) 149 | 150 | model = cls( 151 | base_model, 152 | medusa_config.medusa_num_heads, 153 | medusa_config.medusa_num_layers, 154 | medusa_config.base_model_name_or_path, 155 | ) 156 | medusa_head_path = os.path.join(medusa_head_name_or_path, "medusa_lm_head.pt") 157 | if os.path.exists(medusa_head_path): 158 | filename = medusa_head_path 159 | else: 160 | filename = hf_hub_download(medusa_head_name_or_path, "medusa_lm_head.pt") 161 | medusa_head_state_dict = torch.load(filename, map_location=base_model.device) 162 | model.medusa_head.load_state_dict(medusa_head_state_dict, strict=False) 163 | 164 | return model 165 | 166 | def forward( 167 | self, 168 | input_ids=None, 169 | attention_mask=None, 170 | labels=None, 171 | past_key_values=None, 172 | output_orig=False, 173 | position_ids=None, 174 | ): 175 | """Forward pass of the MedusaModel. 176 | 177 | Args: 178 | input_ids (torch.Tensor, optional): Input token IDs. 179 | attention_mask (torch.Tensor, optional): Attention mask. 180 | labels (torch.Tensor, optional): Ground truth labels for loss computation. 181 | past_key_values (tuple, optional): Tuple containing past key and value states for attention. 182 | output_orig (bool, optional): Whether to also output predictions from the original LM head. 183 | position_ids (torch.Tensor, optional): Position IDs. 184 | 185 | Returns: 186 | torch.Tensor: A tensor containing predictions from all Medusa heads. 187 | (Optional) Original predictions from the base model's LM head. 188 | """ 189 | with torch.no_grad(): 190 | # Pass input through the base model 191 | outputs = self.base_model.model( 192 | input_ids=input_ids, 193 | attention_mask=attention_mask, 194 | past_key_values=past_key_values, 195 | position_ids=position_ids, 196 | ) 197 | if output_orig: 198 | orig = self.base_model.lm_head(outputs[0]) 199 | # Clone the output hidden states 200 | hidden_states = outputs[0].clone() 201 | medusa_logits = [] 202 | # TODO: Consider parallelizing this loop for efficiency? 203 | for i in range(self.medusa): 204 | mhidden_states = self.medusa_head[i](hidden_states) 205 | mlogits = self.base_model.lm_head(mhidden_states) 206 | medusa_logits.append(mlogits) 207 | if output_orig: 208 | return torch.stack(medusa_logits, dim=0), outputs, orig 209 | return torch.stack(medusa_logits, dim=0) 210 | 211 | def medusa_generate( 212 | self, 213 | input_ids, 214 | attention_mask=None, 215 | temperature=0.0, 216 | max_steps=512, 217 | # The hyperparameters below are for the Medusa 218 | # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token. 219 | medusa_choices=mc_sim_7b_63, 220 | posterior_threshold=0.09, # threshold validation of Medusa output 221 | # another threshold hyperparameter, recommended to be sqrt(posterior_threshold) 222 | posterior_alpha=0.3, 223 | ): 224 | """ 225 | Args: 226 | input_ids (torch.Tensor, optional): Input token IDs. 227 | attention_mask (torch.Tensor, optional): Attention mask. 228 | temperature (float, optional): Temperature for typical acceptance. 229 | medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head. 230 | posterior_threshold (float, optional): Threshold for posterior validation. 231 | posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold). 232 | Returns: 233 | torch.Tensor: Output token IDs. 234 | 235 | Warning: Only support batch size 1 for now!! 236 | """ 237 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 238 | # Avoid modifying the input_ids in-place 239 | input_ids = input_ids.clone() 240 | 241 | # Cache medusa buffers (the fixed patterns for tree attention) 242 | if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices: 243 | # Load the cached medusa buffer 244 | medusa_buffers = self.medusa_buffers 245 | else: 246 | # Initialize the medusa buffer 247 | medusa_buffers = generate_medusa_buffers( 248 | medusa_choices, device=self.base_model.device 249 | ) 250 | self.medusa_buffers = medusa_buffers 251 | self.medusa_choices = medusa_choices 252 | 253 | 254 | # Initialize the past key and value states 255 | if hasattr(self, "past_key_values"): 256 | past_key_values = self.past_key_values 257 | past_key_values_data = self.past_key_values_data 258 | current_length_data = self.current_length_data 259 | # Reset the past key and value states 260 | current_length_data.zero_() 261 | else: 262 | ( 263 | past_key_values, 264 | past_key_values_data, 265 | current_length_data, 266 | ) = initialize_past_key_values(self.base_model) 267 | self.past_key_values = past_key_values 268 | self.past_key_values_data = past_key_values_data 269 | self.current_length_data = current_length_data 270 | 271 | input_len = input_ids.shape[1] 272 | 273 | reset_medusa_mode(self) 274 | # Initialize tree attention mask and process prefill tokens 275 | medusa_logits, logits = initialize_medusa( 276 | input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values 277 | ) 278 | 279 | new_token = 0 280 | last_round_token = 0 281 | 282 | for idx in range(max_steps): 283 | # Generate candidates with topk predictions from Medusa heads 284 | candidates, tree_candidates = generate_candidates( 285 | medusa_logits, 286 | logits, 287 | medusa_buffers["tree_indices"], 288 | medusa_buffers["retrieve_indices"], 289 | ) 290 | 291 | # Use tree attention to verify the candidates and get predictions 292 | medusa_logits, logits, outputs = tree_decoding( 293 | self, 294 | tree_candidates, 295 | past_key_values, 296 | medusa_buffers["medusa_position_ids"], 297 | input_ids, 298 | medusa_buffers["retrieve_indices"], 299 | ) 300 | 301 | # Evaluate the posterior of the candidates to select the accepted candidate prefix 302 | best_candidate, accept_length = evaluate_posterior( 303 | logits, candidates, temperature, posterior_threshold, posterior_alpha 304 | ) 305 | 306 | # Update the input_ids and logits 307 | input_ids, logits, medusa_logits, new_token = update_inference_inputs( 308 | input_ids, 309 | candidates, 310 | best_candidate, 311 | accept_length, 312 | medusa_buffers["retrieve_indices"], 313 | outputs, 314 | logits, 315 | medusa_logits, 316 | new_token, 317 | past_key_values_data, 318 | current_length_data, 319 | ) 320 | 321 | yield { 322 | "text": self.tokenizer.decode( 323 | input_ids[0, input_len:], 324 | skip_special_tokens=True, 325 | spaces_between_special_tokens=False, 326 | clean_up_tokenization_spaces=True, 327 | ) 328 | } 329 | 330 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]: 331 | break 332 | -------------------------------------------------------------------------------- /medusa/model/medusa_model_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import PreTrainedModel, PretrainedConfig 4 | from .modeling_llama_kv import LlamaForCausalLM as KVLlamaForCausalLM 5 | from .modeling_mistral_kv import MistralForCausalLM as KVMistralForCausalLM 6 | from .utils import * 7 | from .kv_cache import initialize_past_key_values 8 | from .medusa_choices import mc_sim_7b_63 9 | from transformers import AutoTokenizer, AutoConfig 10 | import os 11 | from huggingface_hub import hf_hub_download 12 | 13 | 14 | class ResBlock(nn.Module): 15 | """ 16 | A Residual Block module. 17 | 18 | This module performs a linear transformation followed by a SiLU activation, 19 | and then adds the result to the original input, creating a residual connection. 20 | 21 | Args: 22 | hidden_size (int): The size of the hidden layers in the block. 23 | """ 24 | 25 | def __init__(self, hidden_size): 26 | super().__init__() 27 | self.linear = nn.Linear(hidden_size, hidden_size) 28 | # Initialize as an identity mapping 29 | torch.nn.init.zeros_(self.linear.weight) 30 | # Use SiLU activation to keep consistent with the Llama model 31 | self.act = nn.SiLU() 32 | 33 | def forward(self, x): 34 | """ 35 | Forward pass of the ResBlock. 36 | 37 | Args: 38 | x (torch.Tensor): Input tensor. 39 | 40 | Returns: 41 | torch.Tensor: Output after the residual connection and activation. 42 | """ 43 | return x + self.act(self.linear(x)) 44 | 45 | class MedusaModel(PreTrainedModel): 46 | """The Medusa Language Model Head. 47 | 48 | This module creates a series of prediction heads (based on the 'medusa' parameter) 49 | on top of a given base model. Each head is composed of a sequence of residual blocks 50 | followed by a linear layer. 51 | """ 52 | 53 | def __init__( 54 | self, 55 | config, 56 | ): 57 | """ 58 | Args: 59 | config (PretrainedConfig): The configuration of the MedusaModel. 60 | """ 61 | super().__init__(config) 62 | # For compatibility with the old APIs 63 | medusa_num_heads = config.medusa_num_heads 64 | medusa_num_layers = config.medusa_num_layers 65 | base_model_name_or_path = config._name_or_path 66 | self.hidden_size = config.hidden_size 67 | self.vocab_size = config.vocab_size 68 | self.medusa = medusa_num_heads 69 | self.medusa_num_layers = medusa_num_layers 70 | self.base_model_name_or_path = base_model_name_or_path 71 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) 72 | # Create a list of Medusa heads 73 | self.medusa_head = nn.ModuleList( 74 | [ 75 | nn.Sequential( 76 | *([ResBlock(self.hidden_size)] * medusa_num_layers), 77 | nn.Linear(self.hidden_size, self.vocab_size, bias=False), 78 | ) 79 | for _ in range(medusa_num_heads) 80 | ] 81 | ) 82 | 83 | # Add a link named base_model to self 84 | @property 85 | def base_model(self): 86 | return self 87 | 88 | def get_tokenizer(self): 89 | """Get the tokenizer of the base model. 90 | 91 | Returns: 92 | Tokenizer: The tokenizer of the base model. 93 | """ 94 | return self.tokenizer 95 | 96 | @classmethod 97 | def from_pretrained( 98 | cls, 99 | pretrained_model_name_or_path, 100 | *args, 101 | **kwargs, 102 | ): 103 | # Manually load config to ensure that the medusa_num_heads parameter is loaded 104 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) 105 | return super().from_pretrained( 106 | pretrained_model_name_or_path, 107 | *args, 108 | **kwargs, 109 | config=config, 110 | ) 111 | 112 | def forward( 113 | self, 114 | input_ids=None, 115 | attention_mask=None, 116 | past_key_values=None, 117 | output_orig=False, 118 | position_ids=None, 119 | medusa_forward=False, 120 | **kwargs, 121 | ): 122 | """Forward pass of the MedusaModel. 123 | 124 | Args: 125 | input_ids (torch.Tensor, optional): Input token IDs. 126 | attention_mask (torch.Tensor, optional): Attention mask. 127 | labels (torch.Tensor, optional): Ground truth labels for loss computation. 128 | past_key_values (tuple, optional): Tuple containing past key and value states for attention. 129 | output_orig (bool, optional): Whether to also output predictions from the original LM head. 130 | position_ids (torch.Tensor, optional): Position IDs. 131 | 132 | Returns: 133 | torch.Tensor: A tensor containing predictions from all Medusa heads. 134 | (Optional) Original predictions from the base model's LM head. 135 | """ 136 | """Forward pass of the MedusaModel. 137 | 138 | Args: 139 | input_ids (torch.Tensor, optional): Input token IDs. 140 | attention_mask (torch.Tensor, optional): Attention mask. 141 | labels (torch.Tensor, optional): Ground truth labels for loss computation. 142 | past_key_values (tuple, optional): Tuple containing past key and value states for attention. 143 | output_orig (bool, optional): Whether to also output predictions from the original LM head. 144 | position_ids (torch.Tensor, optional): Position IDs. 145 | 146 | Returns: 147 | torch.Tensor: A tensor containing predictions from all Medusa heads. 148 | (Optional) Original predictions from the base model's LM head. 149 | """ 150 | if not medusa_forward: 151 | return super().forward( 152 | input_ids=input_ids, 153 | attention_mask=attention_mask, 154 | past_key_values=past_key_values, 155 | position_ids=position_ids, 156 | **kwargs, 157 | ) 158 | with torch.inference_mode(): 159 | # Pass input through the base model 160 | outputs = self.base_model.model( 161 | input_ids=input_ids, 162 | attention_mask=attention_mask, 163 | past_key_values=past_key_values, 164 | position_ids=position_ids, 165 | **kwargs, 166 | ) 167 | if output_orig: 168 | orig = self.base_model.lm_head(outputs[0]) 169 | # Clone the output hidden states 170 | hidden_states = outputs[0].clone() 171 | medusa_logits = [] 172 | # TODO: Consider parallelizing this loop for efficiency? 173 | for i in range(self.medusa): 174 | medusa_logits.append(self.medusa_head[i](hidden_states)) 175 | if output_orig: 176 | return torch.stack(medusa_logits, dim=0), outputs, orig 177 | return torch.stack(medusa_logits, dim=0) 178 | 179 | 180 | 181 | class MedusaLlamaModel(KVLlamaForCausalLM): 182 | """The Medusa Language Model Head. 183 | 184 | This module creates a series of prediction heads (based on the 'medusa' parameter) 185 | on top of a given base model. Each head is composed of a sequence of residual blocks 186 | followed by a linear layer. 187 | """ 188 | 189 | def __init__( 190 | self, 191 | config, 192 | ): 193 | """ 194 | Args: 195 | config (PretrainedConfig): The configuration of the MedusaModel. 196 | """ 197 | # Load the base model 198 | super().__init__(config) 199 | # For compatibility with the old APIs 200 | 201 | medusa_num_heads = config.medusa_num_heads 202 | medusa_num_layers = config.medusa_num_layers 203 | base_model_name_or_path = config._name_or_path 204 | self.hidden_size = config.hidden_size 205 | self.vocab_size = config.vocab_size 206 | self.medusa = medusa_num_heads 207 | self.medusa_num_layers = medusa_num_layers 208 | self.base_model_name_or_path = base_model_name_or_path 209 | self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name_or_path) 210 | # Create a list of Medusa heads 211 | self.medusa_head = nn.ModuleList( 212 | [ 213 | nn.Sequential( 214 | *([ResBlock(self.hidden_size)] * medusa_num_layers), 215 | nn.Linear(self.hidden_size, self.vocab_size, bias=False), 216 | ) 217 | for _ in range(medusa_num_heads) 218 | ] 219 | ) 220 | 221 | # Add a link named base_model to self 222 | @property 223 | def base_model(self): 224 | return self 225 | 226 | def get_tokenizer(self): 227 | """Get the tokenizer of the base model. 228 | 229 | Returns: 230 | Tokenizer: The tokenizer of the base model. 231 | """ 232 | return self.tokenizer 233 | 234 | @classmethod 235 | def from_pretrained( 236 | cls, 237 | pretrained_model_name_or_path, 238 | *args, 239 | **kwargs, 240 | ): 241 | # Manually load config to ensure that the medusa_num_heads parameter is loaded 242 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path) 243 | return super().from_pretrained( 244 | pretrained_model_name_or_path, 245 | *args, 246 | **kwargs, 247 | config=config, 248 | ) 249 | 250 | def forward( 251 | self, 252 | input_ids=None, 253 | attention_mask=None, 254 | past_key_values=None, 255 | output_orig=False, 256 | position_ids=None, 257 | medusa_forward=False, 258 | **kwargs, 259 | ): 260 | """Forward pass of the MedusaModel. 261 | 262 | Args: 263 | input_ids (torch.Tensor, optional): Input token IDs. 264 | attention_mask (torch.Tensor, optional): Attention mask. 265 | labels (torch.Tensor, optional): Ground truth labels for loss computation. 266 | past_key_values (tuple, optional): Tuple containing past key and value states for attention. 267 | output_orig (bool, optional): Whether to also output predictions from the original LM head. 268 | position_ids (torch.Tensor, optional): Position IDs. 269 | 270 | Returns: 271 | torch.Tensor: A tensor containing predictions from all Medusa heads. 272 | (Optional) Original predictions from the base model's LM head. 273 | """ 274 | if not medusa_forward: 275 | return super().forward( 276 | input_ids=input_ids, 277 | attention_mask=attention_mask, 278 | past_key_values=past_key_values, 279 | position_ids=position_ids, 280 | **kwargs, 281 | ) 282 | with torch.inference_mode(): 283 | # Pass input through the base model 284 | outputs = self.base_model.model( 285 | input_ids=input_ids, 286 | attention_mask=attention_mask, 287 | past_key_values=past_key_values, 288 | position_ids=position_ids, 289 | **kwargs, 290 | ) 291 | if output_orig: 292 | orig = self.base_model.lm_head(outputs[0]) 293 | # Clone the output hidden states 294 | hidden_states = outputs[0].clone() 295 | medusa_logits = [] 296 | # TODO: Consider parallelizing this loop for efficiency? 297 | for i in range(self.medusa): 298 | medusa_logits.append(self.medusa_head[i](hidden_states)) 299 | if output_orig: 300 | return torch.stack(medusa_logits, dim=0), outputs, orig 301 | return torch.stack(medusa_logits, dim=0) 302 | 303 | def medusa_generate( 304 | self, 305 | input_ids, 306 | attention_mask=None, 307 | temperature=0.0, 308 | max_steps=512, 309 | # The hyperparameters below are for the Medusa 310 | # top-1 prediciton for the next token, top-7 predictions for the next token, top-6 predictions for the next next token. 311 | medusa_choices=mc_sim_7b_63, 312 | posterior_threshold=0.09, # threshold validation of Medusa output 313 | # another threshold hyperparameter, recommended to be sqrt(posterior_threshold) 314 | posterior_alpha=0.3, 315 | ): 316 | """ 317 | Args: 318 | input_ids (torch.Tensor, optional): Input token IDs. 319 | attention_mask (torch.Tensor, optional): Attention mask. 320 | temperature (float, optional): Temperature for typical acceptance. 321 | medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head. 322 | posterior_threshold (float, optional): Threshold for posterior validation. 323 | posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold). 324 | Returns: 325 | torch.Tensor: Output token IDs. 326 | 327 | Warning: Only support batch size 1 for now!! 328 | """ 329 | assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!" 330 | # Avoid modifying the input_ids in-place 331 | input_ids = input_ids.clone() 332 | 333 | # Cache medusa buffers (the fixed patterns for tree attention) 334 | if hasattr(self, "medusa_choices") and self.medusa_choices == medusa_choices: 335 | # Load the cached medusa buffer 336 | medusa_buffers = self.medusa_buffers 337 | else: 338 | # Initialize the medusa buffer 339 | medusa_buffers = generate_medusa_buffers( 340 | medusa_choices, device=self.base_model.device 341 | ) 342 | self.medusa_buffers = medusa_buffers 343 | self.medusa_choices = medusa_choices 344 | 345 | 346 | # Initialize the past key and value states 347 | if hasattr(self, "past_key_values"): 348 | past_key_values = self.past_key_values 349 | past_key_values_data = self.past_key_values_data 350 | current_length_data = self.current_length_data 351 | # Reset the past key and value states 352 | current_length_data.zero_() 353 | else: 354 | ( 355 | past_key_values, 356 | past_key_values_data, 357 | current_length_data, 358 | ) = initialize_past_key_values(self.base_model) 359 | self.past_key_values = past_key_values 360 | self.past_key_values_data = past_key_values_data 361 | self.current_length_data = current_length_data 362 | 363 | input_len = input_ids.shape[1] 364 | 365 | reset_medusa_mode(self) 366 | # Initialize tree attention mask and process prefill tokens 367 | medusa_logits, logits = initialize_medusa( 368 | input_ids, self, medusa_buffers["medusa_attn_mask"], past_key_values 369 | ) 370 | 371 | new_token = 0 372 | last_round_token = 0 373 | 374 | for idx in range(max_steps): 375 | # Generate candidates with topk predictions from Medusa heads 376 | candidates, tree_candidates = generate_candidates( 377 | medusa_logits, 378 | logits, 379 | medusa_buffers["tree_indices"], 380 | medusa_buffers["retrieve_indices"], 381 | ) 382 | 383 | # Use tree attention to verify the candidates and get predictions 384 | medusa_logits, logits, outputs = tree_decoding( 385 | self, 386 | tree_candidates, 387 | past_key_values, 388 | medusa_buffers["medusa_position_ids"], 389 | input_ids, 390 | medusa_buffers["retrieve_indices"], 391 | ) 392 | 393 | # Evaluate the posterior of the candidates to select the accepted candidate prefix 394 | best_candidate, accept_length = evaluate_posterior( 395 | logits, candidates, temperature, posterior_threshold, posterior_alpha 396 | ) 397 | 398 | # Update the input_ids and logits 399 | input_ids, logits, medusa_logits, new_token = update_inference_inputs( 400 | input_ids, 401 | candidates, 402 | best_candidate, 403 | accept_length, 404 | medusa_buffers["retrieve_indices"], 405 | outputs, 406 | logits, 407 | medusa_logits, 408 | new_token, 409 | past_key_values_data, 410 | current_length_data, 411 | ) 412 | 413 | yield { 414 | "text": self.tokenizer.decode( 415 | input_ids[0, input_len:], 416 | skip_special_tokens=True, 417 | spaces_between_special_tokens=False, 418 | clean_up_tokenization_spaces=True, 419 | ) 420 | } 421 | 422 | if self.tokenizer.eos_token_id in input_ids[0, input_len:]: 423 | break 424 | 425 | # Currently only support LlamaModel 426 | MedusaModel = MedusaLlamaModel -------------------------------------------------------------------------------- /medusa/model/utils_legacy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | TOPK=10 # topk for sparse tree 4 | 5 | def pad_path(path, length, pad_value=-2): 6 | """ 7 | Pad the given path list with a specific value up to a specified length. 8 | 9 | Parameters: 10 | - path (list): The original list that needs padding. 11 | - length (int): The desired length of the padded list. 12 | - pad_value (optional, default=-2): The value to use for padding. 13 | 14 | Returns: 15 | - list: A new list based on the original path but padded to the desired length. 16 | 17 | Example: 18 | >>> pad_path([1,2,3], 5) 19 | [1, 2, 3, -2, -2] 20 | 21 | Note: 22 | If the given path is already longer than the specified length, 23 | then no padding occurs, and the original path is returned. 24 | """ 25 | 26 | # Calculate the number of padding values needed by subtracting the length 27 | # of the path from the desired length. 28 | # Append the padding values to the original path and return the new list. 29 | return path + [pad_value] * (length - len(path)) 30 | 31 | def generate_medusa_buffers(medusa_choices, device="cuda"): 32 | """ 33 | Generate buffers for the Medusa structure based on the provided choices. 34 | 35 | Parameters: 36 | - medusa_choices (list): A nested list representing tree in the Medusa structure. 37 | - device (str): Device to which the tensors should be moved. Default is "cuda". 38 | 39 | Returns: 40 | - dict: A dictionary containing buffers related to the Medusa structure. 41 | """ 42 | 43 | # Sort the medusa_choices based on their lengths and then their values 44 | sorted_medusa_choices = sorted(medusa_choices, key=lambda x: (len(x), x)) 45 | medusa_len = len(sorted_medusa_choices) + 1 46 | 47 | # Initialize depth_counts to keep track of how many choices have a particular depth 48 | depth_counts = [] 49 | prev_depth = 0 50 | for path in sorted_medusa_choices: 51 | depth = len(path) 52 | if depth != prev_depth: 53 | depth_counts.append(0) 54 | depth_counts[depth - 1] += 1 55 | prev_depth = depth 56 | 57 | # Create the attention mask for Medusa 58 | medusa_attn_mask = torch.eye(medusa_len, medusa_len) 59 | medusa_attn_mask[:, 0] = 1 60 | start = 0 61 | for i in range(len(depth_counts)): 62 | for j in range(depth_counts[i]): 63 | cur_medusa_choice = sorted_medusa_choices[start + j] 64 | # retrieve ancestor position 65 | if len(cur_medusa_choice) == 1: 66 | continue 67 | ancestor_idx = [] 68 | for c in range(len(cur_medusa_choice) - 1): 69 | ancestor_idx.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1]) + 1) 70 | medusa_attn_mask[j + start + 1, ancestor_idx] = 1 71 | start += depth_counts[i] 72 | 73 | # Generate tree indices for the Medusa structure 74 | medusa_tree_indices = torch.zeros(medusa_len, dtype=torch.long) 75 | medusa_tree_indices[0] = 0 76 | start = 0 77 | for i in range(len(depth_counts)): 78 | for j in range(depth_counts[i]): 79 | cur_medusa_choice = sorted_medusa_choices[start + j] 80 | medusa_tree_indices[start + j + 1] = cur_medusa_choice[-1] + TOPK * i + 1 81 | start += depth_counts[i] 82 | 83 | # Generate position IDs for the Medusa structure 84 | medusa_position_ids = torch.zeros(medusa_len, dtype=torch.long) 85 | start = 0 86 | for i in range(len(depth_counts)): 87 | medusa_position_ids[start + 1: start + depth_counts[i] + 1] = i + 1 88 | start += depth_counts[i] 89 | 90 | # Generate retrieval indices for Medusa structure verification 91 | retrieve_indices_nest = [] 92 | retrieve_paths = [] 93 | for i in range(len(sorted_medusa_choices)): 94 | cur_medusa_choice = sorted_medusa_choices[-i-1] 95 | retrieve_indice = [] 96 | if cur_medusa_choice in retrieve_paths: 97 | continue 98 | else: 99 | for c in range(len(cur_medusa_choice)): 100 | retrieve_indice.append(sorted_medusa_choices.index(cur_medusa_choice[:c+1])) 101 | retrieve_paths.append(cur_medusa_choice[:c+1]) 102 | retrieve_indices_nest.append(retrieve_indice) 103 | max_length = max([len(x) for x in retrieve_indices_nest]) 104 | retrieve_indices = [pad_path(path, max_length) for path in retrieve_indices_nest] 105 | retrieve_indices = torch.tensor(retrieve_indices, dtype=torch.long) 106 | retrieve_indices = retrieve_indices + 1 107 | retrieve_indices = torch.cat([torch.zeros((retrieve_indices.shape[0], 1), dtype=torch.long), retrieve_indices], dim=1) 108 | 109 | # Aggregate the generated buffers into a dictionary 110 | medusa_buffers = { 111 | "medusa_attn_mask": medusa_attn_mask.unsqueeze(0).unsqueeze(0), 112 | "tree_indices": medusa_tree_indices, 113 | "medusa_position_ids": medusa_position_ids, 114 | "retrieve_indices": retrieve_indices, 115 | } 116 | 117 | # Move the tensors in the dictionary to the specified device 118 | medusa_buffers = { 119 | k: v.clone().to(device) 120 | if isinstance(v, torch.Tensor) 121 | else torch.tensor(v, device=device) 122 | for k, v in medusa_buffers.items() 123 | } 124 | return medusa_buffers 125 | 126 | 127 | def initialize_medusa(input_ids, model, medusa_attn_mask, past_key_values): 128 | """ 129 | Initializes the Medusa structure for a given model. 130 | 131 | This function performs the following operations: 132 | 1. Forward pass through the model to obtain the Medusa logits, original model outputs, and logits. 133 | 2. Sets the Medusa attention mask within the base model. 134 | 135 | Args: 136 | - input_ids (torch.Tensor): The input tensor containing token ids. 137 | - model (MedusaLMHead): The model containing the Medusa layers and base model. 138 | - medusa_attn_mask (torch.Tensor): The attention mask designed specifically for the Medusa structure. 139 | - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values. 140 | 141 | Returns: 142 | - medusa_logits (torch.Tensor): Logits from the Medusa heads. 143 | - logits (torch.Tensor): Original logits from the base model. 144 | """ 145 | medusa_logits, outputs, logits = model( 146 | input_ids, past_key_values=past_key_values, output_orig=True, medusa_forward=True 147 | ) 148 | model.base_model.model.medusa_mask = medusa_attn_mask 149 | return medusa_logits, logits 150 | 151 | 152 | def reset_medusa_mode( 153 | model, 154 | ): 155 | """ 156 | Resets the Medusa settings and the past key-values to their initial state. 157 | 158 | This function ensures that after any operations involving Medusa, 159 | the base model and its settings return to their default state. 160 | Specifically, it performs the following tasks: 161 | 1. Clears the Medusa attention mask in the base model. 162 | 2. Resets the Medusa mode in the base model. 163 | 3. Resets the current lengths in the past key-values to zero for all layers. 164 | 165 | Args: 166 | - model (MedusaLMHead): The model containing the Medusa layers and base model. 167 | - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values. 168 | 169 | Returns: 170 | - past_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. 171 | """ 172 | model.base_model.model.medusa_mask = None 173 | model.base_model.model.medusa_mode = None 174 | 175 | 176 | def reset_past_key_values(passed_key_values): 177 | """ 178 | Resets the current lengths in the passed key-values to zero. 179 | 180 | This function is designed to be used during the evaluation of a baseline model. 181 | It iterates through each layer's key-values and sets their current lengths to zero, 182 | effectively resetting their state. 183 | 184 | Args: 185 | - passed_key_values (list of torch.Tensor): Contains past hidden states and past attention values for each layer. 186 | 187 | Returns: 188 | - passed_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths. 189 | """ 190 | for i in range(len(passed_key_values)): 191 | for j in range(2): 192 | passed_key_values[i][j].current_length.fill_(0) 193 | return passed_key_values 194 | 195 | 196 | def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices): 197 | """ 198 | Generate candidates based on provided logits and indices. 199 | 200 | Parameters: 201 | - medusa_logits (torch.Tensor): Logits associated with the Medusa structure. 202 | - logits (torch.Tensor): Original logits. 203 | - tree_indices (list or torch.Tensor): Indices associated with a tree structure. 204 | - retrieve_indices (list or torch.Tensor): Indices for retrieving candidates. 205 | 206 | Returns: 207 | - tuple: Returns cartesian candidates and tree candidates. 208 | """ 209 | 210 | # Greedy decoding: Select the most probable candidate from the original logits. 211 | candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0) 212 | 213 | # Extract the TOPK candidates from the medusa logits. 214 | candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices 215 | 216 | # Combine the selected candidate from the original logits with the topk medusa logits. 217 | candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1) 218 | 219 | # Map the combined candidates to the tree indices to get tree candidates. 220 | tree_candidates = candidates[tree_indices] 221 | 222 | # Extend the tree candidates by appending a zero. 223 | tree_candidates_ext = torch.cat([tree_candidates, torch.zeros((1), dtype=torch.long, device=tree_candidates.device)], dim=0) 224 | 225 | # Retrieve the cartesian candidates using the retrieve indices. 226 | cart_candidates = tree_candidates_ext[retrieve_indices] 227 | 228 | # Unsqueeze the tree candidates for dimension consistency. 229 | tree_candidates = tree_candidates.unsqueeze(0) 230 | return cart_candidates, tree_candidates 231 | 232 | 233 | def tree_decoding( 234 | model, 235 | tree_candidates, 236 | past_key_values, 237 | medusa_position_ids, 238 | input_ids, 239 | retrieve_indices, 240 | ): 241 | """ 242 | Decode the tree candidates using the provided model and reorganize the logits. 243 | 244 | Parameters: 245 | - model (nn.Module): Model to be used for decoding the tree candidates. 246 | - tree_candidates (torch.Tensor): Input candidates based on a tree structure. 247 | - past_key_values (torch.Tensor): Past states, such as key and value pairs, used in attention layers. 248 | - medusa_position_ids (torch.Tensor): Positional IDs associated with the Medusa structure. 249 | - input_ids (torch.Tensor): Input sequence IDs. 250 | - retrieve_indices (list or torch.Tensor): Indices for reordering the logits. 251 | 252 | Returns: 253 | - tuple: Returns medusa logits, regular logits, and other outputs from the model. 254 | """ 255 | 256 | # Compute new position IDs by adding the Medusa position IDs to the length of the input sequence. 257 | position_ids = medusa_position_ids + input_ids.shape[1] 258 | 259 | # Use the model to decode the tree candidates. 260 | # The model is expected to return logits for the Medusa structure, original logits, and possibly other outputs. 261 | tree_medusa_logits, outputs, tree_logits = model( 262 | tree_candidates, 263 | output_orig=True, 264 | past_key_values=past_key_values, 265 | position_ids=position_ids, 266 | medusa_forward=True, 267 | ) 268 | 269 | # Reorder the obtained logits based on the retrieve_indices to ensure consistency with some reference ordering. 270 | logits = tree_logits[0, retrieve_indices] 271 | medusa_logits = tree_medusa_logits[:, 0, retrieve_indices] 272 | return medusa_logits, logits, outputs 273 | 274 | 275 | def evaluate_posterior( 276 | logits, candidates, temperature, posterior_threshold, posterior_alpha 277 | ): 278 | """ 279 | Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate. 280 | 281 | Depending on the temperature value, the function either uses greedy decoding or evaluates posterior 282 | probabilities to select the best candidate. 283 | 284 | Args: 285 | - logits (torch.Tensor): Predicted logits of shape (batch_size, sequence_length, vocab_size). 286 | - candidates (torch.Tensor): Candidate token sequences. 287 | - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding. 288 | - posterior_threshold (float): Threshold for posterior probability. 289 | - posterior_alpha (float): Scaling factor for the threshold. 290 | 291 | Returns: 292 | - best_candidate (torch.Tensor): Index of the chosen best candidate. 293 | - accept_length (int): Length of the accepted candidate sequence. 294 | """ 295 | # Greedy decoding based on temperature value 296 | if temperature == 0: 297 | # Find the tokens that match the maximum logits for each position in the sequence 298 | posterior_mask = ( 299 | candidates[:, 1:] == torch.argmax(logits[:, :-1], dim=-1) 300 | ).int() 301 | candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) 302 | accept_length = candidates_accept_length.max() 303 | # Choose the best candidate 304 | if accept_length == 0: 305 | # Default to the first candidate if none are accepted 306 | best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) 307 | else: 308 | best_candidate = torch.argmax(candidates_accept_length).to(torch.long) 309 | return best_candidate, accept_length 310 | # Calculate posterior probabilities and thresholds for candidate selection 311 | posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1) 312 | candidates_prob = torch.gather( 313 | posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1) 314 | ).squeeze(-1) 315 | posterior_entropy = -torch.sum( 316 | posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1 317 | ) # torch.sum(torch.log(*)) is faster than torch.prod 318 | threshold = torch.minimum( 319 | torch.ones_like(posterior_entropy) * posterior_threshold, 320 | torch.exp(-posterior_entropy) * posterior_alpha, 321 | ) 322 | posterior_mask = candidates_prob > threshold 323 | candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1) 324 | 325 | # Choose the best candidate based on the evaluated posterior probabilities 326 | accept_length = candidates_accept_length.max() 327 | if accept_length == 0: 328 | # If no candidates are accepted, just choose the first one 329 | best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device) 330 | else: 331 | best_candidates = torch.where(candidates_accept_length == accept_length)[0] 332 | # Accept the best one according to likelihood 333 | likelihood = torch.sum( 334 | torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1 335 | ) 336 | best_candidate = best_candidates[torch.argmax(likelihood)] 337 | return best_candidate, accept_length 338 | 339 | 340 | def update_inference_inputs( 341 | input_ids, 342 | candidates, 343 | best_candidate, 344 | accept_length, 345 | retrieve_indices, 346 | outputs, 347 | logits, 348 | medusa_logits, 349 | new_token, 350 | past_key_values_data, 351 | current_length_data, 352 | ): 353 | """ 354 | Update the input sequences and relevant tensors based on the selected best candidate from the inference results. 355 | 356 | Args: 357 | - input_ids (torch.Tensor): Current input token sequences. 358 | - candidates (torch.Tensor): Candidate token sequences generated in the current step. 359 | - best_candidate (int): Index of the chosen best candidate. 360 | - accept_length (int): Length of the accepted candidate sequence. 361 | - retrieve_indices (torch.Tensor): Indices to map tree to a cartesian product. 362 | - outputs, logits, medusa_logits (torch.Tensor): Model's outputs from the previous inference step. 363 | - new_token (int): Counter for the new tokens added during inference. 364 | - past_key_values_data (torch.Tensor): Tensor containing past hidden states for the transformer model. 365 | - current_length_data (torch.Tensor): Tensor containing the current length of sequences in the batch. 366 | 367 | Returns: 368 | - input_ids (torch.Tensor): Updated input token sequences. 369 | - logits (torch.Tensor): Updated logits. 370 | - medusa_logits (torch.Tensor): Updated medusa logits. 371 | - new_token (int): Updated counter for the new tokens added. 372 | """ 373 | # Calculate the starting position for new tokens based on the previous input length 374 | prev_input_len = input_ids.shape[1] 375 | # Map the best candidate indices to the original indices in the sequence 376 | select_indices = ( 377 | retrieve_indices[best_candidate, : accept_length + 1] + prev_input_len 378 | ) 379 | # Append the tokens from the best candidate to the input sequence 380 | input_ids = torch.cat( 381 | [input_ids, candidates[None, best_candidate, : accept_length + 1]], dim=-1 382 | ) 383 | # Update the past key values based on the selected tokens 384 | # Source tensor that contains relevant past information based on the selected candidate 385 | tgt = past_key_values_data[..., select_indices, :] 386 | # Destination tensor where the relevant past information will be stored 387 | dst = past_key_values_data[..., prev_input_len : prev_input_len + tgt.shape[-2], :] 388 | # Copy relevant past information from the source to the destination 389 | dst.copy_(tgt, non_blocking=True) 390 | 391 | # Update the current length tensor (currently only support batch size is 1) 392 | current_length_data.fill_(prev_input_len + tgt.shape[-2]) 393 | 394 | # Extract logits and medusa logits for the accepted tokens 395 | logits = logits[None, best_candidate, accept_length : accept_length + 1] 396 | medusa_logits = medusa_logits[ 397 | :, None, best_candidate, accept_length : accept_length + 1 398 | ] 399 | # Update the new token counter 400 | new_token += accept_length + 1 401 | 402 | return input_ids, logits, medusa_logits, new_token 403 | -------------------------------------------------------------------------------- /medusa/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FasterDecoding/Medusa/e2a5d20c048a9b0a4092e6933c34313687422518/medusa/train/__init__.py -------------------------------------------------------------------------------- /medusa/train/train_legacy.py: -------------------------------------------------------------------------------- 1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright: 2 | # 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py 18 | 19 | from dataclasses import dataclass, field 20 | import json 21 | import math 22 | import pathlib 23 | from typing import Dict, Optional, Sequence 24 | 25 | import numpy as np 26 | import torch 27 | from torch import nn 28 | from torch.utils.data import Dataset 29 | import transformers 30 | from transformers import Trainer, BitsAndBytesConfig 31 | from transformers.trainer_pt_utils import LabelSmoother 32 | from safetensors.torch import save_file 33 | 34 | from fastchat.conversation import SeparatorStyle 35 | from fastchat.model.model_adapter import get_conversation_template 36 | from torch.nn import CrossEntropyLoss 37 | from torch.nn import functional as F 38 | import os 39 | from medusa.model.medusa_model_legacy import MedusaModel, MedusaConfig 40 | 41 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 42 | 43 | 44 | # Customized for training Medusa heads 45 | class CustomizedTrainer(Trainer): 46 | def compute_loss(self, model, inputs, return_outputs=False): 47 | """ 48 | Compute the training loss for the model. 49 | 50 | Args: 51 | model (torch.nn.Module): The model for which to compute the loss. 52 | inputs (dict): The input data, including input IDs, attention mask, and labels. 53 | return_outputs (bool): Whether to return model outputs along with the loss. 54 | 55 | Returns: 56 | Union[float, Tuple[float, torch.Tensor]]: The computed loss, optionally with model outputs. 57 | """ 58 | # DDP will give us model.module 59 | if hasattr(model, "module"): 60 | medusa = model.module.medusa 61 | else: 62 | medusa = model.medusa 63 | 64 | logits = model( 65 | input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] 66 | ) 67 | labels = inputs["labels"] 68 | # Shift so that tokens < n predict n 69 | loss = 0 70 | loss_fct = CrossEntropyLoss() 71 | log = {} 72 | for i in range(medusa): 73 | medusa_logits = logits[i, :, : -(2 + i)].contiguous() 74 | medusa_labels = labels[..., 2 + i :].contiguous() 75 | medusa_logits = medusa_logits.view(-1, logits.shape[-1]) 76 | medusa_labels = medusa_labels.view(-1) 77 | medusa_labels = medusa_labels.to(medusa_logits.device) 78 | loss_i = loss_fct(medusa_logits, medusa_labels) 79 | loss += loss_i 80 | not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID) 81 | medusa_labels = medusa_labels[not_ignore] 82 | 83 | # Add top-k accuracy 84 | for k in range(1, 2): 85 | _, topk = medusa_logits.topk(k, dim=-1) 86 | topk = topk[not_ignore] 87 | correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1) 88 | log[f"medusa{i}_top{k}"] = correct.float().mean().item() 89 | 90 | log[f"medusa{i}_loss"] = loss_i.item() 91 | self.log(log) 92 | return (loss, logits) if return_outputs else loss 93 | 94 | 95 | @dataclass 96 | class ModelArguments: 97 | model_name_or_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.3") 98 | load_in_4bit: bool = field( 99 | default=False, 100 | metadata={"help": "Load in 4 bit."}, 101 | ) 102 | load_in_8bit: bool = field( 103 | default=False, 104 | metadata={"help": "Load in 8 bit."}, 105 | ) 106 | 107 | 108 | @dataclass 109 | class DataArguments: 110 | data_path: str = field( 111 | default="sharegpt_clean.json", 112 | metadata={"help": "Path to the training data."}, 113 | ) 114 | eval_data_path: str = field( 115 | default=None, metadata={"help": "Path to the evaluation data."} 116 | ) 117 | lazy_preprocess: bool = True 118 | 119 | 120 | @dataclass 121 | class TrainingArguments(transformers.TrainingArguments): 122 | cache_dir: Optional[str] = field(default=None) 123 | report_to: Optional[str] = None 124 | optim: str = field(default="adamw_torch") 125 | model_max_length: int = field( 126 | default=2048, 127 | metadata={ 128 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 129 | }, 130 | ) 131 | medusa_num_heads: int = field( 132 | default=1, 133 | metadata={"help": "Number of Medusa heads."}, 134 | ) 135 | medusa_num_layers: int = field( 136 | default=1, 137 | metadata={"help": "Number of layers for each Medusa head."}, 138 | ) 139 | 140 | 141 | local_rank = None 142 | 143 | 144 | def rank0_print(*args): 145 | if local_rank == 0: 146 | print(*args) 147 | 148 | 149 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 150 | """ 151 | Save the model's state dictionary to a specified directory. 152 | 153 | Args: 154 | trainer (transformers.Trainer): The Hugging Face Trainer object. 155 | output_dir (str): The directory where the model state dictionary will be saved. 156 | """ 157 | state_dict = trainer.model.state_dict() 158 | if trainer.args.should_save: 159 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 160 | del state_dict 161 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 162 | 163 | def preprocess( 164 | sources, 165 | tokenizer: transformers.PreTrainedTokenizer, 166 | ) -> Dict: 167 | """ 168 | Preprocesses conversation data and tokenizes it for model input. 169 | 170 | Args: 171 | sources: A list of conversation sources. 172 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for tokenization. 173 | 174 | Returns: 175 | Dict: A dictionary containing tokenized inputs, labels, and attention mask. 176 | """ 177 | 178 | # Apply prompt templates 179 | conversations = [] 180 | prompts = [] 181 | # # import pdb; pdb.set_trace() 182 | for i, conversation in enumerate(sources): 183 | prompt = tokenizer.apply_chat_template(conversation, tokenize=False) 184 | prompts.append(prompt) 185 | conversations.append(conversation) 186 | 187 | # Tokenize conversations 188 | encoding = tokenizer( 189 | prompts, 190 | return_tensors="pt", 191 | padding="max_length", 192 | truncation=True, 193 | return_offsets_mapping=True, 194 | ) 195 | # Set everything to be ignored, except the assistant part 196 | targets = torch.full_like(encoding.input_ids, IGNORE_TOKEN_ID) 197 | input_ids = encoding.input_ids 198 | 199 | # Mask targets. Only compute loss on the assistant outputs. 200 | for conv_index, (conversation, target, prompt) in enumerate(zip(conversations, targets, prompts)): 201 | 202 | for turn in conversation: 203 | if turn["role"] == "assistant": 204 | content = turn["content"] 205 | # Unfortunate strip() necessary because chat templates are doing the same. 206 | start = prompt.index(content.strip()) 207 | stop = start + len(content) 208 | indices= [] 209 | for tok_index, (tok_start, tok_stop) in enumerate(encoding.offset_mapping[conv_index]): 210 | if tok_stop >= start or tok_start < tok_stop: 211 | indices.append(tok_index) 212 | target[indices] = encoding.input_ids[conv_index][indices] 213 | 214 | 215 | return dict( 216 | input_ids=input_ids, 217 | labels=targets, 218 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 219 | ) 220 | 221 | 222 | class SupervisedDataset(Dataset): 223 | """Dataset for supervised fine-tuning. 224 | 225 | Args: 226 | raw_data (list): A list of raw data examples. 227 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. 228 | """ 229 | 230 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): 231 | super(SupervisedDataset, self).__init__() 232 | 233 | rank0_print("Formatting inputs...") 234 | sources = raw_data 235 | data_dict = preprocess(sources, tokenizer) 236 | 237 | self.input_ids = data_dict["input_ids"] 238 | self.labels = data_dict["labels"] 239 | self.attention_mask = data_dict["attention_mask"] 240 | 241 | def __len__(self): 242 | return len(self.input_ids) 243 | 244 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 245 | return dict( 246 | input_ids=self.input_ids[i], 247 | labels=self.labels[i], 248 | attention_mask=self.attention_mask[i], 249 | ) 250 | 251 | 252 | class LazySupervisedDataset(Dataset): 253 | """Lazy dataset for supervised fine-tuning. 254 | 255 | This dataset loads data on-the-fly when requested, which can be memory-efficient but slower. 256 | 257 | Args: 258 | raw_data (list): A list of raw data examples. 259 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. 260 | """ 261 | 262 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer): 263 | super(LazySupervisedDataset, self).__init__() 264 | self.tokenizer = tokenizer 265 | 266 | rank0_print("Formatting inputs...Skip in lazy mode") 267 | self.tokenizer = tokenizer 268 | self.raw_data = raw_data 269 | self.cached_data_dict = {} 270 | 271 | def __len__(self): 272 | return len(self.raw_data) 273 | 274 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 275 | if i in self.cached_data_dict: 276 | return self.cached_data_dict[i] 277 | 278 | ret = preprocess([self.raw_data[i]], self.tokenizer) 279 | ret = dict( 280 | input_ids=ret["input_ids"][0], 281 | labels=ret["labels"][0], 282 | attention_mask=ret["attention_mask"][0], 283 | ) 284 | self.cached_data_dict[i] = ret 285 | 286 | return ret 287 | 288 | 289 | def make_supervised_data_module( 290 | tokenizer: transformers.PreTrainedTokenizer, data_args 291 | ) -> Dict: 292 | """Make dataset and collator for supervised fine-tuning. 293 | 294 | Args: 295 | tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. 296 | data_args: Data arguments. 297 | 298 | Returns: 299 | dict: A dictionary containing train and eval datasets. 300 | """ 301 | dataset_cls = ( 302 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset 303 | ) 304 | rank0_print("Loading data...") 305 | 306 | train_json = json.load(open(data_args.data_path, "r")) 307 | train_dataset = dataset_cls(train_json, tokenizer=tokenizer) 308 | 309 | if data_args.eval_data_path: 310 | eval_json = json.load(open(data_args.eval_data_path, "r")) 311 | eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer) 312 | else: 313 | eval_dataset = None 314 | 315 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset) 316 | 317 | 318 | def train(): 319 | global local_rank 320 | 321 | parser = transformers.HfArgumentParser( 322 | (ModelArguments, DataArguments, TrainingArguments) 323 | ) 324 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 325 | local_rank = training_args.local_rank 326 | 327 | # Set RoPE scaling factor 328 | config = transformers.AutoConfig.from_pretrained( 329 | model_args.model_name_or_path, 330 | cache_dir=training_args.cache_dir, 331 | ) 332 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 333 | if orig_ctx_len and training_args.model_max_length > orig_ctx_len: 334 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 335 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 336 | config.use_cache = False 337 | 338 | tokenizer = transformers.AutoTokenizer.from_pretrained( 339 | model_args.model_name_or_path, 340 | cache_dir=training_args.cache_dir, 341 | model_max_length=training_args.model_max_length, 342 | padding_side="right", 343 | use_fast=True, 344 | ) 345 | tokenizer.pad_token = tokenizer.unk_token 346 | tokenizer.pad_token = tokenizer.eos_token 347 | 348 | # Making sure the tokenizer works before loading the model. 349 | print(tokenizer(["This is a test", "secondary"], padding=True)) 350 | print(tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}])) 351 | 352 | # Load model and tokenizer 353 | model = transformers.AutoModelForCausalLM.from_pretrained( 354 | model_args.model_name_or_path, 355 | config=config, 356 | cache_dir=training_args.cache_dir, 357 | torch_dtype=torch.bfloat16, 358 | ) 359 | 360 | # Freeze the base model 361 | for param in model.base_model.parameters(): 362 | param.requires_grad = False 363 | 364 | # Add Medusa heads 365 | medusa_lm_head = MedusaModel( 366 | model, 367 | medusa_num_heads=training_args.medusa_num_heads, 368 | medusa_num_layers=training_args.medusa_num_layers, 369 | base_model_name_or_path=model_args.model_name_or_path, 370 | ) 371 | 372 | # Format output dir 373 | training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}" 374 | 375 | 376 | # Load data 377 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 378 | 379 | # Generate Medusa config for pushing to HF hub 380 | medusa_config = MedusaConfig( 381 | medusa_num_heads=training_args.medusa_num_heads, 382 | medusa_num_layers=training_args.medusa_num_layers, 383 | base_model_name_or_path=model_args.model_name_or_path, 384 | version="2" 385 | ) 386 | 387 | # Save Medusa config 388 | medusa_config.save_pretrained(training_args.output_dir) 389 | 390 | # Start trainner 391 | trainer = CustomizedTrainer( 392 | model=medusa_lm_head, tokenizer=tokenizer, args=training_args, **data_module 393 | ) 394 | 395 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 396 | trainer.train(resume_from_checkpoint=True) 397 | else: 398 | trainer.train() 399 | model.config.use_cache = True 400 | # trainer.save_state() 401 | # safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 402 | # Save MedusaHead seperately 403 | if hasattr(medusa_lm_head, "module"): 404 | lm_head = medusa_lm_head.module.medusa_head 405 | else: 406 | lm_head = medusa_lm_head.medusa_head 407 | import deepspeed 408 | with deepspeed.zero.GatheredParameters(lm_head.parameters()): 409 | state_dict = lm_head.state_dict() 410 | 411 | # Save Medusa heads 412 | if local_rank == 0: 413 | # Modify the tokenizer internal state before saving. 414 | tokenizer.encode("Test", truncation=None, padding="do_not_pad") 415 | tokenizer.save_pretrained(training_args.output_dir) 416 | save_file( 417 | state_dict, 418 | os.path.join(training_args.output_dir, "medusa_lm_head.safetensors"), 419 | ) 420 | 421 | 422 | if __name__ == "__main__": 423 | train() 424 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "medusa-llm" 7 | version = "1.0" 8 | description = "Simple Framework for Accelerating LLM Generation with Multiple Decoding Heads" 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "fschat", "torch", "transformers>=4.34", "accelerate", "sentencepiece", "protobuf" 17 | ] 18 | 19 | [project.optional-dependencies] 20 | train = ["bitsandbytes", "wandb", "scipy"] 21 | 22 | [project.urls] 23 | "Homepage" = "https://github.com/FasterDecoding/Medusa" 24 | "Blog" = "https://sites.google.com/view/medusa-llm" 25 | 26 | [tool.setuptools.packages.find] 27 | exclude = ["assets*", "notebooks*", "scripts*", "llm_judge"] -------------------------------------------------------------------------------- /scripts/train_vicuna_33b_8bit.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-33b-v1.3 \ 2 | --data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \ 3 | --bf16 True \ 4 | --output_dir test \ 5 | --num_train_epochs 1 \ 6 | --per_device_train_batch_size 8 \ 7 | --per_device_eval_batch_size 8 \ 8 | --gradient_accumulation_steps 4 \ 9 | --evaluation_strategy "no" \ 10 | --save_strategy "no" \ 11 | --learning_rate 1e-3 \ 12 | --weight_decay 0.0 \ 13 | --warmup_ratio 0.1 \ 14 | --lr_scheduler_type "cosine" \ 15 | --logging_steps 1 \ 16 | --tf32 True \ 17 | --model_max_length 2048 \ 18 | --lazy_preprocess True \ 19 | --medusa_num_heads 3 \ 20 | --medusa_num_layers 1 \ 21 | --load_in_8bit -------------------------------------------------------------------------------- /scripts/train_vicuna_7b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node=4 medusa/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \ 2 | --data_path ShareGPT_Vicuna_unfiltered/ShareGPT_V4.3_unfiltered_cleaned_split.json \ 3 | --bf16 True \ 4 | --output_dir test \ 5 | --num_train_epochs 1 \ 6 | --per_device_train_batch_size 8 \ 7 | --per_device_eval_batch_size 8 \ 8 | --gradient_accumulation_steps 4 \ 9 | --evaluation_strategy "no" \ 10 | --save_strategy "no" \ 11 | --learning_rate 1e-3 \ 12 | --weight_decay 0.0 \ 13 | --warmup_ratio 0.1 \ 14 | --lr_scheduler_type "cosine" \ 15 | --logging_steps 1 \ 16 | --tf32 True \ 17 | --model_max_length 2048 \ 18 | --lazy_preprocess True \ 19 | --medusa_num_heads 3 \ 20 | --medusa_num_layers 1 -------------------------------------------------------------------------------- /simple_gradio_interface.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import time 3 | import torch 4 | from medusa.model.medusa_model import MedusaModel 5 | from fastchat.model.model_adapter import get_conversation_template 6 | 7 | # Global variables 8 | chat_history = "" 9 | model = None 10 | tokenizer = None 11 | conv = None 12 | 13 | 14 | def load_model_function(model_name, load_in_8bit=False, load_in_4bit=False): 15 | model_name = model_name or "FasterDecoding/medusa-vicuna-7b-v1.3" 16 | global model, tokenizer, conv 17 | 18 | try: 19 | model = MedusaModel.from_pretrained( 20 | model_name, 21 | torch_dtype=torch.float16, 22 | low_cpu_mem_usage=True, 23 | device_map="auto", 24 | load_in_8bit=load_in_8bit, 25 | load_in_4bit=load_in_4bit 26 | ) 27 | tokenizer = model.get_tokenizer() 28 | conv = get_conversation_template("vicuna") 29 | return "Model loaded successfully!" 30 | except: 31 | return "Error loading the model. Please check the model name and try again." 32 | 33 | 34 | def reset_conversation(): 35 | """ 36 | Reset the global conversation and chat history 37 | """ 38 | global conv, chat_history 39 | conv = get_conversation_template("vicuna") 40 | chat_history = "" 41 | 42 | 43 | def medusa_chat_interface(user_input, temperature, max_steps, no_history): 44 | global model, tokenizer, conv, chat_history 45 | 46 | # Reset the conversation if no_history is checked 47 | if no_history: 48 | reset_conversation() 49 | 50 | if not model or not tokenizer: 51 | return "Error: Model not loaded!", chat_history 52 | 53 | chat_history += "\nYou: " + user_input 54 | conv.append_message(conv.roles[0], user_input) 55 | conv.append_message(conv.roles[1], '') 56 | prompt = conv.get_prompt() 57 | 58 | input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.base_model.device) 59 | 60 | outputs = model.medusa_generate(input_ids, temperature=temperature, max_steps=max_steps) 61 | response = "" 62 | for output in outputs: 63 | response = output['text'] 64 | yield response, chat_history 65 | time.sleep(0.01) 66 | 67 | chat_history += "\nMedusa: " + response.strip() 68 | 69 | return response, chat_history 70 | 71 | 72 | if __name__ == "__main__": 73 | load_model_interface = gr.Interface( 74 | load_model_function, 75 | [ 76 | gr.components.Textbox(placeholder="FasterDecoding/medusa-vicuna-7b-v1.3", label="Model Name"), 77 | gr.components.Checkbox(label="Use 8-bit Quantization"), 78 | gr.components.Checkbox(label="Use 4-bit Quantization"), 79 | ], 80 | gr.components.Textbox(label="Model Load Status", type="text"), 81 | description="Load Medusa Model", 82 | title="Medusa Model Loader", 83 | live=False, 84 | api_name="load_model" 85 | ) 86 | 87 | # Chat Interface 88 | chat_interface = gr.Interface( 89 | medusa_chat_interface, 90 | [ 91 | gr.components.Textbox(placeholder="Ask Medusa...", label="User Input"), 92 | gr.components.Slider(minimum=0, maximum=1.5, label="Temperature"), 93 | gr.components.Slider(minimum=50, maximum=1000, label="Max Steps"), 94 | gr.components.Checkbox(label="No History"), 95 | ], 96 | [ 97 | gr.components.Textbox(label="Medusa's Response", type="text"), 98 | gr.components.Textbox(label="Chat History", type="text") 99 | ], 100 | live=False, 101 | description="Chat with Medusa", 102 | title="Medusa Chatbox", 103 | api_name="chat" 104 | ) 105 | 106 | # Combine the interfaces in a TabbedInterface 107 | combined_interface = gr.TabbedInterface([load_model_interface, chat_interface], 108 | ["Load Model", "Chat"]) 109 | 110 | # Launch the combined interface 111 | combined_interface.queue().launch() 112 | --------------------------------------------------------------------------------