├── .gitignore
├── .vscode
├── launch.json
└── settings.json
├── LICENSE.txt
├── README.md
├── licenses
└── MosaicML-mpt-7b-chat-hf-space.Apache.LICENSE.txt
├── requirements.txt
├── scripts
└── wizard_play.py
└── src
└── callback_text_iterator_streamer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: WizardCoder-7B",
9 | "type": "python",
10 | "request": "launch",
11 | "module": "scripts.wizard_play",
12 | "justMyCode": false,
13 | "args": [
14 | "--flash",
15 | "--prompt_style", "wizardcoder-python",
16 | ]
17 | },
18 | {
19 | "name": "Python: WizardCoder-34B",
20 | "type": "python",
21 | "request": "launch",
22 | "module": "scripts.wizard_play",
23 | "justMyCode": false,
24 | "args": [
25 | "--model_name_or_path", "WizardLM/WizardCoder-Python-34B-V1.0",
26 | "--flash",
27 | "--prompt_style", "wizardcoder-python",
28 | ]
29 | },
30 | {
31 | "name": "Python: CodeLlama-7B",
32 | "type": "python",
33 | "request": "launch",
34 | "module": "scripts.wizard_play",
35 | "justMyCode": false,
36 | "args": [
37 | "--model_name_or_path", "codellama/CodeLlama-7b-Instruct-hf",
38 | "--flash",
39 | "--prompt_style", "codellama-instruct",
40 | "--chat_memory",
41 | ]
42 | },
43 | {
44 | "name": "Python: CodeLlama-34B",
45 | "type": "python",
46 | "request": "launch",
47 | "module": "scripts.wizard_play",
48 | "justMyCode": false,
49 | "args": [
50 | "--model_name_or_path", "codellama/CodeLlama-34b-Instruct-hf",
51 | "--flash",
52 | "--prompt_style", "codellama-instruct",
53 | "--chat_memory",
54 | ]
55 | },
56 | {
57 | "name": "Python: CodeLlama-7B few-shot",
58 | "type": "python",
59 | "request": "launch",
60 | "module": "scripts.wizard_play",
61 | "justMyCode": false,
62 | "args": [
63 | "--model_name_or_path", "codellama/CodeLlama-7b-Instruct-hf",
64 | "--flash",
65 | "--prompt_style", "codellama-instruct",
66 | "--chat_memory",
67 | "--shot0_input", "Read user's name from stdin",
68 | "--shot0_response", "import sys; name = input(\"Enter your name: \"); print(\"Your name is:\", name)",
69 | ]
70 | },
71 | ]
72 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.analysis.extraPaths": ["src"]
3 | }
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Alex Birch
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # WizardCoder-Play
2 |
3 |
4 |
5 | Python script to demonstrate how to invoke models such as WizardCoder from the command-line, with bitsandbytes 4-bit quantization.
6 |
7 | Intends to support the following models:
8 |
9 | - [`WizardLM/WizardCoder-Python-7B-V1.0`](https://huggingface.co/WizardLM/WizardCoder-Python-7B-V1.0)
10 | - [`WizardLM/WizardCoder-Python-13B-V1.0`](https://huggingface.co/WizardLM/WizardCoder-Python-13B-V1.0)
11 | - [`WizardLM/WizardCoder-Python-34B-V1.0`](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0)
12 | - [`codellama/CodeLlama-7b-Instruct-hf`](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf)
13 | - [`codellama/CodeLlama-13b-Instruct-hf`](https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf)
14 | - [`codellama/CodeLlama-34b-Instruct-hf`](https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf)
15 |
16 | CodeLlama models were [trained on 16000 token sequences](https://ai.meta.com/blog/code-llama-large-language-model-coding/).
17 | WizardCoder was [finetuned on 2048 token sequences](https://arxiv.org/abs/2306.08568).
18 |
19 | WizardCoder-Python-34B-V1.0 [surpasses](https://huggingface.co/WizardLM/WizardCoder-Python-34B-V1.0) GPT4, ChatGPT-3.5 and Claude2 on HumanEval benchmarks.
20 |
21 | ## Setup
22 |
23 | All instructions are written assuming your command-line shell is bash.
24 |
25 | Clone repository:
26 |
27 | ```bash
28 | git clone https://github.com/Birch-san/wizardcoder-play.git
29 | cd wizardcoder-play
30 | ```
31 |
32 | ### Create + activate a new virtual environment
33 |
34 | This is to avoid interfering with your current Python environment (other Python scripts on your computer might not appreciate it if you update a bunch of packages they were relying on).
35 |
36 | Follow the instructions for virtualenv, or conda, or neither (if you don't care what happens to other Python scripts on your computer).
37 |
38 | #### Using `venv`
39 |
40 | **Create environment**:
41 |
42 | ```bash
43 | python -m venv venv
44 | pip install --upgrade pip
45 | ```
46 |
47 | **Activate environment**:
48 |
49 | ```bash
50 | . ./venv/bin/activate
51 | ```
52 |
53 | **(First-time) update environment's `pip`**:
54 |
55 | ```bash
56 | pip install --upgrade pip
57 | ```
58 |
59 | #### Using `conda`
60 |
61 | **Download [conda](https://www.anaconda.com/products/distribution).**
62 |
63 | _Skip this step if you already have conda._
64 |
65 | **Install conda**:
66 |
67 | _Skip this step if you already have conda._
68 |
69 | Assuming you're using a `bash` shell:
70 |
71 | ```bash
72 | # Linux installs Anaconda via this shell script. Mac installs by running a .pkg installer.
73 | bash Anaconda-latest-Linux-x86_64.sh
74 | # this step probably works on both Linux and Mac.
75 | eval "$(~/anaconda3/bin/conda shell.bash hook)"
76 | conda config --set auto_activate_base false
77 | conda init
78 | ```
79 |
80 | **Create environment**:
81 |
82 | ```bash
83 | conda create -n p311-llama python=3.11
84 | ```
85 |
86 | **Activate environment**:
87 |
88 | ```bash
89 | conda activate p311-llama
90 | ```
91 |
92 | ### Install package dependencies
93 |
94 | **Ensure you have activated the environment you created above.**
95 |
96 | Install dependencies:
97 |
98 | ```bash
99 | pip install -r requirements.txt
100 | ```
101 |
102 | #### (Optional) install PyTorch nightly
103 |
104 | The PyTorch nightlies may be more performant. Until [PyTorch 2.1.0 stable comes out (~October 4th)](https://github.com/pytorch/pytorch/issues/86566#issuecomment-1706075651), nightlies are the best way to get CUDA 12.1 support:
105 |
106 | ```bash
107 | # CUDA
108 | pip install --upgrade --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cu121
109 | ```
110 |
111 | #### (Optional) install flash attention 2
112 |
113 | To accelerate inference and reduce memory usage, install `flash-attn`.
114 |
115 | First we install the package itself:
116 |
117 | ```bash
118 | pip install flash-attn --no-build-isolation
119 | ```
120 |
121 | Then we build-from-source its rotary embeddings kernel (there is no officially-distributed wheel):
122 |
123 | ```bash
124 | MAX_JOBS=2 pip install git+https://github.com/HazyResearch/flash-attention.git#subdirectory=csrc/rotary
125 | ```
126 |
127 | **[Building `rotary` from source] `error: expected template-name before ‘<’ token`:**
128 | If you compiled flash-attn source using nvcc 12.x (i.e. CUDA Toolkit 12), you will [encounter the following error](https://github.com/pybind/pybind11/issues/4606) whilst compiling pybind11's `cast.h` header:
129 |
130 | ```
131 | /home/birch/anaconda3/envs/p311-cu121-bnb-opt/lib/python3.11/site-packages/torch/include/pybind11/detail/../cast.h: In function ‘typename pybind11::detail::type_caster::type>::cast_op_type pybind11::detail::cast_op(make_caster&)’:
132 | /home/birch/anaconda3/envs/p311-cu121-bnb-opt/lib/python3.11/site-packages/torch/include/pybind11/detail/../cast.h:45:120: error: expected template-name before ‘<’ token
133 | 45 | return caster.operator typename make_caster::template cast_op_type();
134 | ```
135 |
136 | Solution [here](https://github.com/Dao-AILab/flash-attention/issues/484#issuecomment-1706843478).
137 |
138 | ## Run:
139 |
140 | From root of repository:
141 |
142 | ```bash
143 | python -m scripts.wizard_play
144 | ```
145 |
146 | Fun command-line options:
147 |
148 | - `--model_name_or_path WizardLM/WizardCoder-Python-7B-V1.0 --prompt_style wizardcoder-python`: use WizardCoder 7B with WizardCoder prompting style
149 | - `--model_name_or_path codellama/CodeLlama-7b-Instruct-hf --prompt_style codellama-instruct`: use CodeLlama-7b-Instruct with CodeLlama-Instruct prompting style
150 | - `--flash --trust_remote_code`: enables flash attention 2 via `flash-attn` library and ([my fork of](https://huggingface.co/Birchlabs/flash_llama)) [togethercomputer's `modeling_flash_llama.py`](https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/modeling_flash_llama.py)
151 | - `--max_new_tokens 2048`: modify maximum response length
152 | - `--chat_memory`: enable conversation history, for multi-turn conversations (CodeLlama-Instruct was trained on this, but WizardCoder was not)
153 | - `--initial_input 'Write a function which computes the Fibonacci sequence.'`: you can buffer a prompt to be submitted as soon as the model's loaded.
154 |
155 | You can press Ctrl+C whilst the model is generating a response, to interrupt it. If `--chat_memory` is enabled: the unfinished message **does** get persisted into the conversation history.
156 | If the model is **not** generating a response, then Ctrl+C will exit the software.
157 |
158 | ### Few-shotting
159 |
160 | You can seed the conversation history with a previous input and forced response from the model:
161 |
162 | ```bash
163 | python -m scripts.wizard_play --model_name_or_path codellama/CodeLlama-7b-Instruct-hf --prompt_style codellama-instruct --shot0_input "Read user's name from stdin" --shot0_response 'import sys
164 |
165 | name = input("Enter your name: ")
166 | print("Your name is:", name)'
167 | ```
168 |
169 | This achieves two things:
170 |
171 | - creates a memory in the conversation
172 | - sets an expectation for what kind of style of response you prefer.
173 |
174 | You can see this in action, by asking the model to iterate on the solution you placed into its history:
175 |
176 | ```
177 | [seed=64]$ Print their age too.
178 | import sys
179 |
180 | name = input("Enter your name: ")
181 | age = input("Enter your age: ")
182 | print("Your name is:", name, ",", "and", "your age:", age)
183 | ```
184 |
185 | Note: this won't necessarily work so well for WizardCoder, which isn't trained in multi-turn conversations.
186 |
187 | ### Troubleshooting
188 |
189 | **`cannot import name 'translate_llvmir_to_hsaco'`:**
190 | You [need a triton nightly](https://github.com/openai/triton/issues/2002).
191 | ```
192 | Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
193 | Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
194 | cannot import name 'translate_llvmir_to_hsaco' from 'triton._C.libtriton.triton' (unknown location)
195 | ```
196 |
197 | ```bash
198 | pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
199 | ```
200 |
201 | **`ImportError`:**
202 | Recent flash-attn releases encounter [errors _importing_ rotary embed](https://github.com/Dao-AILab/flash-attention/issues/519). You may need to copy Dao-AILab's [`ops/triton`](https://github.com/Dao-AILab/flash-attention/tree/main/flash_attn/ops/triton) directory into the flash-attn distribution you installed to site-packages.
203 |
204 | ## License
205 |
206 | This repository is itself MIT-licensed.
207 |
208 | Includes:
209 |
210 | - MIT-licensed code copied from Artidoro Pagnoni's [qlora](https://github.com/artidoro/qlora)
211 | - MIT-licensed code copied from Scott Logic's [qlora fork](https://github.com/scottlogic-alex/qlora) (specifically [`evaluate.py`](https://github.com/scottlogic-alex/qlora/blob/stepwise/evaluate.py)).
212 | - [Apache-licensed](licenses/MosaicML-mpt-7b-chat-hf-space.Apache.LICENSE.txt) code copied from MosaicML's [mpt-7b-chat](https://huggingface.co/spaces/mosaicml/mpt-7b-chat/blob/main/app.py) Huggingface Space
213 |
--------------------------------------------------------------------------------
/licenses/MosaicML-mpt-7b-chat-hf-space.Apache.LICENSE.txt:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | transformers
3 | accelerate
4 | bitsandbytes
5 | scipy
--------------------------------------------------------------------------------
/scripts/wizard_play.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional, TypedDict, NamedTuple, List, Dict, Union, TypeAlias, Literal
3 | import torch
4 | from torch import LongTensor
5 | from transformers import (
6 | AutoConfig,
7 | AutoModelForCausalLM,
8 | AutoTokenizer,
9 | BitsAndBytesConfig,
10 | GenerationConfig,
11 | HfArgumentParser,
12 | set_seed,
13 | StoppingCriteria,
14 | StoppingCriteriaList,
15 | LlamaForCausalLM,
16 | LlamaTokenizerFast
17 | )
18 | from src.callback_text_iterator_streamer import CallbackTextIteratorStreamer
19 | import logging
20 | from enum import Enum
21 | import sys
22 | from time import perf_counter
23 | from itertools import pairwise
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 | class TokenizerOutput(TypedDict):
28 | input_ids: LongTensor
29 | attention_mask: LongTensor
30 |
31 | class PromptStyle(Enum):
32 | Bare = 'bare'
33 | WizardCoderPython = 'wizardcoder-python'
34 | CodeLlamaInstruct = 'codellama-instruct'
35 | # I am not proud of this, but when I attempted to specify Enum fields on the arg dataclasses:
36 | # hfparser.parse_args_into_dataclasses() turned the enum instances into string values.
37 | # so we make some types to capture what we're actually going to receive.
38 | PromptStyleLiteral: TypeAlias = Literal['bare', 'wizardcoder-python', 'codellama-instruct']
39 |
40 | class Dtype(Enum):
41 | Bf16 = 'bf16'
42 | Fp16 = 'fp16'
43 | Fp32 = 'fp32'
44 | DtypeLiteral: TypeAlias = Literal['bf16', 'fp16', 'fp32']
45 |
46 | def reify_dtype(dtype: DtypeLiteral) -> torch.dtype:
47 | match(dtype):
48 | case 'bf16':
49 | return torch.bfloat16
50 | case 'fp16':
51 | return torch.float16
52 | case 'fp32':
53 | return torch.float32
54 |
55 | class Participant(Enum):
56 | User = 'user'
57 | Assistant = 'assistant'
58 | System = 'system'
59 |
60 | class Message(NamedTuple):
61 | participant: Participant
62 | message: str
63 |
64 | @dataclass
65 | class StopOnTokens(StoppingCriteria):
66 | stop_token_ids: List[int]
67 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
68 | for stop_id in self.stop_token_ids:
69 | if input_ids[0][-1] == stop_id:
70 | return True
71 | return False
72 |
73 | class SufficientResponse(BaseException): ...
74 |
75 | @dataclass
76 | class ModelArguments:
77 | model_name_or_path: Optional[str] = field(
78 | default="WizardLM/WizardCoder-Python-7B-V1.0"
79 | )
80 | cache_dir: Optional[str] = field(
81 | default=None,
82 | metadata={"help": "Which directory to use as your HuggingFace cache. Defaults to ~/.cache/huggingface, probably. Use this if you want to download models to a specific location."}
83 | )
84 | trust_remote_code: Optional[bool] = field(
85 | default=False,
86 | metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
87 | )
88 | double_quant: bool = field(
89 | default=True,
90 | metadata={"help": "Compress the quantization statistics through double quantization."}
91 | )
92 | quant_type: str = field(
93 | default="nf4",
94 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
95 | )
96 | bits: int = field(
97 | default=4,
98 | metadata={"help": "How many bits to use.", "choices": [4, 8, 16, 32]}
99 | )
100 | model_dtype: DtypeLiteral = field(
101 | default=Dtype.Fp16.value,
102 | metadata={"help": "Compute type of the model. Used for non-quantized computations. Float16 may be more better than bfloat16 for inference.", "choices": [p.value for p in Dtype]}
103 | )
104 | bnb_compute_dtype: DtypeLiteral = field(
105 | default=Dtype.Fp16.value,
106 | metadata={"help": "Compute type used for computations over dequantized weights. Float16 should be better than bfloat16. Float32 can be slightly better than float16.", "choices": [p.value for p in Dtype]}
107 | )
108 | flash: Optional[bool] = field(
109 | default=False,
110 | metadata={"help": "Whether to replace the model code with togethercomputer's modeling_flash_llama.py, which uses Flash Attention 2 (via flash-attn) to accelerate model inference and reduce memory usage."}
111 | )
112 |
113 | @dataclass
114 | class MiscArguments:
115 | seed: Optional[int] = field(
116 | default=64,
117 | metadata={"help": "Random seed, for deterministic generation."}
118 | )
119 | compile: bool = field(
120 | default=False,
121 | metadata={"help": "Invoke torch.compile() on the model, with mode='max-autotune'. Requires PyTorch 2, CUDA, and either Python 3.10 or Python 3.11 with a recent torch nightly. Will make the first inference from the model take a bit longer, but subsequent inferences will be faster."}
122 | )
123 | system_prompt: Optional[str] = field(
124 | default=None,
125 | metadata={"help": "The context which precedes the chat history. Can be used to influence the chatbot's responses. If unspecified: defaults to the standard system prompt for the prompt style."}
126 | )
127 | initial_input: Optional[str] = field(
128 | default=None,
129 | metadata={"help": "Initial message sent to the model. For example: Read user's name from stdin"}
130 | )
131 | shot0_input: Optional[str] = field(
132 | default=None,
133 | metadata={"help": "[to be used with --shot0_response] Use few-shotting to populate conversation history with an example of the kind of input+response you prefer. This arg exemplifies an input that the user sent previously to the model. For example: Read user's name from stdin"}
134 | )
135 | shot0_response: Optional[str] = field(
136 | default=None,
137 | metadata={"help": "[to be used with --shot0_input] Use few-shotting to populate conversation history with an example of the kind of input+response you prefer. This arg exemplifies a reply that the model produced in reponse to the user's previous input, --shot0_input. For example: import sys\n\nname = input(\"Enter your name: \")\nprint(\"Your name is:\", name)"}
138 | )
139 | # if you actually set the type hint to PromptStyle: you will find that HF/argparse assign a string anyway
140 | prompt_style: PromptStyleLiteral = field(
141 | default=PromptStyle.WizardCoderPython.value,
142 | metadata={"choices": [p.value for p in PromptStyle]}
143 | )
144 | chat_memory: bool = field(
145 | default=False,
146 | metadata={"help": "Whether chat sequence should accumulate a conversation context, or reset each time"}
147 | )
148 | reseed_each_prompt: bool = field(
149 | default=True,
150 | metadata={"help": "Reset seed before each user input"}
151 | )
152 | show_seed: bool = field(
153 | default=True,
154 | metadata={"help": "Show seed in prompt"}
155 | )
156 | measure_perf: bool = field(
157 | default=True,
158 | metadata={"help": "Print inference speed"}
159 | )
160 |
161 | @dataclass
162 | class GenerationArguments:
163 | # For more hyperparameters check:
164 | # https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
165 | # Length arguments
166 | max_new_tokens: Optional[int] = field(
167 | default=2048,
168 | metadata={"help": "Maximum number of new tokens to be generated in evaluation or prediction loops"
169 | "if predict_with_generate is set."}
170 | )
171 | min_new_tokens : Optional[int] = field(
172 | default=None,
173 | metadata={"help": "Minimum number of new tokens to generate."}
174 | )
175 |
176 | # Generation strategy
177 | do_sample: Optional[bool] = field(default=False)
178 | num_beams: Optional[int] = field(default=1)
179 | num_beam_groups: Optional[int] = field(default=1)
180 | penalty_alpha: Optional[float] = field(default=None)
181 | use_cache: Optional[bool] = field(default=True)
182 |
183 | # Hyperparameters for logit manipulation
184 | temperature: Optional[float] = field(default=1.0)
185 | top_k: Optional[int] = field(default=50)
186 | top_p: Optional[float] = field(default=1.0)
187 | typical_p: Optional[float] = field(default=1.0)
188 | diversity_penalty: Optional[float] = field(default=0.0)
189 | repetition_penalty: Optional[float] = field(default=1.0)
190 | length_penalty: Optional[float] = field(default=1.0)
191 | no_repeat_ngram_size: Optional[int] = field(default=0)
192 |
193 | def get_model(args: ModelArguments) -> LlamaForCausalLM:
194 | config = AutoConfig.from_pretrained(
195 | args.model_name_or_path,
196 | trust_remote_code=args.trust_remote_code,
197 | cache_dir=args.cache_dir,
198 | )
199 |
200 | if args.flash and config.model_type == 'llama':
201 | updates: Dict[str, Union[str, int, float, bool, None]] = {}
202 | flash_model_name = 'Birchlabs/flash_llama--modeling_flash_llama.LlamaForCausalLM'
203 | if 'num_key_value_heads' not in config.__dict__:
204 | updates['num_key_value_heads'] = config.num_attention_heads
205 | if 'auto_map' in config.__dict__:
206 | if not ('AutoModelForCausalLM' in config.auto_map and 'flash' in config.auto_map['AutoModelForCausalLM']):
207 | updates['auto_map']['AutoModelForCausalLM'] = flash_model_name
208 | else:
209 | updates['auto_map'] = { 'AutoModelForCausalLM': flash_model_name }
210 | if 'rope_scaling' not in config.__dict__:
211 | # CodeLlama-Instruct was trained on 16000 token sequences:
212 | # https://ai.meta.com/blog/code-llama-large-language-model-coding/
213 | # WizardCoder was trained on 2048 token sequences (see section 4.2):
214 | # https://arxiv.org/abs/2306.08568
215 | # but both of their HF models report 16384 as the max position embeddings.
216 | # whatever; let's leave the rope scaling as default.
217 | # if you want to do different scaling, I think you'd compute it like this:
218 | # factor = desired_context_length/config.max_position_embeddings
219 | updates['rope_scaling'] = { 'factor': 1., 'type': 'linear' }
220 | if 'pretraining_tp' not in config.__dict__:
221 | updates['pretraining_tp'] = 1
222 | if updates:
223 | config.update(updates)
224 |
225 | cuda_avail = torch.cuda.is_available()
226 | load_in_4bit = args.bits == 4 and cuda_avail
227 | load_in_8bit = args.bits == 8 and cuda_avail
228 |
229 | bnb_compute_dtype: torch.dtype = reify_dtype(args.bnb_compute_dtype)
230 |
231 | quantization_config: Optional[BitsAndBytesConfig] = BitsAndBytesConfig(
232 | load_in_4bit=load_in_4bit,
233 | load_in_8bit=load_in_8bit,
234 | llm_int8_threshold=6.0,
235 | llm_int8_has_fp16_weight=False,
236 | bnb_4bit_compute_dtype=bnb_compute_dtype,
237 | bnb_4bit_use_double_quant=args.double_quant,
238 | bnb_4bit_quant_type=args.quant_type,
239 | ) if cuda_avail else None
240 |
241 | if not cuda_avail:
242 | logger.warning("You don't have CUDA, so we have turned off quantization. If you happen to be on a Mac: maybe you have enough unified memory to run in fp16 anyway…")
243 |
244 | model_dtype: torch.dtype = reify_dtype(args.model_dtype)
245 |
246 | model: LlamaForCausalLM = AutoModelForCausalLM.from_pretrained(
247 | args.model_name_or_path,
248 | config=config,
249 | load_in_4bit=load_in_4bit,
250 | load_in_8bit=load_in_8bit,
251 | device_map='auto',
252 | quantization_config=quantization_config,
253 | torch_dtype=model_dtype,
254 | trust_remote_code=args.trust_remote_code,
255 | cache_dir=args.cache_dir,
256 | ).eval()
257 | model.config.torch_dtype=model_dtype
258 |
259 | return model
260 |
261 | def main():
262 | hfparser = HfArgumentParser((ModelArguments, GenerationArguments, MiscArguments))
263 | model_args, generation_args, misc_args, extra_args = hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
264 | if extra_args:
265 | raise f"Received unsupported command-line args: {extra_args}"
266 | generation_config = GenerationConfig(**vars(generation_args))
267 |
268 | model: LlamaForCausalLM = get_model(model_args)
269 |
270 | set_seed(misc_args.seed)
271 | if misc_args.compile:
272 | torch.compile(model, mode='max-autotune')
273 |
274 | tokenizer: LlamaTokenizerFast = AutoTokenizer.from_pretrained(
275 | model_args.model_name_or_path,
276 | # fast tokenizer required for WizardLM/WizardCoder-Python-34B-V1.0, because slow tokenizer doesn't come with added_tokens (required for {'[PAD]': 32000})
277 | use_fast=True,
278 | cache_dir=model_args.cache_dir,
279 | )
280 | # WizardCoder defines {'[PAD]': 32000}, but CodeLLama doesn't define any pad token, so we fall back to EOS.
281 | generation_config.pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
282 |
283 | stop_token_ids: List[int] = [tokenizer.eos_token_id]
284 | stop = StopOnTokens(stop_token_ids)
285 | stopping_criteria=StoppingCriteriaList([stop])
286 |
287 | system_prompt: Optional[str] = misc_args.system_prompt
288 | if misc_args.system_prompt is None:
289 | match misc_args.prompt_style:
290 | case PromptStyle.WizardCoderPython.value:
291 | system_prompt: str = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.'
292 | case PromptStyle.CodeLlamaInstruct.value:
293 | system_prompt: str = 'Provide answers in Python'
294 | case PromptStyle.Bare.value:
295 | pass
296 | case _:
297 | raise ValueError(f'Never heard of a {misc_args.prompt_style} PromptStyle.')
298 |
299 | # The CodeLlama blog post suggests it's fine to not specify a system prompt:
300 | # https://huggingface.co/blog/codellama
301 | # whereas WizardCoder seems to always use the same (Alpaca-style) system prompt (I wouldn't recommend erasing WizardCoder's system prompt)
302 | optional_system_message: List[Message] = [Message(Participant.System, system_prompt)] if system_prompt else []
303 | history: List[Message] = []
304 |
305 | if misc_args.shot0_input is not None:
306 | assert misc_args.shot0_response is not None, "few-shotting requires you to specify the entire previous turn of the conversation (both --shot0_input and --shot0_response)."
307 | history += [
308 | Message(Participant.User, misc_args.shot0_input),
309 | Message(Participant.Assistant, misc_args.shot0_response),
310 | ]
311 |
312 | reset_ansi='\x1b[0m'
313 | cyan_ansi='\x1b[31;36m'
314 | blue_ansi='\x1b[31;34m'
315 | green_ansi='\x1b[31;32m'
316 | purple_ansi='\x1b[31;35m'
317 |
318 | participant_names: Dict[Participant, str] = {
319 | Participant.User: 'Instruction',
320 | Participant.Assistant: 'Response',
321 | }
322 |
323 | def alpaca_section(envelope: Message) -> str:
324 | participant, message = envelope
325 | if participant is Participant.System:
326 | return message
327 | return f'### {participant_names[participant]}:\n{message}'
328 |
329 | def codellama_turn(user_msg: Message, assistant_msg: Message, is_first: bool) -> str:
330 | preamble = f'<>\n{system_prompt}\n<>\n\n' if is_first and system_prompt else ''
331 | return f'[INST] {preamble}{user_msg.message} [/INST] {assistant_msg.message}'
332 |
333 | next_seed: Optional[int] = None
334 |
335 | first = True
336 | while True:
337 | seed: int = misc_args.seed if next_seed is None else next_seed
338 | if misc_args.reseed_each_prompt or first or next_seed is not None:
339 | set_seed(seed)
340 |
341 | try:
342 | prompt_ctx: str = f'[seed={seed}]' if misc_args.show_seed else ''
343 | if first and misc_args.initial_input is not None:
344 | user_input = misc_args.initial_input
345 | quote: str = f'{purple_ansi}{prompt_ctx}> '
346 | print(f'{quote}{user_input}')
347 | else:
348 | prompt: str = f'{purple_ansi}{prompt_ctx}$ '
349 | user_input = input(f'{blue_ansi}Type a message to begin the conversation…{reset_ansi}\n{prompt}' if first else prompt)
350 | except (KeyboardInterrupt, EOFError):
351 | sys.exit(0)
352 | print(reset_ansi, end='')
353 |
354 | first = False
355 |
356 | user_message = Message(Participant.User, user_input)
357 |
358 | match misc_args.prompt_style:
359 | case PromptStyle.WizardCoderPython.value:
360 | chat_to_complete: str = '\n\n'.join([
361 | alpaca_section(message) for message in [
362 | *optional_system_message,
363 | *history,
364 | user_message,
365 | Message(Participant.Assistant, ''),
366 | ]
367 | ])
368 | case PromptStyle.CodeLlamaInstruct.value:
369 | chat_to_complete: str = ' '.join([codellama_turn(user_msg, assist_msg, ix == 0) for ix, (user_msg, assist_msg) in enumerate(pairwise([
370 | *history,
371 | user_message,
372 | Message(Participant.Assistant, ''),
373 | ]))])
374 | case PromptStyle.Bare.value:
375 | chat_to_complete: str = user_input
376 | case _:
377 | raise ValueError(f'Never heard of a {misc_args.prompt_style} PromptStyle.')
378 |
379 | tokenized_prompts: TokenizerOutput = tokenizer([chat_to_complete], return_tensors='pt', truncation=True)
380 |
381 | print(green_ansi, end='', flush=True)
382 |
383 | response = ''
384 | def on_text(message: str, stream_end = False):
385 | nonlocal response
386 | response += message
387 | print(message, end='', flush=True)
388 |
389 | streamer = CallbackTextIteratorStreamer(tokenizer, callback=on_text, skip_prompt=True, skip_special_tokens=True)
390 |
391 | try:
392 | inference_start: float = perf_counter()
393 | prediction: LongTensor = model.generate(
394 | input_ids=tokenized_prompts.input_ids.to(model.device),
395 | attention_mask=tokenized_prompts.attention_mask.to(model.device),
396 | generation_config=generation_config,
397 | do_sample=generation_config.temperature > 0.,
398 | stopping_criteria=stopping_criteria,
399 | streamer=streamer,
400 | )
401 | # reset ANSI control sequence (plus line break)
402 | print(reset_ansi)
403 | # if you wanted to see the result, you can do so like this:
404 | # decode: List[str] = tokenizer.decode(prediction[0,tokenized_prompts.input_ids.size(-1):], skip_special_tokens=False, clean_up_tokenization_spaces=True)
405 | # print(decode)
406 | # pass
407 | # but we're already streaming it to the console via our callback
408 | inference_duration: float = perf_counter()-inference_start
409 | token_in_count: int = tokenized_prompts.input_ids.size(-1)
410 | token_out_count: int = prediction.size(-1) - token_in_count
411 | tokens_out_per_sec: float = token_out_count/inference_duration
412 | if misc_args.measure_perf:
413 | print(f'{cyan_ansi}ctx length: {token_in_count}\ntokens out: {token_out_count}\nduration: {inference_duration:.2f} secs\nspeed: {tokens_out_per_sec:.2f} tokens/sec{reset_ansi}')
414 | except (KeyboardInterrupt, SufficientResponse, EOFError):
415 | # reset ANSI control sequence (plus line break)
416 | print(reset_ansi)
417 |
418 | # we disable accumulation of conversation history by default, because WizardCoder is not advertised as being finetuned on multi-turn conversations,
419 | # but more importantly because I'd rather spend our 4k context length on a detailed answer for a single-turn than an incomplete answer for multiple turns.
420 | if misc_args.chat_memory:
421 | history += [
422 | user_message,
423 | Message(Participant.Assistant, response)
424 | ]
425 |
426 | if __name__ == "__main__":
427 | main()
--------------------------------------------------------------------------------
/src/callback_text_iterator_streamer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, TextIteratorStreamer
2 | from typing import Optional, Protocol
3 |
4 | class TextCallback(Protocol):
5 | def __call__(self, text: str, stream_end: bool = False) -> None: ...
6 |
7 |
8 | class CallbackTextIteratorStreamer(TextIteratorStreamer):
9 | callback: TextCallback
10 | def __init__(
11 | self, tokenizer: AutoTokenizer, callback: TextCallback, skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
12 | ):
13 | super().__init__(tokenizer, skip_prompt, **decode_kwargs)
14 | self.callback = callback
15 |
16 | def on_finalized_text(self, text: str, stream_end: bool = False):
17 | self.callback(text, stream_end=stream_end)
18 |
--------------------------------------------------------------------------------