├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── LICENSE.txt ├── README.md ├── licenses └── MosaicML-mpt-7b-chat-hf-space.Apache.LICENSE.txt ├── requirements.txt ├── scripts └── wizard_play.py └── src └── callback_text_iterator_streamer.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: WizardCoder-7B", 9 | "type": "python", 10 | "request": "launch", 11 | "module": "scripts.wizard_play", 12 | "justMyCode": false, 13 | "args": [ 14 | "--flash", 15 | "--prompt_style", "wizardcoder-python", 16 | ] 17 | }, 18 | { 19 | "name": "Python: WizardCoder-34B", 20 | "type": "python", 21 | "request": "launch", 22 | "module": "scripts.wizard_play", 23 | "justMyCode": false, 24 | "args": [ 25 | "--model_name_or_path", "WizardLM/WizardCoder-Python-34B-V1.0", 26 | "--flash", 27 | "--prompt_style", "wizardcoder-python", 28 | ] 29 | }, 30 | { 31 | "name": "Python: CodeLlama-7B", 32 | "type": "python", 33 | "request": "launch", 34 | "module": "scripts.wizard_play", 35 | "justMyCode": false, 36 | "args": [ 37 | "--model_name_or_path", "codellama/CodeLlama-7b-Instruct-hf", 38 | "--flash", 39 | "--prompt_style", "codellama-instruct", 40 | "--chat_memory", 41 | ] 42 | }, 43 | { 44 | "name": "Python: CodeLlama-34B", 45 | "type": "python", 46 | "request": "launch", 47 | "module": "scripts.wizard_play", 48 | "justMyCode": false, 49 | "args": [ 50 | "--model_name_or_path", "codellama/CodeLlama-34b-Instruct-hf", 51 | "--flash", 52 | "--prompt_style", "codellama-instruct", 53 | "--chat_memory", 54 | ] 55 | }, 56 | { 57 | "name": "Python: CodeLlama-7B few-shot", 58 | "type": "python", 59 | "request": "launch", 60 | "module": "scripts.wizard_play", 61 | "justMyCode": false, 62 | "args": [ 63 | "--model_name_or_path", "codellama/CodeLlama-7b-Instruct-hf", 64 | "--flash", 65 | "--prompt_style", "codellama-instruct", 66 | "--chat_memory", 67 | "--shot0_input", "Read user's name from stdin", 68 | "--shot0_response", "import sys; name = input(\"Enter your name: \"); print(\"Your name is:\", name)", 69 | ] 70 | }, 71 | ] 72 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.extraPaths": ["src"] 3 | } -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alex Birch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WizardCoder-Play 2 | 3 | image 4 | 5 | Python script to demonstrate how to invoke models such as WizardCoder from the command-line, with bitsandbytes 4-bit quantization. 6 | 7 | Intends to support the following models: 8 | 9 | - [`WizardLM/WizardCoder-Python-7B-V1.0`](https://huggingface.co/WizardLM/WizardCoder-Python-7B-V1.0) 10 | - [`WizardLM/WizardCoder-Python-13B-V1.0`](https://huggingface.co/WizardLM/WizardCoder-Python-13B-V1.0) 11 | - [`WizardLM/WizardCoder-Python-34B-V1.0`](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0) 12 | - [`codellama/CodeLlama-7b-Instruct-hf`](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf) 13 | - [`codellama/CodeLlama-13b-Instruct-hf`](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf) 14 | - [`codellama/CodeLlama-34b-Instruct-hf`](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf) 15 | 16 | CodeLlama models were [trained on 16000 token sequences](https://ai.meta.com/blog/code-llama-large-language-model-coding/). 17 | WizardCoder was [finetuned on 2048 token sequences](https://arxiv.org/abs/2306.08568). 18 | 19 | WizardCoder-Python-34B-V1.0 [surpasses](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0) GPT4, ChatGPT-3.5 and Claude2 on HumanEval benchmarks. 20 | 21 | ## Setup 22 | 23 | All instructions are written assuming your command-line shell is bash. 24 | 25 | Clone repository: 26 | 27 | ```bash 28 | git clone https://github.com/Birch-san/wizardcoder-play.git 29 | cd wizardcoder-play 30 | ``` 31 | 32 | ### Create + activate a new virtual environment 33 | 34 | This is to avoid interfering with your current Python environment (other Python scripts on your computer might not appreciate it if you update a bunch of packages they were relying on). 35 | 36 | Follow the instructions for virtualenv, or conda, or neither (if you don't care what happens to other Python scripts on your computer). 37 | 38 | #### Using `venv` 39 | 40 | **Create environment**: 41 | 42 | ```bash 43 | python -m venv venv 44 | pip install --upgrade pip 45 | ``` 46 | 47 | **Activate environment**: 48 | 49 | ```bash 50 | . ./venv/bin/activate 51 | ``` 52 | 53 | **(First-time) update environment's `pip`**: 54 | 55 | ```bash 56 | pip install --upgrade pip 57 | ``` 58 | 59 | #### Using `conda` 60 | 61 | **Download [conda](https://www.anaconda.com/products/distribution).** 62 | 63 | _Skip this step if you already have conda._ 64 | 65 | **Install conda**: 66 | 67 | _Skip this step if you already have conda._ 68 | 69 | Assuming you're using a `bash` shell: 70 | 71 | ```bash 72 | # Linux installs Anaconda via this shell script. Mac installs by running a .pkg installer. 73 | bash Anaconda-latest-Linux-x86_64.sh 74 | # this step probably works on both Linux and Mac. 75 | eval "$(~/anaconda3/bin/conda shell.bash hook)" 76 | conda config --set auto_activate_base false 77 | conda init 78 | ``` 79 | 80 | **Create environment**: 81 | 82 | ```bash 83 | conda create -n p311-llama python=3.11 84 | ``` 85 | 86 | **Activate environment**: 87 | 88 | ```bash 89 | conda activate p311-llama 90 | ``` 91 | 92 | ### Install package dependencies 93 | 94 | **Ensure you have activated the environment you created above.** 95 | 96 | Install dependencies: 97 | 98 | ```bash 99 | pip install -r requirements.txt 100 | ``` 101 | 102 | #### (Optional) install PyTorch nightly 103 | 104 | The PyTorch nightlies may be more performant. Until [PyTorch 2.1.0 stable comes out (~October 4th)](https://github.com/pytorch/pytorch/issues/86566#issuecomment-1706075651), nightlies are the best way to get CUDA 12.1 support: 105 | 106 | ```bash 107 | # CUDA 108 | pip install --upgrade --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu121 109 | ``` 110 | 111 | #### (Optional) install flash attention 2 112 | 113 | To accelerate inference and reduce memory usage, install `flash-attn`. 114 | 115 | First we install the package itself: 116 | 117 | ```bash 118 | pip install flash-attn --no-build-isolation 119 | ``` 120 | 121 | Then we build-from-source its rotary embeddings kernel (there is no officially-distributed wheel): 122 | 123 | ```bash 124 | MAX_JOBS=2 pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary 125 | ``` 126 | 127 | **[Building `rotary` from source] `error: expected template-name before ‘<’ token`:** 128 | If you compiled flash-attn source using nvcc 12.x (i.e. CUDA Toolkit 12), you will [encounter the following error](https://github.com/pybind/pybind11/issues/4606) whilst compiling pybind11's `cast.h` header: 129 | 130 | ``` 131 | /home/birch/anaconda3/envs/p311-cu121-bnb-opt/lib/python3.11/site-packages/torch/include/pybind11/detail/../cast.h: In function ‘typename pybind11::detail::type_caster::type>::cast_op_type pybind11::detail::cast_op(make_caster&)’: 132 | /home/birch/anaconda3/envs/p311-cu121-bnb-opt/lib/python3.11/site-packages/torch/include/pybind11/detail/../cast.h:45:120: error: expected template-name before ‘<’ token 133 | 45 | return caster.operator typename make_caster::template cast_op_type(); 134 | ``` 135 | 136 | Solution [here](https://github.com/Dao-AILab/flash-attention/issues/484#issuecomment-1706843478). 137 | 138 | ## Run: 139 | 140 | From root of repository: 141 | 142 | ```bash 143 | python -m scripts.wizard_play 144 | ``` 145 | 146 | Fun command-line options: 147 | 148 | - `--model_name_or_path WizardLM/WizardCoder-Python-7B-V1.0 --prompt_style wizardcoder-python`: use WizardCoder 7B with WizardCoder prompting style 149 | - `--model_name_or_path codellama/CodeLlama-7b-Instruct-hf --prompt_style codellama-instruct`: use CodeLlama-7b-Instruct with CodeLlama-Instruct prompting style 150 | - `--flash --trust_remote_code`: enables flash attention 2 via `flash-attn` library and ([my fork of](https://huggingface.co/Birchlabs/flash_llama)) [togethercomputer's `modeling_flash_llama.py`](https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py) 151 | - `--max_new_tokens 2048`: modify maximum response length 152 | - `--chat_memory`: enable conversation history, for multi-turn conversations (CodeLlama-Instruct was trained on this, but WizardCoder was not) 153 | - `--initial_input 'Write a function which computes the Fibonacci sequence.'`: you can buffer a prompt to be submitted as soon as the model's loaded. 154 | 155 | You can press Ctrl+C whilst the model is generating a response, to interrupt it. If `--chat_memory` is enabled: the unfinished message **does** get persisted into the conversation history. 156 | If the model is **not** generating a response, then Ctrl+C will exit the software. 157 | 158 | ### Few-shotting 159 | 160 | You can seed the conversation history with a previous input and forced response from the model: 161 | 162 | ```bash 163 | python -m scripts.wizard_play --model_name_or_path codellama/CodeLlama-7b-Instruct-hf --prompt_style codellama-instruct --shot0_input "Read user's name from stdin" --shot0_response 'import sys 164 | 165 | name = input("Enter your name: ") 166 | print("Your name is:", name)' 167 | ``` 168 | 169 | This achieves two things: 170 | 171 | - creates a memory in the conversation 172 | - sets an expectation for what kind of style of response you prefer. 173 | 174 | You can see this in action, by asking the model to iterate on the solution you placed into its history: 175 | 176 | ``` 177 | [seed=64]$ Print their age too. 178 | import sys 179 | 180 | name = input("Enter your name: ") 181 | age = input("Enter your age: ") 182 | print("Your name is:", name, ",", "and", "your age:", age) 183 | ``` 184 | 185 | Note: this won't necessarily work so well for WizardCoder, which isn't trained in multi-turn conversations. 186 | 187 | ### Troubleshooting 188 | 189 | **`cannot import name 'translate_llvmir_to_hsaco'`:** 190 | You [need a triton nightly](https://github.com/openai/triton/issues/2002). 191 | ``` 192 | Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback): 193 | Failed to import transformers.generation.utils because of the following error (look up to see its traceback): 194 | cannot import name 'translate_llvmir_to_hsaco' from 'triton._C.libtriton.triton' (unknown location) 195 | ``` 196 | 197 | ```bash 198 | pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly 199 | ``` 200 | 201 | **`ImportError`:** 202 | Recent flash-attn releases encounter [errors _importing_ rotary embed](https://github.com/Dao-AILab/flash-attention/issues/519). You may need to copy Dao-AILab's [`ops/triton`](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/ops/triton) directory into the flash-attn distribution you installed to site-packages. 203 | 204 | ## License 205 | 206 | This repository is itself MIT-licensed. 207 | 208 | Includes: 209 | 210 | - MIT-licensed code copied from Artidoro Pagnoni's [qlora](https://github.com/artidoro/qlora) 211 | - MIT-licensed code copied from Scott Logic's [qlora fork](https://github.com/scottlogic-alex/qlora) (specifically [`evaluate.py`](https://github.com/scottlogic-alex/qlora/blob/stepwise/evaluate.py)). 212 | - [Apache-licensed](licenses/MosaicML-mpt-7b-chat-hf-space.Apache.LICENSE.txt) code copied from MosaicML's [mpt-7b-chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat/blob/main/app.py) Huggingface Space 213 | -------------------------------------------------------------------------------- /licenses/MosaicML-mpt-7b-chat-hf-space.Apache.LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | accelerate 4 | bitsandbytes 5 | scipy -------------------------------------------------------------------------------- /scripts/wizard_play.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional, TypedDict, NamedTuple, List, Dict, Union, TypeAlias, Literal 3 | import torch 4 | from torch import LongTensor 5 | from transformers import ( 6 | AutoConfig, 7 | AutoModelForCausalLM, 8 | AutoTokenizer, 9 | BitsAndBytesConfig, 10 | GenerationConfig, 11 | HfArgumentParser, 12 | set_seed, 13 | StoppingCriteria, 14 | StoppingCriteriaList, 15 | LlamaForCausalLM, 16 | LlamaTokenizerFast 17 | ) 18 | from src.callback_text_iterator_streamer import CallbackTextIteratorStreamer 19 | import logging 20 | from enum import Enum 21 | import sys 22 | from time import perf_counter 23 | from itertools import pairwise 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | class TokenizerOutput(TypedDict): 28 | input_ids: LongTensor 29 | attention_mask: LongTensor 30 | 31 | class PromptStyle(Enum): 32 | Bare = 'bare' 33 | WizardCoderPython = 'wizardcoder-python' 34 | CodeLlamaInstruct = 'codellama-instruct' 35 | # I am not proud of this, but when I attempted to specify Enum fields on the arg dataclasses: 36 | # hfparser.parse_args_into_dataclasses() turned the enum instances into string values. 37 | # so we make some types to capture what we're actually going to receive. 38 | PromptStyleLiteral: TypeAlias = Literal['bare', 'wizardcoder-python', 'codellama-instruct'] 39 | 40 | class Dtype(Enum): 41 | Bf16 = 'bf16' 42 | Fp16 = 'fp16' 43 | Fp32 = 'fp32' 44 | DtypeLiteral: TypeAlias = Literal['bf16', 'fp16', 'fp32'] 45 | 46 | def reify_dtype(dtype: DtypeLiteral) -> torch.dtype: 47 | match(dtype): 48 | case 'bf16': 49 | return torch.bfloat16 50 | case 'fp16': 51 | return torch.float16 52 | case 'fp32': 53 | return torch.float32 54 | 55 | class Participant(Enum): 56 | User = 'user' 57 | Assistant = 'assistant' 58 | System = 'system' 59 | 60 | class Message(NamedTuple): 61 | participant: Participant 62 | message: str 63 | 64 | @dataclass 65 | class StopOnTokens(StoppingCriteria): 66 | stop_token_ids: List[int] 67 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 68 | for stop_id in self.stop_token_ids: 69 | if input_ids[0][-1] == stop_id: 70 | return True 71 | return False 72 | 73 | class SufficientResponse(BaseException): ... 74 | 75 | @dataclass 76 | class ModelArguments: 77 | model_name_or_path: Optional[str] = field( 78 | default="WizardLM/WizardCoder-Python-7B-V1.0" 79 | ) 80 | cache_dir: Optional[str] = field( 81 | default=None, 82 | metadata={"help": "Which directory to use as your HuggingFace cache. Defaults to ~/.cache/huggingface, probably. Use this if you want to download models to a specific location."} 83 | ) 84 | trust_remote_code: Optional[bool] = field( 85 | default=False, 86 | metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} 87 | ) 88 | double_quant: bool = field( 89 | default=True, 90 | metadata={"help": "Compress the quantization statistics through double quantization."} 91 | ) 92 | quant_type: str = field( 93 | default="nf4", 94 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} 95 | ) 96 | bits: int = field( 97 | default=4, 98 | metadata={"help": "How many bits to use.", "choices": [4, 8, 16, 32]} 99 | ) 100 | model_dtype: DtypeLiteral = field( 101 | default=Dtype.Fp16.value, 102 | metadata={"help": "Compute type of the model. Used for non-quantized computations. Float16 may be more better than bfloat16 for inference.", "choices": [p.value for p in Dtype]} 103 | ) 104 | bnb_compute_dtype: DtypeLiteral = field( 105 | default=Dtype.Fp16.value, 106 | metadata={"help": "Compute type used for computations over dequantized weights. Float16 should be better than bfloat16. Float32 can be slightly better than float16.", "choices": [p.value for p in Dtype]} 107 | ) 108 | flash: Optional[bool] = field( 109 | default=False, 110 | metadata={"help": "Whether to replace the model code with togethercomputer's modeling_flash_llama.py, which uses Flash Attention 2 (via flash-attn) to accelerate model inference and reduce memory usage."} 111 | ) 112 | 113 | @dataclass 114 | class MiscArguments: 115 | seed: Optional[int] = field( 116 | default=64, 117 | metadata={"help": "Random seed, for deterministic generation."} 118 | ) 119 | compile: bool = field( 120 | default=False, 121 | metadata={"help": "Invoke torch.compile() on the model, with mode='max-autotune'. Requires PyTorch 2, CUDA, and either Python 3.10 or Python 3.11 with a recent torch nightly. Will make the first inference from the model take a bit longer, but subsequent inferences will be faster."} 122 | ) 123 | system_prompt: Optional[str] = field( 124 | default=None, 125 | metadata={"help": "The context which precedes the chat history. Can be used to influence the chatbot's responses. If unspecified: defaults to the standard system prompt for the prompt style."} 126 | ) 127 | initial_input: Optional[str] = field( 128 | default=None, 129 | metadata={"help": "Initial message sent to the model. For example: Read user's name from stdin"} 130 | ) 131 | shot0_input: Optional[str] = field( 132 | default=None, 133 | metadata={"help": "[to be used with --shot0_response] Use few-shotting to populate conversation history with an example of the kind of input+response you prefer. This arg exemplifies an input that the user sent previously to the model. For example: Read user's name from stdin"} 134 | ) 135 | shot0_response: Optional[str] = field( 136 | default=None, 137 | metadata={"help": "[to be used with --shot0_input] Use few-shotting to populate conversation history with an example of the kind of input+response you prefer. This arg exemplifies a reply that the model produced in reponse to the user's previous input, --shot0_input. For example: import sys\n\nname = input(\"Enter your name: \")\nprint(\"Your name is:\", name)"} 138 | ) 139 | # if you actually set the type hint to PromptStyle: you will find that HF/argparse assign a string anyway 140 | prompt_style: PromptStyleLiteral = field( 141 | default=PromptStyle.WizardCoderPython.value, 142 | metadata={"choices": [p.value for p in PromptStyle]} 143 | ) 144 | chat_memory: bool = field( 145 | default=False, 146 | metadata={"help": "Whether chat sequence should accumulate a conversation context, or reset each time"} 147 | ) 148 | reseed_each_prompt: bool = field( 149 | default=True, 150 | metadata={"help": "Reset seed before each user input"} 151 | ) 152 | show_seed: bool = field( 153 | default=True, 154 | metadata={"help": "Show seed in prompt"} 155 | ) 156 | measure_perf: bool = field( 157 | default=True, 158 | metadata={"help": "Print inference speed"} 159 | ) 160 | 161 | @dataclass 162 | class GenerationArguments: 163 | # For more hyperparameters check: 164 | # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig 165 | # Length arguments 166 | max_new_tokens: Optional[int] = field( 167 | default=2048, 168 | metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops" 169 | "if predict_with_generate is set."} 170 | ) 171 | min_new_tokens : Optional[int] = field( 172 | default=None, 173 | metadata={"help": "Minimum number of new tokens to generate."} 174 | ) 175 | 176 | # Generation strategy 177 | do_sample: Optional[bool] = field(default=False) 178 | num_beams: Optional[int] = field(default=1) 179 | num_beam_groups: Optional[int] = field(default=1) 180 | penalty_alpha: Optional[float] = field(default=None) 181 | use_cache: Optional[bool] = field(default=True) 182 | 183 | # Hyperparameters for logit manipulation 184 | temperature: Optional[float] = field(default=1.0) 185 | top_k: Optional[int] = field(default=50) 186 | top_p: Optional[float] = field(default=1.0) 187 | typical_p: Optional[float] = field(default=1.0) 188 | diversity_penalty: Optional[float] = field(default=0.0) 189 | repetition_penalty: Optional[float] = field(default=1.0) 190 | length_penalty: Optional[float] = field(default=1.0) 191 | no_repeat_ngram_size: Optional[int] = field(default=0) 192 | 193 | def get_model(args: ModelArguments) -> LlamaForCausalLM: 194 | config = AutoConfig.from_pretrained( 195 | args.model_name_or_path, 196 | trust_remote_code=args.trust_remote_code, 197 | cache_dir=args.cache_dir, 198 | ) 199 | 200 | if args.flash and config.model_type == 'llama': 201 | updates: Dict[str, Union[str, int, float, bool, None]] = {} 202 | flash_model_name = 'Birchlabs/flash_llama--modeling_flash_llama.LlamaForCausalLM' 203 | if 'num_key_value_heads' not in config.__dict__: 204 | updates['num_key_value_heads'] = config.num_attention_heads 205 | if 'auto_map' in config.__dict__: 206 | if not ('AutoModelForCausalLM' in config.auto_map and 'flash' in config.auto_map['AutoModelForCausalLM']): 207 | updates['auto_map']['AutoModelForCausalLM'] = flash_model_name 208 | else: 209 | updates['auto_map'] = { 'AutoModelForCausalLM': flash_model_name } 210 | if 'rope_scaling' not in config.__dict__: 211 | # CodeLlama-Instruct was trained on 16000 token sequences: 212 | # https://ai.meta.com/blog/code-llama-large-language-model-coding/ 213 | # WizardCoder was trained on 2048 token sequences (see section 4.2): 214 | # https://arxiv.org/abs/2306.08568 215 | # but both of their HF models report 16384 as the max position embeddings. 216 | # whatever; let's leave the rope scaling as default. 217 | # if you want to do different scaling, I think you'd compute it like this: 218 | # factor = desired_context_length/config.max_position_embeddings 219 | updates['rope_scaling'] = { 'factor': 1., 'type': 'linear' } 220 | if 'pretraining_tp' not in config.__dict__: 221 | updates['pretraining_tp'] = 1 222 | if updates: 223 | config.update(updates) 224 | 225 | cuda_avail = torch.cuda.is_available() 226 | load_in_4bit = args.bits == 4 and cuda_avail 227 | load_in_8bit = args.bits == 8 and cuda_avail 228 | 229 | bnb_compute_dtype: torch.dtype = reify_dtype(args.bnb_compute_dtype) 230 | 231 | quantization_config: Optional[BitsAndBytesConfig] = BitsAndBytesConfig( 232 | load_in_4bit=load_in_4bit, 233 | load_in_8bit=load_in_8bit, 234 | llm_int8_threshold=6.0, 235 | llm_int8_has_fp16_weight=False, 236 | bnb_4bit_compute_dtype=bnb_compute_dtype, 237 | bnb_4bit_use_double_quant=args.double_quant, 238 | bnb_4bit_quant_type=args.quant_type, 239 | ) if cuda_avail else None 240 | 241 | if not cuda_avail: 242 | logger.warning("You don't have CUDA, so we have turned off quantization. If you happen to be on a Mac: maybe you have enough unified memory to run in fp16 anyway…") 243 | 244 | model_dtype: torch.dtype = reify_dtype(args.model_dtype) 245 | 246 | model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained( 247 | args.model_name_or_path, 248 | config=config, 249 | load_in_4bit=load_in_4bit, 250 | load_in_8bit=load_in_8bit, 251 | device_map='auto', 252 | quantization_config=quantization_config, 253 | torch_dtype=model_dtype, 254 | trust_remote_code=args.trust_remote_code, 255 | cache_dir=args.cache_dir, 256 | ).eval() 257 | model.config.torch_dtype=model_dtype 258 | 259 | return model 260 | 261 | def main(): 262 | hfparser = HfArgumentParser((ModelArguments, GenerationArguments, MiscArguments)) 263 | model_args, generation_args, misc_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True) 264 | if extra_args: 265 | raise f"Received unsupported command-line args: {extra_args}" 266 | generation_config = GenerationConfig(**vars(generation_args)) 267 | 268 | model: LlamaForCausalLM = get_model(model_args) 269 | 270 | set_seed(misc_args.seed) 271 | if misc_args.compile: 272 | torch.compile(model, mode='max-autotune') 273 | 274 | tokenizer: LlamaTokenizerFast = AutoTokenizer.from_pretrained( 275 | model_args.model_name_or_path, 276 | # fast tokenizer required for WizardLM/WizardCoder-Python-34B-V1.0, because slow tokenizer doesn't come with added_tokens (required for {'[PAD]': 32000}) 277 | use_fast=True, 278 | cache_dir=model_args.cache_dir, 279 | ) 280 | # WizardCoder defines {'[PAD]': 32000}, but CodeLLama doesn't define any pad token, so we fall back to EOS. 281 | generation_config.pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id 282 | 283 | stop_token_ids: List[int] = [tokenizer.eos_token_id] 284 | stop = StopOnTokens(stop_token_ids) 285 | stopping_criteria=StoppingCriteriaList([stop]) 286 | 287 | system_prompt: Optional[str] = misc_args.system_prompt 288 | if misc_args.system_prompt is None: 289 | match misc_args.prompt_style: 290 | case PromptStyle.WizardCoderPython.value: 291 | system_prompt: str = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' 292 | case PromptStyle.CodeLlamaInstruct.value: 293 | system_prompt: str = 'Provide answers in Python' 294 | case PromptStyle.Bare.value: 295 | pass 296 | case _: 297 | raise ValueError(f'Never heard of a {misc_args.prompt_style} PromptStyle.') 298 | 299 | # The CodeLlama blog post suggests it's fine to not specify a system prompt: 300 | # https://huggingface.co/blog/codellama 301 | # whereas WizardCoder seems to always use the same (Alpaca-style) system prompt (I wouldn't recommend erasing WizardCoder's system prompt) 302 | optional_system_message: List[Message] = [Message(Participant.System, system_prompt)] if system_prompt else [] 303 | history: List[Message] = [] 304 | 305 | if misc_args.shot0_input is not None: 306 | assert misc_args.shot0_response is not None, "few-shotting requires you to specify the entire previous turn of the conversation (both --shot0_input and --shot0_response)." 307 | history += [ 308 | Message(Participant.User, misc_args.shot0_input), 309 | Message(Participant.Assistant, misc_args.shot0_response), 310 | ] 311 | 312 | reset_ansi='\x1b[0m' 313 | cyan_ansi='\x1b[31;36m' 314 | blue_ansi='\x1b[31;34m' 315 | green_ansi='\x1b[31;32m' 316 | purple_ansi='\x1b[31;35m' 317 | 318 | participant_names: Dict[Participant, str] = { 319 | Participant.User: 'Instruction', 320 | Participant.Assistant: 'Response', 321 | } 322 | 323 | def alpaca_section(envelope: Message) -> str: 324 | participant, message = envelope 325 | if participant is Participant.System: 326 | return message 327 | return f'### {participant_names[participant]}:\n{message}' 328 | 329 | def codellama_turn(user_msg: Message, assistant_msg: Message, is_first: bool) -> str: 330 | preamble = f'<>\n{system_prompt}\n<>\n\n' if is_first and system_prompt else '' 331 | return f'[INST] {preamble}{user_msg.message} [/INST] {assistant_msg.message}' 332 | 333 | next_seed: Optional[int] = None 334 | 335 | first = True 336 | while True: 337 | seed: int = misc_args.seed if next_seed is None else next_seed 338 | if misc_args.reseed_each_prompt or first or next_seed is not None: 339 | set_seed(seed) 340 | 341 | try: 342 | prompt_ctx: str = f'[seed={seed}]' if misc_args.show_seed else '' 343 | if first and misc_args.initial_input is not None: 344 | user_input = misc_args.initial_input 345 | quote: str = f'{purple_ansi}{prompt_ctx}> ' 346 | print(f'{quote}{user_input}') 347 | else: 348 | prompt: str = f'{purple_ansi}{prompt_ctx}$ ' 349 | user_input = input(f'{blue_ansi}Type a message to begin the conversation…{reset_ansi}\n{prompt}' if first else prompt) 350 | except (KeyboardInterrupt, EOFError): 351 | sys.exit(0) 352 | print(reset_ansi, end='') 353 | 354 | first = False 355 | 356 | user_message = Message(Participant.User, user_input) 357 | 358 | match misc_args.prompt_style: 359 | case PromptStyle.WizardCoderPython.value: 360 | chat_to_complete: str = '\n\n'.join([ 361 | alpaca_section(message) for message in [ 362 | *optional_system_message, 363 | *history, 364 | user_message, 365 | Message(Participant.Assistant, ''), 366 | ] 367 | ]) 368 | case PromptStyle.CodeLlamaInstruct.value: 369 | chat_to_complete: str = ' '.join([codellama_turn(user_msg, assist_msg, ix == 0) for ix, (user_msg, assist_msg) in enumerate(pairwise([ 370 | *history, 371 | user_message, 372 | Message(Participant.Assistant, ''), 373 | ]))]) 374 | case PromptStyle.Bare.value: 375 | chat_to_complete: str = user_input 376 | case _: 377 | raise ValueError(f'Never heard of a {misc_args.prompt_style} PromptStyle.') 378 | 379 | tokenized_prompts: TokenizerOutput = tokenizer([chat_to_complete], return_tensors='pt', truncation=True) 380 | 381 | print(green_ansi, end='', flush=True) 382 | 383 | response = '' 384 | def on_text(message: str, stream_end = False): 385 | nonlocal response 386 | response += message 387 | print(message, end='', flush=True) 388 | 389 | streamer = CallbackTextIteratorStreamer(tokenizer, callback=on_text, skip_prompt=True, skip_special_tokens=True) 390 | 391 | try: 392 | inference_start: float = perf_counter() 393 | prediction: LongTensor = model.generate( 394 | input_ids=tokenized_prompts.input_ids.to(model.device), 395 | attention_mask=tokenized_prompts.attention_mask.to(model.device), 396 | generation_config=generation_config, 397 | do_sample=generation_config.temperature > 0., 398 | stopping_criteria=stopping_criteria, 399 | streamer=streamer, 400 | ) 401 | # reset ANSI control sequence (plus line break) 402 | print(reset_ansi) 403 | # if you wanted to see the result, you can do so like this: 404 | # decode: List[str] = tokenizer.decode(prediction[0,tokenized_prompts.input_ids.size(-1):], skip_special_tokens=False, clean_up_tokenization_spaces=True) 405 | # print(decode) 406 | # pass 407 | # but we're already streaming it to the console via our callback 408 | inference_duration: float = perf_counter()-inference_start 409 | token_in_count: int = tokenized_prompts.input_ids.size(-1) 410 | token_out_count: int = prediction.size(-1) - token_in_count 411 | tokens_out_per_sec: float = token_out_count/inference_duration 412 | if misc_args.measure_perf: 413 | print(f'{cyan_ansi}ctx length: {token_in_count}\ntokens out: {token_out_count}\nduration: {inference_duration:.2f} secs\nspeed: {tokens_out_per_sec:.2f} tokens/sec{reset_ansi}') 414 | except (KeyboardInterrupt, SufficientResponse, EOFError): 415 | # reset ANSI control sequence (plus line break) 416 | print(reset_ansi) 417 | 418 | # we disable accumulation of conversation history by default, because WizardCoder is not advertised as being finetuned on multi-turn conversations, 419 | # but more importantly because I'd rather spend our 4k context length on a detailed answer for a single-turn than an incomplete answer for multiple turns. 420 | if misc_args.chat_memory: 421 | history += [ 422 | user_message, 423 | Message(Participant.Assistant, response) 424 | ] 425 | 426 | if __name__ == "__main__": 427 | main() -------------------------------------------------------------------------------- /src/callback_text_iterator_streamer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, TextIteratorStreamer 2 | from typing import Optional, Protocol 3 | 4 | class TextCallback(Protocol): 5 | def __call__(self, text: str, stream_end: bool = False) -> None: ... 6 | 7 | 8 | class CallbackTextIteratorStreamer(TextIteratorStreamer): 9 | callback: TextCallback 10 | def __init__( 11 | self, tokenizer: AutoTokenizer, callback: TextCallback, skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs 12 | ): 13 | super().__init__(tokenizer, skip_prompt, **decode_kwargs) 14 | self.callback = callback 15 | 16 | def on_finalized_text(self, text: str, stream_end: bool = False): 17 | self.callback(text, stream_end=stream_end) 18 | --------------------------------------------------------------------------------