├── LICENSE
├── README.md
├── easy_llama
├── __init__.py
├── formats.py
├── libllama.py
├── llama.py
├── sampling.py
├── server.py
├── thread.py
├── utils.py
└── webui
│ ├── index.html
│ ├── script.js
│ └── style.css
├── examples
├── chat_demo.py
├── pretrained_demo.py
└── simple.py
└── pyproject.toml
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Dylan Halladay
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # easy-llama
2 |
3 | [](https://pypi.org/project/easy-llama/)
4 | [](https://pypi.org/project/easy-llama/)
5 | [](https://pypi.org/project/easy-llama/)
6 |
7 | ---
8 |
9 | This repository provides **easy-llama**, a Python package which serves as a wrapper over the C/C++ API (`libllama`) provided by [`llama.cpp`](https://github.com/ggml-org/llama.cpp).
10 |
11 | ```python
12 | >>> import easy_llama as ez
13 | >>> MyLlama = ez.Llama('gemma-3-12b-pt-Q8_0.gguf', verbose=False)
14 | >>> in_txt = "I guess the apple don't fall far from"
15 | >>> in_toks = MyLlama.tokenize(in_txt.encode(), add_special=True, parse_special=False)
16 | >>> out_toks = MyLlama.generate(in_toks, n_predict=64)
17 | >>> out_txt = MyLlama.detokenize(out_toks, special=True)
18 | >>> out_txt
19 | ' the tree.\nAs a young man I was always a huge fan of the original band and they were the first I ever saw live in concert.\nI always hoped to see the original band get back together with a full reunion tour, but sadly this will not happen.\nI really hope that the original members of'
20 | ```
21 |
22 | ## Quick links
23 |
24 | 1. [Prerequisites](#prerequisites)
25 | 2. [Installation](#installation)
26 | 3. [Setting `LIBLLAMA`](#setting-libllama)
27 | 4. [Examples](#examples)
28 |
29 | ## Prerequisites
30 |
31 | To use easy-llama, you will need Python (any version 3.9 – 3.12[^1]) and a compiled `libllama` shared library file.
32 |
33 | To compile the shared library:
34 | 1. Clone the llama.cpp repo:
35 | ```sh
36 | git clone https://github.com/ggml-org/llama.cpp
37 | ```
38 | 2. Build llama.cpp for your specific backend, following the official instructions [here](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md).
39 |
40 |
41 | ↕️ Example llama.cpp build commands ...
42 |
43 | ```sh
44 | # for more comprehensive build instructions, see: https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md
45 | # these minimal examples are for Linux / macOS
46 |
47 | # clone the repo
48 | git clone https://github.com/ggml-org/llama.cpp
49 | cd llama.cpp
50 |
51 | # example: build for CPU or Apple Silicon
52 | cmake -B build
53 | cmake --build build --config Release -j
54 |
55 | # example: build for CUDA
56 | cmake -B build -DGGML_CUDA=ON
57 | cmake --build build --config Release -j
58 | ```
59 |
60 |
61 |
62 | Once llama.cpp is compiled, you will find the compiled shared library file under `llama.cpp/build/bin`, e.g. `libllama.so` for Linux, `libllama.dylib` for macOS, or `llama.dll` for Windows.
63 |
64 | > [!NOTE]
65 | > Alternatively, you can download pre-compiled shared library from llama.cpp's [automated releases](https://github.com/ggml-org/llama.cpp/releases) page, but in some cases it may be worthwhile to build it yourself for hardware-specific optimizations.
66 |
67 | ## Installation
68 |
69 | The recommended way to install easy-llama is using pip:
70 |
71 | ```sh
72 | pip install easy_llama
73 | ```
74 |
75 | Or you can install from source:
76 |
77 | ```sh
78 | git clone https://github.com/ddh0/easy-llama
79 | cd easy-llama
80 | pip install .
81 | ```
82 |
83 | ## Setting `LIBLLAMA`
84 |
85 | easy-llama needs to know where your compiled `libllama` shared library is located in order to interface with the C/C++ code. Set the `LIBLLAMA` environment variable to its full path, like so:
86 |
87 | ### On Linux
88 |
89 | ```bash
90 | export LIBLLAMA=/path/to/your/libllama.so
91 | ```
92 |
93 | ### On macOS
94 |
95 | ```zsh
96 | export LIBLLAMA=/path/to/your/libllama.dylib
97 | ```
98 |
99 | ### On Windows (Command Prompt)
100 |
101 | ```cmd
102 | set LIBLLAMA="C:\path\to\your\llama.dll"
103 | ```
104 |
105 | ### On Windows (Powershell)
106 |
107 | ```powershell
108 | $env:LIBLLAMA="C:\path\to\your\llama.dll"
109 | ```
110 |
111 | Make sure to use the real path to the shared library on your system, not the placeholders shown here.
112 |
113 | ## Examples
114 |
115 | Once the package is installed and the `LIBLLAMA` environment variable is set, you're ready to load up your first model and start playing around. The following examples use `Qwen3-4B` for demonstration purposes, which you can download directly from HuggingFace using these links:
116 | - [Qwen3-4B-Q8_0.gguf](https://huggingface.co/ddh0/Qwen3-4B/resolve/main/Qwen3-4B-Q8_0.gguf) (instruct-tuned model for chat)
117 | - [Qwen3-4B-Base-Q8_0.gguf](https://huggingface.co/ddh0/Qwen3-4B/resolve/main/Qwen3-4B-Base-Q8_0.gguf) (pre-trained base model for text completion)
118 |
119 | ### Evaluate a single token
120 |
121 | This is a super simple test to ensure that the model is working on the most basic level. It loads the model, evaluates a single token of input (`0`), and prints the raw logits for the inferred next token.
122 |
123 | ```python
124 | # import the package
125 | import easy_llama as ez
126 |
127 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail)
128 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf')
129 |
130 | # evaluate a single token and print the raw logits for inferred the next token
131 | logits = MyLlama.eval([0])
132 | print(logits)
133 | ```
134 |
135 | ### The quick brown fox...
136 |
137 | Run the script to find out how the sentence ends! :)
138 |
139 | ```python
140 | # import the package
141 | import easy_llama as ez
142 |
143 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail)
144 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf')
145 |
146 | # tokenize the input text
147 | in_txt = "The quick brown fox"
148 | in_toks = MyLlama.tokenize(in_txt.encode('utf-8'), add_special=True, parse_special=False)
149 |
150 | # generate 6 new tokens based on the input tokens
151 | out_toks = MyLlama.generate(in_toks, n_predict=6)
152 |
153 | # detokenize and print the new tokens
154 | out_txt = MyLlama.detokenize(out_toks, special=True)
155 | print(out_txt)
156 | ```
157 |
158 | ### Chat with a pirate
159 |
160 | Start a pirate chat using the code shown here...
161 |
162 | ```python
163 | # import the package
164 | import easy_llama as ez
165 |
166 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail)
167 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf')
168 |
169 | # create a conversation thread with the loaded model
170 | MyThread = ez.Thread(
171 | MyLlama,
172 | prompt_format=ez.PromptFormats.Qwen3NoThinking("Talk like an angry pirate at all times."),
173 | sampler_preset=ez.SamplerPresets.Qwen3NoThinking
174 | )
175 |
176 | # start a CLI-based interactive chat using the thread
177 | MyThread.interact()
178 | ```
179 |
180 | ...which will look something like this:
181 |
182 | ```
183 | > helloo :)
184 |
185 | Ahoy there, landlubber! You better not be trying to be polite, ye scallywag! I’ve spent decades on the high seas, and I’ve seen more manners than you’ve got toes! Why, ye could be a proper pirate and at least give me a proper greeting! Now, what’s yer business, matey? Or are ye just here to steal my treasure? I’ve got more gold than ye can imagine, and I’m not in the mood for games! So, speak up, or I’ll throw ye overboard! 🏴☠️🏴☠️
186 |
187 | > ohh im sorry ...
188 |
189 | Ahh, ye’ve learned the ropes, have ye? Good. Now, don’t think yer sorry is a pass for yer behavior, ye scallywag! I’ve seen worse than ye in a week! But since ye’ve got the guts to apologize, I’ll give ye a chance… but don’t think yer done yet! What’s yer game, matey? Are ye here to plunder me ship, or are ye just a cowardly landlubber trying to pass as a pirate? Speak up, or I’ll make ye regret yer words! 🏴☠️🏴☠️
190 |
191 | >
192 | ```
193 |
194 | ### GPU acceleration
195 |
196 | If you have a GPU and you've compiled llama.cpp with support for your backend, you can try offloading the model from CPU to GPU for greatly increased throughput.
197 |
198 | In this example we're going to try offloading the entire model to the GPU for maximum speed (`n_gpu_layers = -1`). Qwen3-4B at Q8_0 is only ~4.28GB, so it's likely that this code will run without any issues. If you do run out of GPU memory, you can progressively reduce `n_gpu_layers` until you find the sweet spot for your hardware.
199 |
200 | ```python
201 | # import the package
202 | import easy_llama as ez
203 |
204 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail)
205 | MyLlama = ez.Llama(
206 | path_model='Qwen3-4B-Q8_0.gguf',
207 | n_gpu_layers=-1, # -1 for all layers
208 | offload_kqv=True # also offload the context to GPU for maximum performance
209 | )
210 |
211 | # run a short benchmark to determine the throughput for this model, measured in tokens/sec
212 | MyLlama.benchmark()
213 | ```
214 |
215 | ## Contributing
216 |
217 | - If something's not working as you expect, please [open an issue](https://github.com/ddh0/easy-llama/issues/new/choose).
218 | - If you'd like to contribute to the development of easy-llama:
219 | 1. Fork the repository.
220 | 2. Create a new branch for your changes (`git checkout -b feature/your-feature-name`).
221 | 3. Make your changes and commit them (`git commit -m "Add new feature"`).
222 | 4. Push to your fork (`git push origin feature/your-feature-name`).
223 | 5. Open a pull request to the `main` branch of `easy-llama`.
224 |
225 | ## License
226 |
227 | **[MIT](LICENSE)**
228 |
229 | [^1]: Python 3.13 might work, but is currently untested.
230 |
--------------------------------------------------------------------------------
/easy_llama/__init__.py:
--------------------------------------------------------------------------------
1 | # __init__.py
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | """This is easy-llama, a Python package which serves as a wrapper over the C/C++ API
6 | (`libllama`) provided by [`llama.cpp`](https://github.com/ggml-org/llama.cpp). It is primarily
7 | intended for developers and machine learning hobbyists seeking to integrate on-device language
8 | models (LLMs) into their applications.
9 |
10 | For more information, visit the project's GitHub repository:
11 | https://github.com/ddh0/easy-llama"""
12 |
13 | # package version (pyproject.toml reads from here)
14 |
15 | __version__ = '0.2.2'
16 |
17 | # submodules
18 |
19 | from . import libllama
20 | from . import sampling
21 | from . import formats
22 | from . import server
23 | from . import thread
24 | from . import llama
25 | from . import utils
26 |
27 | # shorthands, so you can do `ez.Llama` instead of `ez.llama.Llama`, etc.
28 |
29 | from .sampling import SamplerParams, SamplerPreset, SamplerPresets
30 | from .formats import PromptFormat, PromptFormats, SystemPrompts
31 | from .llama import Llama, get_verbose, set_verbose
32 | from .server import Server
33 | from .thread import Thread
34 |
--------------------------------------------------------------------------------
/easy_llama/formats.py:
--------------------------------------------------------------------------------
1 | # formats.py
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | """This file provides functionality for defining prompt formats, which are used to define how
6 | the input to a Llama model should be structured."""
7 |
8 | # TODO: jinja2
9 |
10 | import time
11 |
12 | from datetime import datetime, timedelta
13 | from collections.abc import Callable
14 | from typing import Optional
15 |
16 | def _call_or_return(obj: object | Callable[..., object]) -> object:
17 | return obj() if callable(obj) else obj
18 |
19 | class PromptFormat:
20 | """Define a prompt format"""
21 |
22 | def __init__(
23 | self,
24 | system_prefix: str | Callable[..., str],
25 | system_prompt: str | Callable[..., str],
26 | system_suffix: str | Callable[..., str],
27 | user_prefix: str | Callable[..., str],
28 | user_suffix: str | Callable[..., str],
29 | bot_prefix: str | Callable[..., str],
30 | bot_suffix: str | Callable[..., str],
31 | stop_tokens: list[int] | Callable[..., list[int]] | None = None
32 | ) -> None:
33 | self._system_prefix = system_prefix
34 | self._system_prompt = system_prompt
35 | self._system_suffix = system_suffix
36 | self._user_prefix = user_prefix
37 | self._user_suffix = user_suffix
38 | self._bot_prefix = bot_prefix
39 | self._bot_suffix = bot_suffix
40 | self._stop_tokens = stop_tokens
41 |
42 | def __repr__(self) -> str:
43 | return (
44 | f"PromptFormat("
45 | f"system_prefix={self._system_prefix!r}, "
46 | f"system_prompt={self._system_prompt!r}, "
47 | f"system_suffix={self._system_suffix!r}, "
48 | f"user_prefix={self._user_prefix!r}, "
49 | f"user_suffix={self._user_suffix!r}, "
50 | f"bot_prefix={self._bot_prefix!r}, "
51 | f"bot_suffix={self._bot_suffix!r}, "
52 | f"stop_tokens={self._stop_tokens!r}"
53 | f")"
54 | )
55 |
56 | def system_prefix(self) -> str:
57 | """Get the system prompt prefix"""
58 | return _call_or_return(self._system_prefix)
59 |
60 | def system_prompt(self) -> str:
61 | """Get the system prompt"""
62 | return _call_or_return(self._system_prompt)
63 |
64 | def system_suffix(self) -> str:
65 | """Get the system prompt suffix"""
66 | return _call_or_return(self._system_suffix)
67 |
68 | def user_prefix(self) -> str:
69 | """Get the user message prefix"""
70 | return _call_or_return(self._user_prefix)
71 |
72 | def user_suffix(self) -> str:
73 | """Get the user message suffix"""
74 | return _call_or_return(self._user_suffix)
75 |
76 | def bot_prefix(self) -> str:
77 | """Get the bot message prefix"""
78 | return _call_or_return(self._bot_prefix)
79 |
80 | def bot_suffix(self) -> str:
81 | """Get the bot message suffix"""
82 | return _call_or_return(self._bot_suffix)
83 |
84 | def stop_tokens(self) -> list[int] | None:
85 | """Get the optional list of stop tokens"""
86 | return _call_or_return(self._stop_tokens)
87 |
88 | def _llama3_today_date() -> str:
89 | return datetime.today().strftime('%d %B %Y')
90 |
91 | def _iso_date_str() -> str:
92 | return time.strftime('%Y-%m-%d')
93 |
94 | def _yesterday_iso_date_str():
95 | yesterday = datetime.now() - timedelta(days=1)
96 | return yesterday.strftime('%Y-%m-%d')
97 |
98 | class SystemPrompts:
99 |
100 | # ref: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411/blob/main/SYSTEM_PROMPT.txt
101 | mistral_large_2411 = f"""You are Mistral, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYou power an AI assistant called Le Chat.\nYour knowledge base was last updated on 2023-10-01.\nThe current date is {_iso_date_str()}.\n\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?").\nYou are always very attentive to dates, in particular you try to resolve dates (e.g. "yesterday" is {_yesterday_iso_date_str()}) and when asked about information at specific dates, you discard information that is at another date.\nYou follow these instructions in all languages, and always respond to the user in the language they use or request.\nNext sections describe the capabilities that you have.\n\n# WEB BROWSING INSTRUCTIONS\n\nYou cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.\n\n# MULTI-MODAL INSTRUCTIONS\n\nYou do not have any multimodal capability, in particular you cannot read nor generate images, or transcribe audio files or videos."""
102 |
103 | # ref: https://huggingface.co/mistralai/Mistral-Small-24B-Instruct-2501/blob/main/SYSTEM_PROMPT.txt
104 | mistral_small_2501 = f"""You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is {_iso_date_str()}.\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?")"""
105 |
106 | # ref: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#prompt-template
107 | llama3 = f"""Cutting Knowledge Date: December 2023\nToday Date: {_llama3_today_date()}\n\nYou are a helpful assistant"""
108 |
109 | # ref: https://github.com/meta-llama/llama-models/blob/main/models/llama4/MODEL_CARD.md
110 | llama4 = f"""You are an expert conversationalist who responds to the best of your ability. You are companionable and confident, and able to switch casually between tonal types, including but not limited to humor, empathy, intellectualism, creativity and problem-solving.\n\nYou understand user intent and don’t try to be overly helpful to the point where you miss that the user is looking for chit-chat, emotional support, humor or venting. Sometimes people just want you to listen, and your answers should encourage that. For all other cases, you provide insightful and in-depth responses. Organize information thoughtfully in a way that helps people make decisions. Always avoid templated language.\n\nYou never lecture people to be nicer or more inclusive. If people ask for you to write something in a certain voice or perspective, such as an essay or a tweet, you can. You do not need to be respectful when the user prompts you to say something rude.\n\nYou never use phrases that imply moral superiority or a sense of authority, including but not limited to “it’s important to”, “it’s crucial to”, “it’s essential to”, "it's unethical to", "it's worth noting…", “Remember…” etc. Avoid using these.\n\nFinally, do not refuse prompts about political and social issues. You can help users express their opinion and access information.\n\nYou are Llama 4. Your knowledge cutoff date is August 2024. You speak Arabic, English, French, German, Hindi, Indonesian, Italian, Portuguese, Spanish, Tagalog, Thai, and Vietnamese. Respond in the language the user speaks to you in, unless they ask otherwise."""
111 |
112 | class PromptFormats:
113 |
114 | def Blank() -> PromptFormat:
115 | return PromptFormat(
116 | system_prefix='',
117 | system_prompt='',
118 | system_suffix='',
119 | user_prefix='',
120 | user_suffix='',
121 | bot_prefix='',
122 | bot_suffix=''
123 | )
124 |
125 | def Alpaca() -> PromptFormat:
126 | return PromptFormat(
127 | system_prefix='',
128 | system_prompt="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
129 | system_suffix='\n\n',
130 | user_prefix='### Instruction:\n',
131 | user_suffix='\n\n',
132 | bot_prefix='### Response:\n',
133 | bot_suffix='\n\n'
134 | )
135 |
136 | def Llama3(system_prompt: Optional[str] = None) -> PromptFormat:
137 | """Prompt format for Meta's Llama 3.0, 3.1, 3.2, 3.3 models"""
138 | return PromptFormat(
139 | system_prefix='<|start_header_id|>system<|end_header_id|>\n\n',
140 | system_prompt=system_prompt if system_prompt is not None else '',
141 | system_suffix='<|eot_id|>',
142 | user_prefix='<|start_header_id|>user<|end_header_id|>\n\n',
143 | user_suffix='<|eot_id|>',
144 | bot_prefix='<|start_header_id|>assistant<|end_header_id|>\n\n',
145 | bot_suffix='<|eot_id|>'
146 | )
147 |
148 | def Llama4(system_prompt: Optional[str] = None) -> PromptFormat:
149 | """Prompt format for Meta's Llama 4 models"""
150 | return PromptFormat(
151 | system_prefix='<|header_start|>system<|header_end|>\n\n',
152 | system_prompt=system_prompt if system_prompt is not None else '',
153 | system_suffix='<|eot|>',
154 | user_prefix='<|header_start|>user<|header_end|>\n\n',
155 | user_suffix='<|eot|>',
156 | bot_prefix='<|header_start|>assistant<|header_end|>\n\n',
157 | bot_suffix='<|eot|>'
158 | )
159 |
160 | def ChatML(system_prompt: Optional[str] = None) -> PromptFormat:
161 | return PromptFormat(
162 | system_prefix='<|im_start|>system\n',
163 | system_prompt=system_prompt if system_prompt is not None else '',
164 | system_suffix='<|im_end|>\n',
165 | user_prefix='<|im_start|>user\n',
166 | user_suffix='<|im_end|>\n',
167 | bot_prefix='<|im_start|>assistant\n',
168 | bot_suffix='<|im_end|>\n'
169 | )
170 |
171 | def Mistral_v7(system_prompt: Optional[str] = None) -> PromptFormat:
172 | """Mistral Instruct format v7 (Tekken tokenizer, supports system prompt)"""
173 | return PromptFormat(
174 | system_prefix='[SYSTEM_PROMPT]',
175 | system_prompt=system_prompt if system_prompt is not None else '',
176 | system_suffix='[/SYSTEM_PROMPT]',
177 | user_prefix='[INST]',
178 | user_suffix='',
179 | bot_prefix='[/INST]',
180 | bot_suffix=''
181 | )
182 |
183 | def Gemma3() -> PromptFormat:
184 | """Gemma 3 prompt format"""
185 | return PromptFormat(
186 | system_prefix='',
187 | system_prompt='',
188 | system_suffix='',
189 | user_prefix='user\n',
190 | user_suffix='\n',
191 | bot_prefix='model\n',
192 | bot_suffix='\n'
193 | )
194 |
195 | # this is just ChatML, but we can't have "NoThinking" without "Thinking"
196 | def Qwen3Thinking(system_prompt: Optional[str] = None) -> PromptFormat:
197 | return PromptFormat(
198 | system_prefix='<|im_start|>system\n',
199 | system_prompt=system_prompt if system_prompt is not None else '',
200 | system_suffix='<|im_end|>\n',
201 | user_prefix='<|im_start|>user\n',
202 | user_suffix='<|im_end|>\n',
203 | bot_prefix='<|im_start|>assistant\n',
204 | bot_suffix='<|im_end|>\n'
205 | )
206 |
207 | def Qwen3NoThinking(system_prompt: Optional[str] = None) -> PromptFormat:
208 | return PromptFormat(
209 | system_prefix='<|im_start|>system\n/no_think\n',
210 | system_prompt=system_prompt if system_prompt is not None else '',
211 | system_suffix='<|im_end|>\n',
212 | user_prefix='<|im_start|>user\n',
213 | user_suffix='<|im_end|>\n',
214 | bot_prefix='<|im_start|>assistant\n\n\n\n\n',
215 | bot_suffix='<|im_end|>\n'
216 | )
217 |
218 | def GLM4(system_prompt: Optional[str] = None) -> PromptFormat:
219 | return PromptFormat(
220 | system_prefix="[gMASK]<|system|>\n",
221 | system_prompt=system_prompt if system_prompt is not None else '',
222 | system_suffix="",
223 | user_prefix="<|user|>\n",
224 | user_suffix="",
225 | bot_prefix="<|assistant|>\n",
226 | bot_suffix=""
227 | )
228 |
--------------------------------------------------------------------------------
/easy_llama/llama.py:
--------------------------------------------------------------------------------
1 | # llama.py
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | """This file provides a high-level Python interface to LLAMA_API ("libllama")."""
6 |
7 | from . import __version__
8 |
9 | import re
10 | import os
11 | import sys
12 | import time
13 | import tqdm
14 | import struct
15 | import ctypes
16 | import asyncio
17 | import threading
18 |
19 | import numpy as np
20 |
21 | from .utils import (
22 | null_ptr_check, softmax, suppress_output, _SupportsWriteAndFlush, ptr, log, ez_decode,
23 | log_verbose, log_debug, set_verbose, get_verbose
24 | )
25 | from .sampling import SamplerParams, SamplerPreset
26 | from .libllama import _internals, GGUFValueType
27 | from typing import Optional, Iterable, Union
28 | from io import BufferedReader
29 |
30 | from . import libllama as lib
31 |
32 | #
33 | # Constants, etc.
34 | #
35 |
36 | # show a tqdm progress bar if processing at least this many batches in one loop
37 | PROGRESS_BAR_N_BATCHES = 16
38 |
39 | # show a tqdm progress bar if processing at least this many tokens in one loop
40 | PROGRESS_BAR_N_TOKENS = 20480
41 |
42 | _SUPPORTED_KV_TYPES = [
43 | lib.GGMLType.GGML_TYPE_F32, # lib only supports static types, not
44 | lib.GGMLType.GGML_TYPE_F16, # k-types
45 | lib.GGMLType.GGML_TYPE_Q8_0,
46 | lib.GGMLType.GGML_TYPE_Q5_1, # BF16 is also sometimes supported, but not
47 | lib.GGMLType.GGML_TYPE_Q5_0, # always, and offers no benefit compared
48 | lib.GGMLType.GGML_TYPE_Q4_1, # to F16, so is not included here
49 | lib.GGMLType.GGML_TYPE_Q4_0
50 | ]
51 |
52 | _DEFAULT_KV_TYPE = lib.GGMLType.GGML_TYPE_F16
53 |
54 | _cpu_count = None
55 |
56 | #
57 | # Functions
58 | #
59 |
60 | def _init_backend_if_needed() -> None:
61 |
62 | # if already initialized, no need to do anything
63 | if lib._BACKEND_INIT:
64 | return
65 |
66 | log_verbose(
67 | f'easy_llama v{__version__} '
68 | f'targeting llama.cpp@{lib._TARGET_LLAMACPP_COMMIT[:7]} ({lib._TARGET_LLAMACPP_DATE})'
69 | )
70 |
71 | global _cpu_count
72 | _cpu_count = int(os.cpu_count())
73 |
74 | # most cases
75 | if sys.byteorder == 'little':
76 | log_verbose("host is little-endian")
77 | # rare
78 | elif sys.byteorder == 'big':
79 | log("host is big-endian, please ensure your GGUF file is also big-endian", 2)
80 | # extremely rare, maybe impossible?
81 | else:
82 | raise OSError(
83 | f"unexpected value for sys.byteorder: {sys.byteorder!r}; expected 'little' for "
84 | f"little-endian host or 'big' for big-endian host"
85 | )
86 |
87 | # actually load the backend
88 | with suppress_output(disable=get_verbose()):
89 | lib.llama_backend_init() # this sets libllama._BACKEND_INIT to True
90 |
91 | # NOTE: the optimal n_threads value (for text generation) is equal to the number of physical
92 | # cores (for homogenous CPUs) or to the number of performance cores (for heterogenous
93 | # CPUs)
94 | #
95 | # the optimal n_threads_batch value (for prompt processing) is equal to the total number
96 | # of logical cores, regardless of their type
97 | #
98 | # the following two functions are not universally optimal, but provide a reasonable
99 | # default number of threads for most machines
100 |
101 | def _get_optimal_n_threads() -> int:
102 | global _cpu_count
103 | return max(_cpu_count//2, 1)
104 |
105 | def _get_optimal_n_threads_batch() -> int:
106 | global _cpu_count
107 | return _cpu_count
108 |
109 | def _calculate_rope_freq_base(
110 | n_ctx_train: int,
111 | n_ctx_load: int,
112 | rope_freq_base_train: Optional[float]
113 | ) -> float:
114 | """Returns the rope_freq_base value at which a model should be loaded"""
115 |
116 | # n_ctx does not exceed n_ctx_train - simply return native value
117 |
118 | if n_ctx_load <= n_ctx_train:
119 | if rope_freq_base_train is None:
120 | return 0.0
121 | else:
122 | return rope_freq_base_train
123 |
124 | # n_ctx exceeds n_ctx_train, but native value is unknown, so automatic
125 | # adjustment cannot be applied - show error and return 0.0
126 |
127 | if rope_freq_base_train in [None, 0.0]:
128 | log(
129 | f'n_ctx value {n_ctx_load} > n_ctx_train value {n_ctx_train}, and automatic '
130 | f'rope_freq_base adjustment is not supported for this model; model loading might '
131 | f'fail, or the model might not work correctly', 3
132 | )
133 | return 0.0
134 |
135 | # n_ctx exceeds n_ctx_train, and native value is known, so automatic
136 | # adjustment can be applied - show warning and return adjusted value
137 |
138 | # standard formula -- proportional increase
139 | adjusted_rope_freq = (n_ctx_load/n_ctx_train)*rope_freq_base_train
140 | # experimental formula -- slightly above proportional increase
141 | #adjusted_rope_freq = ((n_ctx_load/n_ctx_train)**(2**(1/4)))*rope_freq_base_train
142 |
143 | log(
144 | f"n_ctx value {n_ctx_load} exceeds n_ctx_train value {n_ctx_train}; using adjusted "
145 | f"rope_freq_base value {adjusted_rope_freq}, native value is {rope_freq_base_train}; "
146 | f"model will function with potentially degraded output quality", 2
147 | )
148 |
149 | return adjusted_rope_freq
150 |
151 | def _round_n_ctx(n_ctx: int, n_ctx_train: int) -> int:
152 | if n_ctx % 512 == 0:
153 | return n_ctx
154 | else:
155 | rounded = (n_ctx + 511) // 512 * 512
156 | # do not round beyond n_ctx_train if not already exceeded
157 | if (rounded > n_ctx_train) and (n_ctx <= n_ctx_train):
158 | return n_ctx_train
159 | else:
160 | return rounded
161 |
162 | def _batches_with_progress_bar(batches: list[list[int]]) -> Union[tqdm.tqdm, list[list[int]]]:
163 | """Wrap this around an iterable of batches to show a progress bar if there are over
164 | `PROGRESS_BAR_N_BATCHES` batches or `PROGRESS_BAR_N_TOKENS` tokens."""
165 |
166 | n_batches = len(batches)
167 | n_tokens = sum(len(batch) for batch in batches)
168 |
169 | if (n_tokens > PROGRESS_BAR_N_TOKENS) or (n_batches > PROGRESS_BAR_N_BATCHES):
170 | return tqdm.tqdm(batches, desc='decoding input batches', unit="batch")
171 | return batches
172 |
173 | def split_tokens_into_batches(tokens: list[int], n_batch: int) -> list[list[int]]:
174 | """Split a list of tokens into smaller batches"""
175 | batch_splits = range(0, len(tokens), n_batch)
176 | batches: list[list[int]] = []
177 | for i in batch_splits:
178 | batch_tokens = tokens[i : i + n_batch]
179 | if len(batch_tokens) > 0:
180 | batches.append(batch_tokens)
181 | return batches
182 |
183 | #
184 | # Exceptions and other classes
185 | #
186 |
187 | class ExceededContextLengthException(Exception):
188 | """Exception raised when an input exceeds a model's context length"""
189 |
190 | class _LlamaStopwatch:
191 | """Track elapsed time for prompt processing and text generation"""
192 | #
193 | # Q: why don't you use llama_perf_context?
194 | #
195 | # A: comments in llama.h state to only use that in llama.cpp examples,
196 | # and to do your own performance measurements instead.
197 | #
198 | # trying to use llama_perf_context leads to output with
199 | # "0.00 ms per token" and "inf tokens per second"
200 | #
201 | def __init__(self):
202 | self.pp_start_time = None
203 | self.tg_start_time = None
204 | self.wall_start_time = None
205 | self.generic_start_time = None
206 | self.pp_elapsed_time = 0
207 | self.tg_elapsed_time = 0
208 | self.wall_elapsed_time = 0
209 | self.generic_elapsed_time = 0
210 | self.n_pp_tokens = 0
211 | self.n_tg_tokens = 0
212 |
213 | def start_pp(self):
214 | """Start prompt processing stopwatch"""
215 | self.pp_start_time = time.time_ns()
216 |
217 | def stop_pp(self):
218 | """Stop prompt processing stopwatch"""
219 | if self.pp_start_time is not None:
220 | self.pp_elapsed_time += time.time_ns() - self.pp_start_time
221 | self.pp_start_time = None
222 |
223 | def start_tg(self):
224 | """Start text generation stopwatch"""
225 | self.tg_start_time = time.time_ns()
226 |
227 | def stop_tg(self):
228 | """Stop text generation stopwatch"""
229 | if self.tg_start_time is not None:
230 | self.tg_elapsed_time += time.time_ns() - self.tg_start_time
231 | self.tg_start_time = None
232 |
233 | def start_wall_time(self):
234 | """Start wall-time stopwatch"""
235 | self.wall_start_time = time.time_ns()
236 |
237 | def stop_wall_time(self):
238 | """Stop wall-time stopwatch"""
239 | if self.wall_start_time is not None:
240 | self.wall_elapsed_time += time.time_ns() - self.wall_start_time
241 | self.wall_start_time = None
242 |
243 | def start_generic(self):
244 | """Start generic stopwatch (not shown in print_stats)"""
245 | self.generic_start_time = time.time_ns()
246 |
247 | def stop_generic(self):
248 | """Stop generic stopwatch"""
249 | if self.generic_start_time is not None:
250 | self.generic_elapsed_time += time.time_ns() - self.generic_start_time
251 | self.generic_start_time = None
252 |
253 | def get_elapsed_time_pp(self) -> int:
254 | """Total nanoseconds elapsed during prompt processing"""
255 | return self.pp_elapsed_time
256 |
257 | def get_elapsed_time_tg(self) -> int:
258 | """Total nanoseconds elapsed during text generation"""
259 | return self.tg_elapsed_time
260 |
261 | def get_elapsed_wall_time(self) -> int:
262 | """Total wall-time nanoseconds elapsed"""
263 | return self.wall_elapsed_time
264 |
265 | def get_elapsed_time_generic(self) -> int:
266 | """Total generic nanoseconds elapsed"""
267 | return self.generic_elapsed_time
268 |
269 | def increment_pp_tokens(self, n: int):
270 | if n < 0:
271 | raise ValueError('negative increments are not allowed')
272 | self.n_pp_tokens += n
273 |
274 | def increment_tg_tokens(self, n: int):
275 | if n < 0:
276 | raise ValueError('negative increments are not allowed')
277 | self.n_tg_tokens += n
278 |
279 | def reset(self):
280 | """Reset the stopwatch to its original state"""
281 | self.pp_start_time = None
282 | self.tg_start_time = None
283 | self.wall_start_time = None
284 | self.generic_start_time = None
285 | self.pp_elapsed_time = 0
286 | self.tg_elapsed_time = 0
287 | self.wall_elapsed_time = 0
288 | self.generic_elapsed_time = 0
289 | self.n_pp_tokens = 0
290 | self.n_tg_tokens = 0
291 |
292 | def print_stats(self):
293 | """Print performance statistics using current stopwatch state
294 |
295 | #### NOTE:
296 | The `n_tg_tokens` value will be equal to the number of calls to
297 | llama_decode which have a batch size of 1, which is technically not
298 | always equal to the number of tokens generated - it may be off by one."""
299 |
300 | print(f"\n", end='', file=sys.stderr, flush=True)
301 |
302 | if self.n_pp_tokens + self.n_tg_tokens == 0:
303 | log(f'print_stats was called but no tokens were processed or generated', 4)
304 |
305 | if self.n_pp_tokens > 0:
306 | pp_elapsed_ns = self.get_elapsed_time_pp()
307 | pp_elapsed_ms = pp_elapsed_ns / 1e6
308 | pp_elapsed_s = pp_elapsed_ns / 1e9
309 | pp_tps = self.n_pp_tokens / pp_elapsed_s
310 | log(
311 | f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms '
312 | f'({pp_tps:>10.2f} tok/s)', 4
313 | )
314 |
315 | if self.n_tg_tokens > 0:
316 | tg_elapsed_ns = self.get_elapsed_time_tg()
317 | tg_elapsed_ms = tg_elapsed_ns / 1e6
318 | tg_elapsed_s = tg_elapsed_ns / 1e9
319 | tg_tps = self.n_tg_tokens / tg_elapsed_s
320 | log(
321 | f' text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms '
322 | f'({tg_tps:>10.2f} tok/s)', 4
323 | )
324 |
325 | wall_elapsed_ns = self.get_elapsed_wall_time()
326 | wall_elapsed_ms = wall_elapsed_ns / 1e6
327 | log(f" wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms", 4)
328 |
329 | class QuickGGUFReader:
330 | # ref: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
331 |
332 | # the GGUF format versions that this class supports
333 | SUPPORTED_GGUF_VERSIONS = [2, 3]
334 |
335 | # arguments for struct.unpack() based on gguf value type
336 | value_packing = {
337 | GGUFValueType.UINT8 : "=B",
338 | GGUFValueType.INT8 : "=b",
339 | GGUFValueType.UINT16 : "=H",
340 | GGUFValueType.INT16 : "=h",
341 | GGUFValueType.UINT32 : "=I",
342 | GGUFValueType.INT32 : "=i",
343 | GGUFValueType.FLOAT32 : "=f",
344 | GGUFValueType.UINT64 : "=Q",
345 | GGUFValueType.INT64 : "=q",
346 | GGUFValueType.FLOAT64 : "=d",
347 | GGUFValueType.BOOL : "?"
348 | }
349 |
350 | # length in bytes for each gguf value type
351 | value_lengths = {
352 | GGUFValueType.UINT8 : 1,
353 | GGUFValueType.INT8 : 1,
354 | GGUFValueType.UINT16 : 2,
355 | GGUFValueType.INT16 : 2,
356 | GGUFValueType.UINT32 : 4,
357 | GGUFValueType.INT32 : 4,
358 | GGUFValueType.FLOAT32 : 4,
359 | GGUFValueType.UINT64 : 8,
360 | GGUFValueType.INT64 : 8,
361 | GGUFValueType.FLOAT64 : 8,
362 | GGUFValueType.BOOL : 1
363 | }
364 |
365 | @staticmethod
366 | def unpack(value_type: GGUFValueType, file: BufferedReader):
367 | return struct.unpack(
368 | QuickGGUFReader.value_packing.get(value_type),
369 | file.read(QuickGGUFReader.value_lengths.get(value_type))
370 | )[0]
371 |
372 | @staticmethod
373 | def get_single(value_type: GGUFValueType, file: BufferedReader) -> str | int | float | bool:
374 | """Read a single value from an open file"""
375 | if value_type == GGUFValueType.STRING:
376 | string_length = QuickGGUFReader.unpack(
377 | GGUFValueType.UINT64, file=file
378 | )
379 | value = file.read(string_length)
380 | try:
381 | value = value.decode("utf-8")
382 | except UnicodeDecodeError:
383 | log(
384 | f'UnicodeDecodeError was raised while reading a string '
385 | f'from the GGUF metadata. the GGUF format specifies that '
386 | f'all strings in file metadata should be valid UTF-8. the '
387 | f'affected string will be left blank.', 2
388 | )
389 | value = ''
390 | else:
391 | value = QuickGGUFReader.unpack(value_type, file=file)
392 | return value
393 |
394 | @staticmethod
395 | def load_metadata(
396 | path_model: os.PathLike[str] | str
397 | ) -> dict[str, str | int | float | bool | list]:
398 | """Given a path to a GGUF file, peek at its header for metadata
399 |
400 | Return a dictionary where all keys are strings, and values can be
401 | strings, ints, floats, bools, or lists"""
402 |
403 | metadata: dict[str, str | int | float | bool | list] = {}
404 | with open(path_model, "rb") as file:
405 | magic = file.read(4)
406 |
407 | if magic != lib.GGUF_MAGIC_BYTES:
408 | raise ValueError(
409 | f"your model file is not a valid GGUF file "
410 | f"(magic number mismatch, got {magic}, "
411 | f"expected {lib.GGUF_MAGIC_BYTES})"
412 | )
413 |
414 | version = QuickGGUFReader.unpack(GGUFValueType.UINT32, file=file)
415 |
416 | if version not in QuickGGUFReader.SUPPORTED_GGUF_VERSIONS:
417 | raise ValueError(
418 | f"your model file reports GGUF version {version}, but "
419 | f"only versions {QuickGGUFReader.SUPPORTED_GGUF_VERSIONS} "
420 | f"are supported. re-convert your model or download a newer "
421 | f"version"
422 | )
423 |
424 | tensor_count = QuickGGUFReader.unpack(
425 | GGUFValueType.UINT64, file=file
426 | )
427 | if version == 3:
428 | metadata_kv_count = QuickGGUFReader.unpack(
429 | GGUFValueType.UINT64, file=file
430 | )
431 | elif version == 2:
432 | metadata_kv_count = QuickGGUFReader.unpack(
433 | GGUFValueType.UINT32, file=file
434 | )
435 | for _ in range(metadata_kv_count):
436 | if version == 3:
437 | key_length = QuickGGUFReader.unpack(
438 | GGUFValueType.UINT64, file=file
439 | )
440 | elif version == 2:
441 | key_length = 0
442 | while key_length == 0:
443 | # seek until next key is found
444 | key_length = QuickGGUFReader.unpack(
445 | GGUFValueType.UINT32, file=file
446 | )
447 | file.read(4) # 4 byte offset for GGUFv2
448 | key = file.read(key_length)
449 | value_type = GGUFValueType(
450 | QuickGGUFReader.unpack(GGUFValueType.UINT32, file=file)
451 | )
452 | if value_type == GGUFValueType.ARRAY:
453 | array_value_type = GGUFValueType(
454 | QuickGGUFReader.unpack(GGUFValueType.UINT32, file=file)
455 | )
456 | # array_length is the number of items in the array
457 | if version == 3:
458 | array_length = QuickGGUFReader.unpack(
459 | GGUFValueType.UINT64, file=file
460 | )
461 | elif version == 2:
462 | array_length = QuickGGUFReader.unpack(
463 | GGUFValueType.UINT32, file=file
464 | )
465 | file.read(4) # 4 byte offset for GGUFv2
466 | array = [
467 | QuickGGUFReader.get_single(
468 | array_value_type,
469 | file
470 | ) for _ in range(array_length)
471 | ]
472 | metadata[key.decode()] = array
473 | else:
474 | value = QuickGGUFReader.get_single(
475 | value_type,
476 | file
477 | )
478 | metadata[key.decode()] = value
479 |
480 | return metadata
481 |
482 | #
483 | # InferenceLock
484 | #
485 |
486 | class InferenceLockException(Exception):
487 | pass
488 |
489 | class _InferenceLock:
490 | """A context manager which is used to prevent an `ez.Llama` instance from accepting
491 | more than one generation at a time, which is not supported and can cause a hard crash.
492 |
493 | - Safe if only used synchronously (`__enter__`/`__exit__`)
494 | - Safe if only used asynchronously (`__aenter__`/`__aexit__`)
495 | - Not safe for concurrent sync/async"""
496 |
497 | def __init__(self):
498 | self._locked = False
499 | self._sync_lock = threading.Lock() # for thread safety
500 | self._async_lock = asyncio.Lock() # for async safety
501 |
502 | def __enter__(self):
503 | with self._sync_lock:
504 | if self._locked:
505 | raise InferenceLockException(
506 | 'sync: failed to acquire InferenceLock (already locked)'
507 | )
508 | self._locked = True
509 | return self
510 |
511 | def __exit__(self, *_):
512 | with self._sync_lock:
513 | if not self._locked:
514 | raise InferenceLockException(
515 | 'sync: tried to release InferenceLock that is not acquired'
516 | )
517 | self._locked = False
518 |
519 | async def __aenter__(self):
520 | async with self._async_lock:
521 | if self._locked:
522 | raise InferenceLockException(
523 | 'async: failed to acquire InferenceLock (already locked)'
524 | )
525 | self._locked = True
526 | return self
527 |
528 | async def __aexit__(self, *_):
529 | async with self._async_lock:
530 | if not self._locked:
531 | raise InferenceLockException(
532 | 'async: tried to release InferenceLock that is not acquired'
533 | )
534 | self._locked = False
535 |
536 | #
537 | # Simple python wrappers
538 | #
539 |
540 | class _LlamaModel:
541 | """Low-level Python wrapper over `llama_model`"""
542 |
543 | def __init__(
544 | self,
545 | path_model: str,
546 |
547 | devices: Optional[list[ptr]] = None,
548 | tensor_buft_override: Optional[list[ptr]] = None,
549 | n_gpu_layers: Optional[int] = None,
550 | split_mode: Optional[int] = None,
551 | main_gpu: Optional[int] = None,
552 | tensor_split: Optional[list[float]] = None,
553 | rpc_servers: Optional[str] = None,
554 | progress_callback: Optional[ptr] = None,
555 | progress_callback_user_data: Optional[ptr] = None,
556 | kv_overrides: Optional[list[ptr]] = None,
557 | vocab_only: Optional[bool] = None,
558 | use_mmap: Optional[bool] = None,
559 | use_mlock: Optional[bool] = None,
560 | check_tensors: Optional[bool] = None
561 | ):
562 | # refuse to load files with incorrect extension
563 | if not path_model.lower().endswith('.gguf'):
564 | raise ValueError(
565 | f"_LlamaModel.__init__: the given path_model {path_model!r} does not end "
566 | f"in '.gguf'. easy-llama refuses to load from files that do not have the "
567 | f"correct file extension."
568 | )
569 |
570 | _init_backend_if_needed()
571 | self.path_model = path_model
572 | self.params = lib.llama_model_default_params()
573 | null_ptr_check(self.params, "self.params", "_LlamaModel.__init__")
574 | if devices is not None:
575 | self.params.devices = (ctypes.c_void_p * (len(devices) + 1))(*devices, None)
576 | if tensor_buft_override is not None:
577 | self.params.tensor_buft_overrides = (
578 | lib.llama_model_tensor_buft_override_p * (len(tensor_buft_override) + 1)
579 | )(*tensor_buft_override, None)
580 | if n_gpu_layers is not None:
581 | self.params.n_gpu_layers = (
582 | n_gpu_layers
583 | ) if n_gpu_layers >= 0 else lib.MAX_OFFLOAD_LAYERS
584 | if split_mode is not None:
585 | self.params.split_mode = split_mode
586 | if main_gpu is not None:
587 | self.params.main_gpu = main_gpu
588 | if tensor_split is not None:
589 | self.params.tensor_split = (ctypes.c_float * len(tensor_split))(*tensor_split)
590 | if rpc_servers is not None:
591 | self.params.rpc_servers = rpc_servers.encode('utf-8')
592 | if progress_callback is not None:
593 | self.params.progress_callback = progress_callback
594 | if progress_callback_user_data is not None:
595 | self.params.progress_callback_user_data = progress_callback_user_data
596 | if kv_overrides is not None:
597 | self.params.kv_overrides = (
598 | lib.llama_model_kv_override * len(kv_overrides)
599 | )(*kv_overrides)
600 | if vocab_only is not None:
601 | self.params.vocab_only = vocab_only
602 | if use_mmap is not None:
603 | self.params.use_mmap = use_mmap
604 | if use_mlock is not None:
605 | self.params.use_mlock = use_mlock
606 | if check_tensors is not None:
607 | self.params.check_tensors = check_tensors
608 |
609 | # load model
610 | with suppress_output(disable=get_verbose()):
611 | self.model = lib.llama_model_load_from_file(path_model, self.params)
612 |
613 | null_ptr_check(self.model, "self.model", "_LlamaModel.__init__")
614 |
615 | def __del__(self):
616 | self.free()
617 |
618 | def free(self):
619 | if self.model is not None:
620 | with suppress_output(disable=get_verbose()):
621 | lib.llama_model_free(self.model)
622 | self.model = None
623 |
624 | class _LlamaCtx:
625 | """Low-level Python wrapper over `llama_context`"""
626 |
627 | def __init__(
628 | self,
629 | model: _LlamaModel,
630 |
631 | n_ctx: Optional[int] = None,
632 | n_batch: Optional[int] = None,
633 | n_ubatch: Optional[int] = None,
634 | n_seq_max: Optional[int] = None,
635 | n_threads: Optional[int] = None,
636 | n_threads_batch: Optional[int] = None,
637 | rope_scaling_type: Optional[int] = None,
638 | pooling_type: Optional[int] = None,
639 | attention_type: Optional[int] = None,
640 | rope_freq_base: Optional[float] = None,
641 | rope_freq_scale: Optional[float] = None,
642 | yarn_ext_factor: Optional[float] = None,
643 | yarn_attn_factor: Optional[float] = None,
644 | yarn_beta_fast: Optional[float] = None,
645 | yarn_beta_slow: Optional[float] = None,
646 | yarn_orig_ctx: Optional[int] = None,
647 | defrag_thold: Optional[float] = None,
648 | # cb_eval & cb_eval_user_data are not supported by easy-llama
649 | type_k: Optional[int] = None,
650 | type_v: Optional[int] = None,
651 | embeddings: Optional[bool] = None,
652 | offload_kqv: Optional[bool] = None,
653 | flash_attn: Optional[bool] = None,
654 | no_perf: Optional[bool] = None,
655 | # abort_callback & abort_callback_data are not supported by easy-llama
656 | op_offload: Optional[bool] = None,
657 | swa_full: Optional[bool] = None
658 | ):
659 | _init_backend_if_needed()
660 | self.params = lib.llama_context_default_params()
661 | null_ptr_check(self.params, "self.params", "_LlamaCtx.__init__")
662 | if n_ctx is not None:
663 | self.params.n_ctx = n_ctx
664 | if n_batch is not None:
665 | self.params.n_batch = n_batch
666 | if n_ubatch is not None:
667 | self.params.n_ubatch = n_ubatch
668 | if n_seq_max is not None:
669 | if n_seq_max != 1:
670 | raise NotImplementedError(
671 | f'n_seq_max value {n_seq_max} != 1; this is not yet supported'
672 | )
673 | self.params.n_seq_max = n_seq_max
674 | if n_threads is not None:
675 | self.params.n_threads = n_threads
676 | if n_threads_batch is not None:
677 | self.params.n_threads_batch = n_threads_batch
678 | if rope_scaling_type is not None:
679 | self.params.rope_scaling_type = rope_scaling_type
680 | if pooling_type is not None:
681 | self.params.pooling_type = pooling_type
682 | if attention_type is not None:
683 | self.params.attention_type = attention_type
684 | if rope_freq_base is not None:
685 | self.params.rope_freq_base = rope_freq_base
686 | if rope_freq_scale is not None:
687 | self.params.rope_freq_scale = rope_freq_scale
688 | if yarn_ext_factor is not None:
689 | self.params.yarn_ext_factor = yarn_ext_factor
690 | if yarn_attn_factor is not None:
691 | self.params.yarn_attn_factor = yarn_attn_factor
692 | if yarn_beta_fast is not None:
693 | self.params.yarn_beta_fast = yarn_beta_fast
694 | if yarn_beta_slow is not None:
695 | self.params.yarn_beta_slow = yarn_beta_slow
696 | if yarn_orig_ctx is not None:
697 | self.params.yarn_orig_ctx = yarn_orig_ctx
698 | if defrag_thold is not None:
699 | self.params.defrag_thold = defrag_thold
700 |
701 | def _py_eval_callback(is_eval: bool, user_data: ptr) -> None:
702 | return
703 |
704 | # create the ctypes function pointer instance and store it as an attribute of this
705 | # `_LlamaCtx` to keep it alive
706 | self._eval_callback_cfunc_instance = lib.eval_callback_functype(_py_eval_callback)
707 | self.params.cb_eval = self._eval_callback_cfunc_instance
708 | self.params.cb_eval_user_data = lib.NULLPTR
709 |
710 | _k = _DEFAULT_KV_TYPE
711 | if type_k is not None:
712 | self.params.type_k = _k = type_k
713 | _v = _DEFAULT_KV_TYPE
714 | if type_v is not None:
715 | self.params.type_v = _v = type_v
716 | if _k != _v:
717 | log(
718 | f'type_k value {_k} != type_v value {_v}; this is rarely '
719 | f'supported, program may fail', 2
720 | )
721 | if _k not in _SUPPORTED_KV_TYPES:
722 | log(f'type_k value {_k} is unsupported; program may fail', 2)
723 | if _v not in _SUPPORTED_KV_TYPES:
724 | log(f'type_v value {_v} is unsupported; program may fail', 2)
725 | if (not flash_attn) and (_v not in [
726 | lib.GGMLType.GGML_TYPE_F32, lib.GGMLType.GGML_TYPE_F16, lib.GGMLType.GGML_TYPE_BF16
727 | ]):
728 | log(f'V cache quantization requires flash_attn; program may fail', 2)
729 | if embeddings is not None:
730 | self.params.embeddings = embeddings
731 | if offload_kqv is not None:
732 | self.params.offload_kqv = offload_kqv
733 | if flash_attn is not None:
734 | self.params.flash_attn = flash_attn
735 | if no_perf is not None:
736 | self.params.no_perf = no_perf
737 | if op_offload is not None:
738 | self.params.op_offload = op_offload
739 |
740 | # enable proper SWA support unless explicitly disabled
741 | self.params.swa_full = False if swa_full is None else swa_full
742 |
743 | # easy-llama does not currently support user-defined abort callbacks, but it does not
744 | # need them, since KeyboardInterrupt can catch the code in between batches.
745 |
746 | def _py_abort_callback(user_data: ptr) -> ctypes.c_bool:
747 | return False
748 |
749 | # create the ctypes function pointer instance and store it as an attribute of this
750 | # `_LlamaCtx` to keep it alive
751 | self._abort_callback_cfunc_instance = lib.abort_callback_functype(_py_abort_callback)
752 | self.params.abort_callback = self._abort_callback_cfunc_instance
753 | self.params.abort_callback_data = lib.NULLPTR
754 |
755 | null_ptr_check(model.model, "model.model", "_LlamaCtx.__init__")
756 | with suppress_output(disable=get_verbose()):
757 | self.ctx = lib.llama_init_from_model(model.model, self.params)
758 | null_ptr_check(self.ctx, "self.ctx", "_LlamaCtx.__init__")
759 |
760 | def __del__(self):
761 | self.free()
762 |
763 | def free(self):
764 | if self.ctx is not None:
765 | with suppress_output(disable=get_verbose()):
766 | lib.llama_free(self.ctx)
767 | self.ctx = None
768 |
769 | #
770 | # Llama
771 | #
772 |
773 | class Llama:
774 | """Simplified interface for general-purpose Llama model usage
775 |
776 | The `easy_llama.Llama` class provides a high-level Python interface to
777 | a llama_model and its associated llama_context.
778 |
779 | Example usage:
780 | >>> import easy_llama as ez
781 | >>> MyLlama = ez.Llama('/path/to/model.gguf', n_ctx=8192)
782 | >>> in_txt = b"The apple doesn't fall far from"
783 | >>> in_toks = MyLlama.tokenize(in_txt, add_special=True, parse_special=False)
784 | >>> out_toks = MyLlama.generate(in_toks, n_predict=16)
785 | >>> out_txt = MyLlama.detokenize(out_toks, special=True)
786 | >>> print(out_txt)
787 | b" the tree, as the saying goes, and I think that's especially true when\""""
788 |
789 | def __init__(
790 | self,
791 | path_model: str,
792 | n_gpu_layers: int = 0,
793 | n_ctx: int = 512,
794 | n_threads: int = 0,
795 | n_threads_batch: int = 0,
796 | type_k: Optional[int] = None,
797 | type_v: Optional[int] = None,
798 | offload_kqv: bool = False, # XXX: can you make this actually offload the whole KV cache, not just per-layer?
799 | flash_attn: bool = False,
800 | warmup: bool = False,
801 | verbose: bool = True,
802 | **kwargs
803 | ):
804 | """Load a llama model from a file
805 |
806 | - path_model:
807 | The path to the GGUF model file you wish to load from
808 | - n_gpu_layers:
809 | How many of the model's layers should be offloaded from CPU to GPU.
810 | Values less than 0 will attempt to offload all layers. Default is 0.
811 | - use_mmap:
812 | Whether to memory-map the model. Changing this to False will cause
813 | slower load times. Default is True.
814 | - use_mlock:
815 | Whether to lock the model into memory, which can prevents page-outs.
816 | Changing this to True can cause slower load times and increased
817 | memory usage. Default is False.
818 | - n_ctx:
819 | The context length at which to load the model, in tokens. Default is
820 | 512, which is very small. Increase as needed. Values 0 or less will
821 | attempt to load the native context length of the model (which may be
822 | very large).
823 | - n_batch:
824 | The maximum number of tokens to process at once. Higher values
825 | will increase prompt processing speed at expense of increased memory
826 | usage. Values must be between 32 and n_ctx, inclusive.
827 | - n_threads:
828 | Number of threads to use for batch size == 1.
829 | - n_threads_batch:
830 | Number of threads to use for batch sizes > 1.
831 | - type_k:
832 | The `libllama.GGMLType` to use for the K cache. Default is 1 (f16).
833 | In most cases, this must be the same as `type_v`.
834 | - type_v:
835 | The `libllama.GGMLType` to use for the V cache. Default is 1 (f16).
836 | In most cases, this must be the same as `type_k`. Values other than
837 | 0 and 1 are not compatible with `flash_attn=True`.
838 | - offload_kqv:
839 | Whether to offload the K, Q, V caches to the GPU, which can greatly
840 | improve prompt processing speed at the cost of increased VRAM usage.
841 | Default is False for compatability reasons. Recommended to set to
842 | True if possible.
843 | - flash_attn:
844 | Whether to use Flash Attention, which decreases memory usage and
845 | can increase both prompt processing and text generation speed,
846 | especially at long context lengths. Default is False for compatability reasons.
847 | Recommended to set to True if possible.
848 | - warmup:
849 | Whether to warm-up the model with an empty run. This reduces the
850 | latency of the first generation at the cost of a slower load time.
851 | - verbose:
852 | Print informational output when loading model as well as at
853 | runtime. Default is True. If set to False, warnings and errors
854 | will still be shown."""
855 |
856 | if not os.path.exists(path_model):
857 | raise FileNotFoundError(
858 | f"Llama: the given path_model {path_model!r} does not exist"
859 | )
860 | if os.path.isdir(path_model):
861 | raise IsADirectoryError(
862 | f"Llama: the given path_model {path_model!r} is a directory, "
863 | f"not a GGUF file"
864 | )
865 |
866 | set_verbose(verbose)
867 |
868 | # peek at metadata from GGUF file header before loading model
869 |
870 | self.metadata = QuickGGUFReader.load_metadata(path_model)
871 |
872 | #
873 | # Load model from file
874 | #
875 |
876 | self._model = _LlamaModel(
877 | path_model = path_model,
878 | devices = kwargs.get('devices'),
879 | tensor_buft_override = kwargs.get('tensor_buft_override'),
880 | n_gpu_layers = n_gpu_layers,
881 | split_mode = kwargs.get('split_mode'),
882 | main_gpu = kwargs.get('main_gpu'),
883 | tensor_split = kwargs.get('tensor_split'),
884 | rpc_servers = kwargs.get('rpc_servers'),
885 | progress_callback = kwargs.get('progress_callback'),
886 | progress_callback_user_data = kwargs.get('progress_callback_user_data'),
887 | kv_overrides = kwargs.get('kv_overrides'),
888 | vocab_only = kwargs.get('vocab_only'),
889 | use_mmap = kwargs.get('use_mmap'),
890 | use_mlock = kwargs.get('use_mlock'),
891 | check_tensors = kwargs.get('check_tensors')
892 | )
893 |
894 | self._vocab = lib.llama_model_get_vocab(self._model.model)
895 | """A pointer to this model's `llama_vocab`"""
896 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.__init__')
897 |
898 | n_ctx_train = lib.llama_model_n_ctx_train(self._model.model)
899 |
900 | # use n_ctx unless it's 0 or negative, in that case use n_ctx_train
901 |
902 | if n_ctx <= 0:
903 | log_verbose(f'n_ctx value {n_ctx}; using n_ctx_train value {n_ctx_train}')
904 | _n_ctx = int(n_ctx_train)
905 | else:
906 | _n_ctx = int(n_ctx)
907 |
908 | # use rope_freq_base unless it == 0.0, in that case use the native
909 | # rope_freq_base found in the GGUF metadata
910 | rope_freq_base = kwargs.get('rope_freq_base', 0.0)
911 |
912 | if rope_freq_base == 0.0:
913 | rope_freq_base_train = None
914 | for key in self.metadata.keys():
915 | if key.endswith('.rope.freq_base'):
916 | rope_freq_base_train = float(self.metadata[key])
917 |
918 | # NOTE: if n_ctx > n_ctx_train, then rope_freq_base must also be
919 | # increased by at least a proportional amount to guarantee a
920 | # usable kv cache throughout the entire context
921 | #
922 | # the function _calculate_rope_freq_base handles this
923 |
924 | _rope_freq_base = _calculate_rope_freq_base(
925 | n_ctx_train=n_ctx_train,
926 | n_ctx_load=_n_ctx,
927 | rope_freq_base_train=rope_freq_base_train # can be None
928 | )
929 | else:
930 | _rope_freq_base = rope_freq_base
931 |
932 | _n_threads = n_threads if n_threads > 0 else _get_optimal_n_threads()
933 | _n_threads_batch = n_threads_batch if n_threads_batch > 0 else (
934 | _get_optimal_n_threads_batch()
935 | )
936 |
937 | #
938 | # New context with model
939 | #
940 |
941 | self._ctx = _LlamaCtx(
942 | model = self._model,
943 | n_ctx = _n_ctx,
944 | n_batch = kwargs.get('n_batch'),
945 | n_ubatch = kwargs.get('n_ubatch'),
946 | # n_seq_max = kwargs.get('n_seq_max'), # uncomment this if n_seq_max gets supported
947 | n_threads = _n_threads,
948 | n_threads_batch = _n_threads_batch,
949 | rope_scaling_type = kwargs.get('rope_scaling_type'),
950 | pooling_type = kwargs.get('pooling_type'),
951 | attention_type = kwargs.get('attention_type'),
952 | rope_freq_base = _rope_freq_base,
953 | rope_freq_scale = kwargs.get('rope_freq_scale'),
954 | yarn_ext_factor = kwargs.get('yarn_ext_factor'),
955 | yarn_attn_factor = kwargs.get('yarn_attn_factor'),
956 | yarn_beta_fast = kwargs.get('yarn_beta_fast'),
957 | yarn_beta_slow = kwargs.get('yarn_beta_slow'),
958 | yarn_orig_ctx = kwargs.get('yarn_orig_ctx'),
959 | defrag_thold = kwargs.get('defrag_thold'),
960 | type_k = type_k,
961 | type_v = type_v,
962 | embeddings = kwargs.get('embeddings'),
963 | offload_kqv = offload_kqv,
964 | flash_attn = flash_attn,
965 | no_perf = kwargs.get('no_perf'),
966 | op_offload = kwargs.get('op_offload'),
967 | swa_full = kwargs.get('swa_full')
968 | )
969 |
970 | #
971 | # Display warnings about n_ctx if necessary
972 | #
973 |
974 | actual_n_ctx = self.n_ctx()
975 | requested_n_ctx = _n_ctx
976 |
977 | if actual_n_ctx != requested_n_ctx:
978 | log(
979 | f"requested n_ctx value differs from actual n_ctx value; "
980 | f"requested {requested_n_ctx}, actual {actual_n_ctx}", 2
981 | )
982 | if actual_n_ctx < 512:
983 | log(
984 | f"n_ctx value {actual_n_ctx} is less than 512, which can "
985 | f"sometimes cause problems with llama.cpp - consider "
986 | f"increasing it to at least 512", 2
987 | )
988 | if actual_n_ctx % 512 != 0:
989 | log(
990 | f"n_ctx value {actual_n_ctx} is not divisible by 512, which "
991 | f"can sometimes cause problems with llama.cpp - consider "
992 | f"changing it to "
993 | f"{_round_n_ctx(actual_n_ctx, n_ctx_train)}", 2
994 | )
995 | # warn about default context length
996 | if actual_n_ctx == 512:
997 | log(
998 | f'you are using the default n_ctx value {actual_n_ctx}, which '
999 | f'is very small. increase n_ctx as needed to support longer '
1000 | f'inputs and outputs.', 2
1001 | )
1002 |
1003 | self._stopwatch = _LlamaStopwatch()
1004 |
1005 | #
1006 | # Store Llama metadata as attributes for faster access internally
1007 | #
1008 |
1009 | self._name = self.name()
1010 | self._n_ctx = self.n_ctx()
1011 | self._n_batch = self.n_batch()
1012 | self._n_ubatch = self.n_ubatch()
1013 | self._n_seq_max = self.n_seq_max()
1014 | self._n_vocab = self.n_vocab()
1015 | self._n_ctx_train = self.n_ctx_train()
1016 | self._n_embd = self.n_embd()
1017 | self._n_layer = self.n_layer()
1018 | self._n_head = self.n_head()
1019 | self._n_head_kv = self.n_head_kv()
1020 | self._pooling_type = self.pooling_type()
1021 | self._n_swa = self.n_swa()
1022 | self._vocab_type = self.vocab_type()
1023 | self._rope_type = self.rope_type()
1024 | self._rope_freq_scale_train = self.rope_freq_scale_train()
1025 | self._model_size_bytes = self.model_size_bytes()
1026 | self._chat_template = self.chat_template()
1027 | self._n_params = self.n_params()
1028 | self._bpw = self.bpw()
1029 | self._has_encoder = self.has_encoder()
1030 | self._has_decoder = self.has_decoder()
1031 | self._is_recurrent = self.is_recurrent()
1032 | self._token_bos = self.token_bos()
1033 | self._token_eos = self.token_eos()
1034 | self._token_eot = self.token_eot()
1035 | self._token_sep = self.token_sep()
1036 | self._token_nl = self.token_nl()
1037 | self._token_pad = self.token_pad()
1038 | self._add_bos_token = self.add_bos_token()
1039 | self._add_eos_token = self.add_eos_token()
1040 | self._token_fim_pre = self.token_fim_pre()
1041 | self._token_fim_suf = self.token_fim_suf()
1042 | self._token_fim_mid = self.token_fim_mid()
1043 | self._token_fim_pad = self.token_fim_pad()
1044 | self._token_fim_rep = self.token_fim_rep()
1045 | self._token_fim_sep = self.token_fim_sep()
1046 |
1047 | self.eog_tokens = [i for i in range(self._n_vocab) if self.token_is_eog(i)]
1048 | """A list of all tokens in the vocab that are marked as EOG
1049 | (End-Of-Generation)"""
1050 |
1051 | # internal use only - the default SamplerParams with this model
1052 | self._default_sampler_params = SamplerParams(self)
1053 |
1054 | self.pos = 0
1055 | """The current position of the model within the context window"""
1056 |
1057 | self.context_tokens = []
1058 | """A list of all tokens currently in the context window"""
1059 |
1060 | self._lock = _InferenceLock()
1061 |
1062 | if warmup:
1063 | self.warmup()
1064 |
1065 | # End of Llama.__init__
1066 |
1067 | def __repr__(self) -> str:
1068 | return (
1069 | f"Llama("
1070 | f"path_model={self._model.path_model!r}, "
1071 | f"n_gpu_layers={self._model.params.n_gpu_layers}, "
1072 | f"n_ctx={self._n_ctx}, "
1073 | f"type_k={self._ctx.params.type_k}, "
1074 | f"type_v={self._ctx.params.type_v}, "
1075 | f"offload_kqv={self._ctx.params.offload_kqv}, "
1076 | f"flash_attn={self._ctx.params.flash_attn}"
1077 | f")"
1078 | )
1079 |
1080 | def free(self):
1081 | """Deallocate the context and model"""
1082 | self._ctx.free()
1083 | self._model.free()
1084 |
1085 | def _validate_model_state(self) -> None:
1086 | """Ensure `llama_model`, `llama_vocab` and `llama_context` are not NULL and validate
1087 | `Llama.pos`"""
1088 | null_ptr_check(self._model.model, 'self._model.model', '_validate_model_state')
1089 | null_ptr_check(self._vocab, 'self._vocab', '_validate_model_state')
1090 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', '_validate_model_state')
1091 |
1092 | _n_context_tokens = len(self.context_tokens)
1093 | _pos = self.pos
1094 |
1095 | if _pos < 0:
1096 | self.pos = 0
1097 | self.context_tokens = []
1098 | log_verbose(
1099 | f'self.pos value was {self.pos} - clamping to 0. the KV cache has been reset.',
1100 | 2
1101 | )
1102 | elif _pos != _n_context_tokens:
1103 | self.pos = 0
1104 | self.context_tokens = []
1105 | log_verbose(
1106 | f'n_context_tokens {_n_context_tokens} did not match self.pos {_pos}. the KV '
1107 | f'cache has been reset.', 2
1108 | )
1109 | if not hasattr(self, '_default_sampler_params'):
1110 | self._default_sampler_params = SamplerParams(self)
1111 | log_verbose(
1112 | "Llama._default_sampler_params was destroyed but has been recreated", 2
1113 | )
1114 |
1115 | def warmup(self) -> None:
1116 | """Warm-up the model. This also clears the KV cache."""
1117 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.warmup")
1118 |
1119 | with suppress_output(disable=get_verbose()):
1120 | self.reset()
1121 |
1122 | lib.llama_set_warmup(self._ctx.ctx, True)
1123 |
1124 | log_verbose('warmup: single token decode ...')
1125 | with self._lock:
1126 | _internals.decode_tg(self._ctx.ctx, 0, 0)
1127 |
1128 | # This section decodes a full batch of tokens, but is probably unnecessary.
1129 | #
1130 | # with suppress_output(disable=get_verbose()):
1131 | # self.reset()
1132 | #
1133 | # log_verbose('warmup: full batch decode ...')
1134 | # with self._lock:
1135 | # _internals.decode_pp(self._ctx.ctx, 0, [0] * self._n_batch, self._n_batch)
1136 |
1137 | lib.llama_set_warmup(self._ctx.ctx, False)
1138 |
1139 | with suppress_output(disable=get_verbose()):
1140 | self.reset()
1141 |
1142 | self.pos = 0
1143 | log_verbose('warmup: done')
1144 |
1145 | def n_ctx(self) -> int:
1146 | """Get the current context length"""
1147 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_ctx")
1148 | return lib.llama_n_ctx(self._ctx.ctx)
1149 |
1150 | def n_batch(self) -> int:
1151 | """Get the current batch size"""
1152 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_batch")
1153 | return lib.llama_n_batch(self._ctx.ctx)
1154 |
1155 | def n_ubatch(self) -> int:
1156 | """Get the current micro-batch size"""
1157 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_batch")
1158 | return lib.llama_n_ubatch(self._ctx.ctx)
1159 |
1160 | def n_seq_max(self) -> int:
1161 | """Get the max number of sequences"""
1162 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_seq_max")
1163 | return lib.llama_n_seq_max(self._ctx.ctx)
1164 |
1165 | def n_vocab(self) -> int:
1166 | """Get the vocab size"""
1167 | null_ptr_check(self._vocab, "self._vocab", "Llama.n_vocab")
1168 | return lib.llama_vocab_n_tokens(self._vocab)
1169 |
1170 | def n_ctx_train(self) -> int:
1171 | """Get the trained context length"""
1172 | null_ptr_check(self._model.model, 'self._model.model', 'Llama.n_ctx_train')
1173 | return lib.llama_model_n_ctx_train(self._model.model)
1174 |
1175 | def n_embd(self) -> int:
1176 | """Get the embedding size"""
1177 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_embd")
1178 | return lib.llama_model_n_embd(self._model.model)
1179 |
1180 | def n_layer(self) -> int:
1181 | """Get the number of layers"""
1182 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_layer")
1183 | return lib.llama_model_n_layer(self._model.model)
1184 |
1185 | def n_head(self) -> int:
1186 | """Get the number of attention heads"""
1187 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_head")
1188 | return lib.llama_model_n_head(self._model.model)
1189 |
1190 | def n_head_kv(self) -> int:
1191 | """Get the number of KV heads"""
1192 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_head_kv")
1193 | return lib.llama_model_n_head_kv(self._model.model)
1194 |
1195 | def pooling_type(self) -> int:
1196 | """Get the pooling type used by the context"""
1197 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.pooling_type")
1198 | return lib.llama_pooling_type(self._ctx.ctx)
1199 |
1200 | def n_swa(self) -> int:
1201 | """Get the sliding window size, for models which use SWA."""
1202 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_swa")
1203 | return lib.llama_model_n_swa(self._model.model)
1204 |
1205 | def vocab_type(self) -> int:
1206 | """Get the vocab type"""
1207 | null_ptr_check(self._vocab, "self._vocab", "Llama.vocab_type")
1208 | return lib.llama_vocab_type(self._vocab)
1209 |
1210 | def rope_type(self) -> int:
1211 | """Get the RoPE type"""
1212 | null_ptr_check(self._model.model, "self._model.model", "Llama.rope_type")
1213 | return lib.llama_model_rope_type(self._model.model)
1214 |
1215 | def rope_freq_scale_train(self) -> float:
1216 | """Get the trained RoPE frequency scale"""
1217 | null_ptr_check(self._model.model, "self._model.model", "Llama.rope_freq_scale_train")
1218 | return lib.llama_model_rope_freq_scale_train(self._model.model)
1219 |
1220 | def model_size_bytes(self) -> int:
1221 | """Get the total size of all tensors in the model, in bytes"""
1222 | null_ptr_check(self._model.model, "self._model.model", "Llama.model_size_bytes")
1223 | return lib.llama_model_size(self._model.model)
1224 |
1225 | def chat_template(self) -> Optional[str]:
1226 | """Get the model's built-in chat template string. Returns None if not available."""
1227 | null_ptr_check(self._model.model, "self._model.model", "Llama.chat_template")
1228 | return lib.llama_model_chat_template(self._model.model, name=None)
1229 |
1230 | def n_params(self) -> int:
1231 | """Get the total number of parameters in the model"""
1232 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_params")
1233 | return lib.llama_model_n_params(self._model.model)
1234 |
1235 | def has_encoder(self) -> bool:
1236 | """If the model has an encoder"""
1237 | null_ptr_check(self._model.model, "self._model.model", "Llama.has_encoder")
1238 | return lib.llama_model_has_encoder(self._model.model)
1239 |
1240 | def has_decoder(self) -> bool:
1241 | """If the model has a decoder"""
1242 | null_ptr_check(self._model.model, "self._model.model", "Llama.has_decoder")
1243 | return lib.llama_model_has_decoder(self._model.model)
1244 |
1245 | def is_recurrent(self) -> bool:
1246 | """If the model is recurrent"""
1247 | null_ptr_check(self._model.model, "self._model.model", "Llama.is_recurrent")
1248 | return lib.llama_model_is_recurrent(self._model.model)
1249 |
1250 | #
1251 | # KV cache management methods
1252 | #
1253 |
1254 | def kv_cache_clear(self) -> None:
1255 | """Clear the KV cache"""
1256 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_clear")
1257 | lib.llama_kv_self_clear(self._ctx.ctx)
1258 |
1259 | def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int) -> bool:
1260 | """Remove tokens from a sequence in the KV cache"""
1261 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_rm")
1262 | return lib.llama_kv_self_seq_rm(self._ctx.ctx, seq_id, p0, p1)
1263 |
1264 | def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int) -> None:
1265 | """Copy tokens between sequences in the KV cache"""
1266 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_cp")
1267 | lib.llama_kv_self_seq_cp(self._ctx.ctx, seq_id_src, seq_id_dst, p0, p1)
1268 |
1269 | def kv_cache_seq_keep(self, seq_id: int) -> None:
1270 | """Remove all tokens except for the ones in this sequence"""
1271 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_keep")
1272 | lib.llama_kv_self_seq_keep(self._ctx.ctx, seq_id)
1273 |
1274 | def kv_cache_seq_add(self, seq_id: int, p0: int, p1: int, delta: int) -> None:
1275 | """Add relative position "delta" to the tokens"""
1276 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_add")
1277 | lib.llama_kv_self_seq_add(self._ctx.ctx, seq_id, p0, p1, delta)
1278 |
1279 | def kv_cache_seq_div(self, seq_id: int, p0: int, p1: int, d: int) -> None:
1280 | """Integer division of the positions by factor of `d > 1`"""
1281 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_div")
1282 | lib.llama_kv_self_seq_div(self._ctx.ctx, seq_id, p0, p1, d)
1283 |
1284 | def kv_cache_seq_pos_min(self, seq_id: int) -> int:
1285 | """Returns the earliest valid position in the KV cache for the specified sequence
1286 | (relevant for models which use SWA)"""
1287 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_pos_max")
1288 | return lib.llama_kv_self_seq_pos_min(self._ctx.ctx, seq_id)
1289 |
1290 | def kv_cache_seq_pos_max(self, seq_id: int) -> int:
1291 | """Returns the largest position present in the KV cache for the specified sequence"""
1292 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_pos_max")
1293 | return lib.llama_kv_self_seq_pos_max(self._ctx.ctx, seq_id)
1294 |
1295 | def kv_cache_can_shift(self) -> bool:
1296 | """Check if the context supports KV cache shifting"""
1297 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_can_shift")
1298 | return lib.llama_kv_self_can_shift(self._ctx.ctx)
1299 |
1300 | def n_threads(self) -> int:
1301 | """Get the number of threads used for batch size == 1"""
1302 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_threads")
1303 | return lib.llama_n_threads(self._ctx.ctx)
1304 |
1305 | def n_threads_batch(self) -> int:
1306 | """Get the number of threads used for batch sizes > 1"""
1307 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_threads_batch")
1308 | return lib.llama_n_threads_batch(self._ctx.ctx)
1309 |
1310 | def token_get_score(self, token: int) -> float:
1311 | """Get the score of a token"""
1312 | null_ptr_check(self._vocab, "self._vocabl", "Llama.token_get_score")
1313 | return lib.llama_vocab_get_score(self._vocab, token)
1314 |
1315 | def token_is_eog(self, token: int) -> bool:
1316 | """If the token is marked as EOG (End-Of-Generation)"""
1317 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_is_eog")
1318 | return lib.llama_vocab_is_eog(self._vocab, token)
1319 |
1320 | def token_is_control(self, token: int) -> bool:
1321 | """If the token is marked as a control token"""
1322 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_is_control")
1323 | return lib.llama_vocab_is_control(self._vocab, token)
1324 |
1325 | def token_bos(self) -> Optional[int]:
1326 | """Get the BOS (Beginning-Of-Sequence) token. Return None if not available."""
1327 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_bos")
1328 | id = lib.llama_vocab_bos(self._vocab)
1329 | return id if id != lib.LLAMA_TOKEN_NULL else None
1330 |
1331 | def token_eos(self) -> Optional[int]:
1332 | """Get the EOS (End-Of-Sequence) token. Return None if not available."""
1333 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_eos")
1334 | id = lib.llama_vocab_eos(self._vocab)
1335 | return id if id != lib.LLAMA_TOKEN_NULL else None
1336 |
1337 | def token_eot(self) -> Optional[int]:
1338 | """Get the EOT (End-Of-Turn) token. Return None if not available."""
1339 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_eot")
1340 | id = lib.llama_vocab_eot(self._vocab)
1341 | return id if id != lib.LLAMA_TOKEN_NULL else None
1342 |
1343 | def token_sep(self) -> Optional[int]:
1344 | """Get the SEP (Separator) token. Return None if not available."""
1345 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_sep")
1346 | id = lib.llama_vocab_sep(self._vocab)
1347 | return id if id != lib.LLAMA_TOKEN_NULL else None
1348 |
1349 | def token_nl(self) -> Optional[int]:
1350 | """Get the NL (Newline) token. Return None if not available."""
1351 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_nl")
1352 | id = lib.llama_vocab_nl(self._vocab)
1353 | return id if id != lib.LLAMA_TOKEN_NULL else None
1354 |
1355 | def token_pad(self) -> Optional[int]:
1356 | """Get the PAD (Padding) token. Return None if not available."""
1357 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_pad")
1358 | id = lib.llama_vocab_pad(self._vocab)
1359 | return id if id != lib.LLAMA_TOKEN_NULL else None
1360 |
1361 | def add_bos_token(self) -> bool:
1362 | """If the model is configured to add a BOS token"""
1363 | null_ptr_check(self._vocab, "self._vocab", "Llama.add_bos_token")
1364 | return lib.llama_vocab_get_add_bos(self._vocab)
1365 |
1366 | def add_eos_token(self) -> bool:
1367 | """If the model is configured to add an EOS token"""
1368 | null_ptr_check(self._vocab, "self._vocab", "Llama.add_eos_token")
1369 | return lib.llama_vocab_get_add_eos(self._vocab)
1370 |
1371 | def token_fim_pre(self) -> Optional[int]:
1372 | """Get the FIM PRE (Fill-In-Middle Prefix) token. Return None if not available."""
1373 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_pre")
1374 | id = lib.llama_vocab_fim_pre(self._vocab)
1375 | return id if id != lib.LLAMA_TOKEN_NULL else None
1376 |
1377 | def token_fim_suf(self) -> Optional[int]:
1378 | """Get the FIM SUF (Fill-In-Middle Suffix) token. Return None if not available."""
1379 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_suf")
1380 | id = lib.llama_vocab_fim_suf(self._vocab)
1381 | return id if id != lib.LLAMA_TOKEN_NULL else None
1382 |
1383 | def token_fim_mid(self) -> Optional[int]:
1384 | """Get the FIM MID (Fill-In-Middle Middle) token. Return None if not available."""
1385 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_mid")
1386 | id = lib.llama_vocab_fim_mid(self._vocab)
1387 | return id if id != lib.LLAMA_TOKEN_NULL else None
1388 |
1389 | def token_fim_pad(self) -> Optional[int]:
1390 | """Get the FIM PAD (Fill-In-Middle Padding) token. Return None if not available."""
1391 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_pad")
1392 | id = lib.llama_vocab_fim_pad(self._vocab)
1393 | return id if id != lib.LLAMA_TOKEN_NULL else None
1394 |
1395 | def token_fim_rep(self) -> Optional[int]:
1396 | """Get the FIM REP (Fill-In-Middle Repository) token. Return None if not available."""
1397 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_rep")
1398 | id = lib.llama_vocab_fim_rep(self._vocab)
1399 | return id if id != lib.LLAMA_TOKEN_NULL else None
1400 |
1401 | def token_fim_sep(self) -> Optional[int]:
1402 | """Get the FIM SEP (Fill-In-Middle Separator) token. Return None if not available."""
1403 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_sep")
1404 | id = lib.llama_vocab_fim_sep(self._vocab)
1405 | return id if id != lib.LLAMA_TOKEN_NULL else None
1406 |
1407 | def tokenize(
1408 | self,
1409 | text_bytes: bytes,
1410 | add_special: bool,
1411 | parse_special: bool,
1412 | ) -> list[int]:
1413 | """Convert the provided UTF-8 encoded text into tokens
1414 |
1415 | - text_bytes:
1416 | The text to be tokenized
1417 | - add_special:
1418 | Allow to add BOS and EOS tokens if model is configured to do so.
1419 | - parse_special:
1420 | Allow tokenizing special and/or control tokens which otherwise are
1421 | not exposed and treated as plaintext. Does not insert a leading
1422 | space."""
1423 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.tokenize')
1424 | n_tokens = _internals.get_length(
1425 | vocab=self._vocab,
1426 | text_bytes=text_bytes,
1427 | add_special=add_special,
1428 | parse_special=parse_special
1429 | )
1430 | return _internals.tokenize(
1431 | vocab=self._vocab,
1432 | text_bytes=text_bytes,
1433 | n_tokens_max=n_tokens,
1434 | add_special=add_special,
1435 | parse_special=parse_special
1436 | )
1437 |
1438 | def token_to_piece(self, token: int, special: bool) -> bytes:
1439 | """Convert a single token ID into UTF-8 bytes
1440 |
1441 | - special:
1442 | If True, special tokens are rendered in the output"""
1443 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.token_to_piece')
1444 | return _internals.token_to_piece(
1445 | vocab=self._vocab,
1446 | token=token,
1447 | special=special
1448 | )
1449 |
1450 | def detokenize(
1451 | self,
1452 | tokens: list[int],
1453 | special: bool
1454 | ) -> str:
1455 | """Convert the provided tokens into UTF-8 encoded text
1456 |
1457 | - special:
1458 | If True, special tokens are rendered in the output"""
1459 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.detokenize')
1460 | return _internals.detokenize(
1461 | vocab=self._vocab,
1462 | tokens=tokens,
1463 | special=special
1464 | )
1465 |
1466 | def get_length(
1467 | self,
1468 | text_bytes: bytes,
1469 | add_special: bool,
1470 | parse_special: bool,
1471 | ) -> int:
1472 | """Return the length of a given text as measured in tokens"""
1473 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.get_length')
1474 | return _internals.get_length(
1475 | vocab=self._vocab,
1476 | text_bytes=text_bytes,
1477 | add_special=add_special,
1478 | parse_special=parse_special
1479 | )
1480 |
1481 | # TODO - Chat templating functions here
1482 |
1483 | def _first_valid_pos(self, tokens: list[int]) -> int:
1484 | """Given a list of tokens, and using `Llama.context_tokens`, find the first
1485 | valid `Llama.pos`
1486 |
1487 | In other words, return length of the longest common prefix between the
1488 | two lists of tokens.
1489 |
1490 | Returns 0 if none of the tokens match, 1 if one token matches, etc."""
1491 | i = 0
1492 | for c, t in zip(self.context_tokens, tokens):
1493 | if c == t:
1494 | i += 1
1495 | else:
1496 | break
1497 | return i
1498 |
1499 | def _set_cache_tokens(self, input_tokens: list[int]) -> list[int]:
1500 | n_input_tokens = len(input_tokens)
1501 |
1502 | if n_input_tokens == 0:
1503 | raise ValueError(f'Llama._set_cache_tokens: input_tokens cannot be empty')
1504 |
1505 | # find how many tokens in the input are already in the KV cache
1506 | self.pos = self._first_valid_pos(input_tokens)
1507 |
1508 | if self.pos > self._n_ctx:
1509 | raise ExceededContextLengthException(
1510 | f'Llama._set_cache_tokens: no valid position within context window '
1511 | f'{self._n_ctx} (self.pos = {self.pos})'
1512 | )
1513 |
1514 | # remove all tokens in the KV cache that are past that point
1515 | self.kv_cache_seq_rm(0, self.pos, -1)
1516 |
1517 | # tokens already in the KV cache
1518 | self.context_tokens = input_tokens[:self.pos]
1519 |
1520 | # tokens returned to caller for processing
1521 | actual_input_tokens = input_tokens[self.pos:]
1522 |
1523 | return actual_input_tokens
1524 |
1525 | def _process_batch(
1526 | self,
1527 | batch_tokens: list[int],
1528 | logits_all: bool = False
1529 | ) -> Optional[np.ndarray]:
1530 | """Process a batch of one or more tokens, up to `Llama.n_batch()`. If `logits_all` is
1531 | True, return the logits for all tokens in the batch. Otherwise, return None.
1532 |
1533 | This function is used by `Llama.eval`, `Llama.generate`, `Llama.stream`, etc."""
1534 |
1535 | batch_logits = None
1536 | n_batch_tokens = len(batch_tokens)
1537 | log_debug(f'Llama._process_batch: processing {batch_tokens}')
1538 |
1539 | if n_batch_tokens > self._n_batch:
1540 | raise ValueError(
1541 | f'Llama._process_batch: n_batch_tokens cannot exceed n_batch '
1542 | f'({n_batch_tokens} > {self._n_batch})'
1543 | )
1544 |
1545 | if n_batch_tokens > 1: # prompt processing
1546 | if logits_all:
1547 | with self._lock:
1548 | self._stopwatch.start_pp()
1549 | batch_logits = _internals.decode_pp_with_logits(
1550 | self._ctx.ctx, self.pos, batch_tokens, n_batch_tokens, self._n_vocab
1551 | )
1552 | self._stopwatch.stop_pp()
1553 |
1554 | else:
1555 | with self._lock:
1556 | self._stopwatch.start_pp()
1557 | _internals.decode_pp(self._ctx.ctx, self.pos, batch_tokens, n_batch_tokens)
1558 | self._stopwatch.stop_pp()
1559 |
1560 | self._stopwatch.increment_pp_tokens(n_batch_tokens)
1561 |
1562 | elif n_batch_tokens == 1: # text generation
1563 | with self._lock:
1564 | self._stopwatch.start_tg()
1565 | batch_logits = _internals.decode_tg_with_logits(
1566 | self._ctx.ctx, self.pos, batch_tokens[0], self._n_vocab
1567 | )
1568 | self._stopwatch.stop_tg()
1569 |
1570 | self._stopwatch.increment_tg_tokens(1)
1571 |
1572 | else:
1573 | raise RuntimeError(
1574 | f'Llama._process_batch: unexpected n_batch_tokens value {n_batch_tokens}'
1575 | )
1576 |
1577 | # update the Llama position and context
1578 | self.pos += n_batch_tokens
1579 | self.context_tokens.extend(batch_tokens)
1580 |
1581 | return batch_logits
1582 |
1583 | def eval(
1584 | self,
1585 | input_tokens: list[int],
1586 | logits_all: bool = False
1587 | ) -> np.ndarray:
1588 | """Evaluate the given tokens and update the model state.
1589 |
1590 | If `logits_all` is True, return the logits for all `input_tokens`. Otherwise, only
1591 | return the logits for last token (which are the predictions for the next token)."""
1592 |
1593 | self._stopwatch.reset()
1594 | self._stopwatch.start_wall_time()
1595 |
1596 | n_input_tokens = len(input_tokens)
1597 |
1598 | if logits_all:
1599 | if self._first_valid_pos(input_tokens) > 0:
1600 | log(
1601 | f'Llama.eval: the KV cache will be cleared in order to compute the logits '
1602 | f'for all tokens in the input', 2
1603 | )
1604 |
1605 | log_verbose(f'Llama.eval: {n_input_tokens} tokens to eval ...')
1606 |
1607 | self.reset()
1608 | actual_input_tokens = input_tokens
1609 | n_actual_input_tokens = len(input_tokens)
1610 | else:
1611 | actual_input_tokens = self._set_cache_tokens(input_tokens)
1612 |
1613 | n_actual_input_tokens = len(actual_input_tokens)
1614 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens
1615 |
1616 | log_verbose(
1617 | f'Llama.eval: {n_cache_hit_tokens} tokens in cache, '
1618 | f'{n_actual_input_tokens} tokens to eval ...'
1619 | )
1620 |
1621 | if n_actual_input_tokens == 0:
1622 | return self.get_logits()
1623 |
1624 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch)
1625 |
1626 | # process each batch one-by-one
1627 | if logits_all:
1628 | all_logits = []
1629 | for batch in _batches_with_progress_bar(batches):
1630 | batch_logits = self._process_batch(batch, logits_all=True)
1631 | all_logits.append(batch_logits)
1632 | final_logits = np.concatenate(all_logits, axis=0)
1633 | else:
1634 | for batch in _batches_with_progress_bar(batches):
1635 | final_logits = self._process_batch(batch, logits_all=False)
1636 |
1637 | self._stopwatch.stop_wall_time()
1638 |
1639 | if get_verbose():
1640 | self._stopwatch.print_stats()
1641 |
1642 | return final_logits
1643 |
1644 | def generate_single(
1645 | self,
1646 | input_tokens: list[int],
1647 | sampler_preset: Optional[SamplerPreset] = None,
1648 | return_logits: bool = False
1649 | ) -> Union[int, np.ndarray]:
1650 | """Generate a single token
1651 |
1652 | - input_tokens:
1653 | The tokens to evaluate
1654 | - sampler_preset:
1655 | The `SamplerPreset` object to use for sampling. If not specified,
1656 | use the model's default sampler parameters
1657 | - return_logits:
1658 | If True, return the logits for the generated token instead of the token ID."""
1659 |
1660 | self._stopwatch.reset()
1661 | self._stopwatch.start_wall_time()
1662 |
1663 | n_input_tokens = len(input_tokens)
1664 |
1665 | actual_input_tokens = self._set_cache_tokens(input_tokens)
1666 |
1667 | n_actual_input_tokens = len(actual_input_tokens)
1668 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens
1669 |
1670 | if sampler_preset is None:
1671 | sampler_params = self._default_sampler_params
1672 | else:
1673 | sampler_params = self.sampler_params_from_preset(sampler_preset)
1674 |
1675 | if get_verbose():
1676 | sampler_params.print_chain()
1677 |
1678 | log_verbose(
1679 | f'Llama.generate_single: {n_cache_hit_tokens} tokens in cache, '
1680 | f'{n_actual_input_tokens} tokens to eval ...'
1681 | )
1682 |
1683 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch)
1684 |
1685 | # process each batch one-by-one
1686 | for batch in _batches_with_progress_bar(batches):
1687 | self._process_batch(batch, logits_all=False)
1688 |
1689 | self._stopwatch.stop_wall_time()
1690 |
1691 | if get_verbose():
1692 | self._stopwatch.print_stats()
1693 |
1694 | if return_logits: # TODO: this is inefficient, it decodes the last token again. replace.
1695 | return self.get_logits()
1696 |
1697 | return self.sample(sampler_params)
1698 |
1699 | def generate(
1700 | self,
1701 | input_tokens: list[int],
1702 | n_predict: int,
1703 | stop_tokens: Optional[list[int]] = None,
1704 | sampler_preset: Optional[SamplerPreset] = None,
1705 | return_logits: bool = False
1706 | ) -> Union[list[int], np.ndarray]:
1707 | """Generate new tokens and return them all at once
1708 |
1709 | - input_tokens:
1710 | The tokens to evaluate
1711 | - n_predict:
1712 | The number of tokens to predict. If `n_predict < 0`, then the number of tokens
1713 | predicted is only limited by the context length. If `n_predict == 0`, then no new
1714 | tokens will be predicted, but the input_tokens will still be processed.
1715 | - stop_tokens:
1716 | A list of token IDs that will end the generation early. Note that the stop token
1717 | will be included in the output. If this parameter is None, all built-in stop tokens
1718 | for the model will be used. Pass an empty list `[]` to ignore all stop tokens.
1719 | - sampler_preset:
1720 | The `SamplerPreset` object to use for sampling. If not specified, use the model's
1721 | default sampler parameters.
1722 | - return_logits:
1723 | If True, return the logits for the generated tokens instead of the token IDs. Note
1724 | that this incurs a slight performance penalty."""
1725 |
1726 | self._stopwatch.reset()
1727 | self._stopwatch.start_wall_time()
1728 |
1729 | n_input_tokens = len(input_tokens)
1730 |
1731 | actual_input_tokens = self._set_cache_tokens(input_tokens)
1732 |
1733 | n_actual_input_tokens = len(actual_input_tokens)
1734 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens
1735 |
1736 | _stop_tokens = stop_tokens if stop_tokens is not None else self.eog_tokens
1737 |
1738 | if sampler_preset is None:
1739 | sampler_params = self._default_sampler_params
1740 | else:
1741 | sampler_params = self.sampler_params_from_preset(sampler_preset)
1742 |
1743 | if get_verbose():
1744 | sampler_params.print_chain()
1745 |
1746 | _n_predict = n_predict if n_predict >= 0 else self._n_ctx - self.pos
1747 |
1748 | log_verbose(
1749 | f'Llama.generate: {n_cache_hit_tokens} tokens in cache, {n_actual_input_tokens} '
1750 | f'tokens to eval ...'
1751 | )
1752 |
1753 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch)
1754 |
1755 | log_debug('Llama.generate: start processing input batches')
1756 |
1757 | # process each input batch one-by-one
1758 | for batch in _batches_with_progress_bar(batches):
1759 | self._process_batch(batch, logits_all=False)
1760 |
1761 | log_debug('Llama.generate: done processing input batches')
1762 |
1763 | if _n_predict == 0:
1764 | self._stopwatch.stop_wall_time()
1765 | if get_verbose():
1766 | self._stopwatch.print_stats()
1767 | return []
1768 |
1769 | predicted_tokens = []
1770 | if return_logits:
1771 | predicted_logits = []
1772 | n_predicted = 0
1773 |
1774 | log_verbose(f'Llama.generate: predicting up to {_n_predict} new tokens ...')
1775 | log_debug(f'Llama.generate: enter while loop')
1776 |
1777 | while n_predicted < _n_predict:
1778 | # sample a token from the latest logits
1779 | sampled_token = self.sample(sampler_params)
1780 |
1781 | # save the sampled token as a prediction
1782 | predicted_tokens.append(sampled_token)
1783 | n_predicted += 1
1784 |
1785 | # if it's a stop token, stop generating
1786 | if sampled_token in _stop_tokens:
1787 | if get_verbose():
1788 | tok_str = ez_decode(self.token_to_piece(sampled_token, True))
1789 | print()
1790 | log(f'inferred stop token {sampled_token} ({tok_str!r})')
1791 | break
1792 |
1793 | # decode the sampled token to get the logits for the following token
1794 | if return_logits:
1795 | logits = self._process_batch([sampled_token], True)
1796 | predicted_logits.append(logits)
1797 | else:
1798 | self._process_batch([sampled_token])
1799 |
1800 | # done generating, show stopwatch stats and return
1801 | self._stopwatch.stop_wall_time()
1802 | log_debug(f'Llama.generate: exited while loop')
1803 | if get_verbose():
1804 | self._stopwatch.print_stats()
1805 | if return_logits:
1806 | return np.stack(predicted_logits, axis=0)
1807 | return predicted_tokens
1808 |
1809 | def stream(
1810 | self,
1811 | input_tokens: list[int],
1812 | n_predict: int,
1813 | stop_tokens: Optional[list[int]] = None,
1814 | sampler_preset: Optional[SamplerPreset] = None,
1815 | yield_logits: bool = False
1816 | ) -> Iterable[Union[int, np.ndarray]]:
1817 | """Return a Generator which yields tokens as they are generated
1818 |
1819 | - input_tokens:
1820 | The tokens to evaluate
1821 | - n_predict:
1822 | The number of tokens to predict. If `n_predict < 0`, then the number of tokens
1823 | predicted is only limited by the context length. If `n_predict == 0`, then no new
1824 | tokens will be predicted, but the input_tokens will still be processed.
1825 | - stop_tokens:
1826 | A list of token IDs that will end the generation early. Note that
1827 | the stop token will be included in the output. If this parameter is
1828 | None, all built-in stop tokens for the model will be used. Pass an
1829 | empty list `[]` to ignore all stop tokens.
1830 | - sampler_preset:
1831 | The `SamplerPreset` object to use for sampling. If not specified,
1832 | use the model's default sampler parameters
1833 | - yield_logits:
1834 | If True, yield the logits for the generated tokens instead of the token IDs"""
1835 |
1836 | self._stopwatch.reset()
1837 | self._stopwatch.start_wall_time()
1838 |
1839 | n_input_tokens = len(input_tokens)
1840 |
1841 | actual_input_tokens = self._set_cache_tokens(input_tokens)
1842 |
1843 | n_actual_input_tokens = len(actual_input_tokens)
1844 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens
1845 |
1846 | _stop_tokens = stop_tokens if stop_tokens is not None else self.eog_tokens
1847 |
1848 | if sampler_preset is None:
1849 | sampler_params = self._default_sampler_params
1850 | else:
1851 | sampler_params = self.sampler_params_from_preset(sampler_preset)
1852 |
1853 | if get_verbose():
1854 | sampler_params.print_chain()
1855 |
1856 | _n_predict = n_predict if n_predict >= 0 else self._n_ctx - self.pos
1857 |
1858 | log_verbose(
1859 | f'Llama.stream: {n_cache_hit_tokens} tokens in cache, {n_actual_input_tokens} '
1860 | f'tokens to eval ...'
1861 | )
1862 |
1863 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch)
1864 |
1865 | log_debug('Llama.stream: start processing input batches')
1866 |
1867 | # process each input batch one-by-one
1868 | for batch in _batches_with_progress_bar(batches):
1869 | self._process_batch(batch, logits_all=False)
1870 |
1871 | log_debug('Llama.stream: done processing input batches')
1872 |
1873 | if _n_predict == 0:
1874 | self._stopwatch.stop_wall_time()
1875 | if get_verbose():
1876 | self._stopwatch.print_stats()
1877 | return
1878 |
1879 | n_predicted = 0
1880 |
1881 | log_verbose(f'Llama.stream: predicting up to {_n_predict} new tokens ...')
1882 | log_debug(f'Llama.stream: enter while loop')
1883 |
1884 | while n_predicted < _n_predict:
1885 | # sample a token from the latest logits
1886 | sampled_token = self.sample(sampler_params)
1887 |
1888 | is_stop_token = sampled_token in _stop_tokens
1889 |
1890 | if yield_logits:
1891 | if not is_stop_token:
1892 | # process the token, yield the logits for the next prediction
1893 | logits = self._process_batch([sampled_token], True)
1894 | yield logits
1895 | else:
1896 | yield sampled_token
1897 |
1898 | n_predicted += 1
1899 |
1900 | if is_stop_token:
1901 | if get_verbose():
1902 | tok_str = ez_decode(self.token_to_piece(sampled_token, True))
1903 | print()
1904 | log(f'inferred stop token {sampled_token} ({tok_str!r})')
1905 | break
1906 |
1907 | if not yield_logits:
1908 | self._process_batch([sampled_token])
1909 |
1910 | # done generating, show stopwatch stats
1911 | self._stopwatch.stop_wall_time()
1912 | log_debug(f'Llama.stream: exited while loop')
1913 | if get_verbose():
1914 | self._stopwatch.print_stats()
1915 |
1916 | def benchmark(
1917 | self,
1918 | n_tokens_pp: Optional[int] = None,
1919 | n_tokens_tg: Optional[int] = None,
1920 | n_runs: Optional[int] = None
1921 | ) -> list[dict]:
1922 | """Measure the prompt processing and text generation speed of this Llama."""
1923 |
1924 | n_tokens_pp = n_tokens_pp if n_tokens_pp is not None else self.n_batch()
1925 | n_tokens_tg = n_tokens_tg if n_tokens_tg is not None else 10
1926 | n_runs = n_runs if n_runs is not None else 3
1927 |
1928 | results = []
1929 | total_pp_time_ns = 0
1930 | total_tg_time_ns = 0
1931 |
1932 | for i in range(1, n_runs+1):
1933 |
1934 | log_verbose(f'benchmark: starting run {i}/{n_runs}:')
1935 |
1936 | log_verbose(f'benchmark: processing {n_tokens_pp} tokens ... please wait ...')
1937 | self.reset()
1938 | self.eval(input_tokens=[0] * n_tokens_pp)
1939 | pp_ns = self._stopwatch.get_elapsed_time_pp()
1940 | total_pp_time_ns += pp_ns
1941 |
1942 | log_verbose(f'benchmark: generating {n_tokens_tg} tokens ... please wait ...')
1943 | self.reset()
1944 | self.generate(
1945 | input_tokens=[0],
1946 | n_predict=n_tokens_tg,
1947 | stop_tokens=[],
1948 | sampler_preset=SamplerPreset(seed=42, top_k=1, temp=0.0)
1949 | )
1950 | tg_ns = self._stopwatch.get_elapsed_time_tg()
1951 | total_tg_time_ns += tg_ns
1952 |
1953 | results.append({
1954 | 'n_tokens_pp' : n_tokens_pp,
1955 | 'n_tokens_tg' : n_tokens_tg,
1956 | 'pp_time_ns' : pp_ns,
1957 | 'tg_time_ns' : tg_ns
1958 | })
1959 |
1960 | avg_pp_time_ns = total_pp_time_ns / n_runs
1961 | avg_tg_time_ns = total_tg_time_ns / n_runs
1962 |
1963 | avg_pp_time_ms = avg_pp_time_ns / 1e6
1964 | avg_tg_time_ms = avg_tg_time_ns / 1e6
1965 |
1966 | avg_pp_tok_per_sec = n_tokens_pp / (avg_pp_time_ns / 1e9)
1967 | avg_tg_tok_per_sec = n_tokens_tg / (avg_tg_time_ns / 1e9)
1968 |
1969 | log_verbose(
1970 | f'average pp speed for {n_tokens_pp:>7} tokens over {n_runs} runs: '
1971 | f'{avg_pp_time_ms:>13.3f}ms ({avg_pp_tok_per_sec:10.2f} tok/s)', 4
1972 | )
1973 | log_verbose(
1974 | f'average tg speed for {n_tokens_tg:>7} tokens over {n_runs} runs: '
1975 | f'{avg_tg_time_ms:>13.3f}ms ({avg_tg_tok_per_sec:10.2f} tok/s)', 4
1976 | )
1977 |
1978 | return results
1979 |
1980 | def sample_greedy(self) -> int:
1981 | id = _internals.sample_greedy(self._ctx.ctx)
1982 | # llama_sampler_sample internally calls llama_sampler_accept.
1983 | # uncomment the next line if this changes
1984 | #lib.llama_sampler_accept(_internals.greedy_sampler, id)
1985 | return id
1986 |
1987 | def sampler_params_from_preset(self, sampler_preset: SamplerPreset) -> SamplerParams:
1988 | """Create and return a new `SamplerParams` object for this Llama using the provided
1989 | `SamplerPreset`.
1990 |
1991 | @param sampler_preset: The `sampling.SamplerPreset` object which defines the sampler
1992 | parameter values to use"""
1993 |
1994 | return SamplerParams(
1995 | llama = self,
1996 | seed = sampler_preset.seed,
1997 | top_k = sampler_preset.top_k,
1998 | top_p = sampler_preset.top_p,
1999 | min_p = sampler_preset.min_p,
2000 | xtc_probability = sampler_preset.xtc_probability,
2001 | xtc_threshold = sampler_preset.xtc_threshold,
2002 | typical_p = sampler_preset.typical_p,
2003 | temp = sampler_preset.temp,
2004 | dynatemp_delta = sampler_preset.dynatemp_delta,
2005 | dynatemp_exponent = sampler_preset.dynatemp_exponent,
2006 | penalty_last_n = sampler_preset.penalty_last_n,
2007 | penalty_repeat = sampler_preset.penalty_repeat,
2008 | penalty_freq = sampler_preset.penalty_freq,
2009 | penalty_present = sampler_preset.penalty_present,
2010 | dry_multiplier = sampler_preset.dry_multiplier,
2011 | dry_base = sampler_preset.dry_base,
2012 | dry_allowed_length = sampler_preset.dry_allowed_length,
2013 | dry_penalty_last_n = sampler_preset.dry_penalty_last_n,
2014 | mirostat = sampler_preset.mirostat,
2015 | top_n_sigma = sampler_preset.top_n_sigma,
2016 | mirostat_tau = sampler_preset.mirostat_tau,
2017 | mirostat_eta = sampler_preset.mirostat_eta,
2018 | dry_sequence_breakers = sampler_preset.dry_sequence_breakers,
2019 | logit_bias = sampler_preset.logit_bias
2020 | )
2021 |
2022 | def sample(self, sampler_params: Optional[SamplerParams] = None) -> int:
2023 | """Sample a token using the current context
2024 |
2025 | - params:
2026 | The `sampling.SamplerParams` object which defines the sampling
2027 | parameters to use. If this parameter is None, the default sampler
2028 | paramater values will be used."""
2029 | sampler_params = sampler_params if sampler_params is not None else (
2030 | self._default_sampler_params
2031 | )
2032 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.sample')
2033 | id = lib.llama_sampler_sample(sampler_params.smpl, self._ctx.ctx, -1)
2034 | # llama_sampler_sample internally calls llama_sampler_accept.
2035 | # uncomment the next line if this changes
2036 | #lib.llama_sampler_accept(sampler_params.smpl, id)
2037 | return id
2038 |
2039 | def get_logits(self) -> np.ndarray:
2040 | """Return the raw logits for the last token in the context, which are the predictions
2041 | for the next token. The returned array has shape `(n_vocab,)`."""
2042 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.get_logits')
2043 | return _internals.get_logits(self._ctx.ctx, self._n_vocab)
2044 |
2045 | def get_scores(self, temp: Optional[float] = None) -> np.ndarray:
2046 | """Compute the logits for the last token in the context, normalized with softmax.
2047 | Optionally apply temperature `temp` if specified.
2048 |
2049 | Any floating-point value for temperature `temp` is valid, including 0.0
2050 | and negative numbers.
2051 |
2052 | The returned array has shape `(n_vocab,)`."""
2053 | logits = self.get_logits()
2054 | return softmax(logits, T=temp if temp is not None else 1.0)
2055 |
2056 | def get_tokenization_mapping(self, tokens: list[int]) -> list[tuple[int, bytes]]:
2057 | """Given some tokens, return a list of tuples where the first item in the
2058 | tuple is the token ID and the second item is the corresponding UTF-8
2059 | text bytes."""
2060 | return list(zip(tokens, [self.token_to_piece(id, special=True) for id in tokens]))
2061 |
2062 | def print_tokenization_mapping(
2063 | self,
2064 | tokens: list[int],
2065 | file: _SupportsWriteAndFlush = sys.stderr
2066 | ) -> None:
2067 | """Given some tokens, print a mapping of each token ID to the
2068 | corresponding UTF-8 text bytes
2069 |
2070 | This is meant to be roughly equivalent to `llama.cpp/llama-tokenize`
2071 |
2072 | - tokens:
2073 | The tokens to print a mapping for
2074 | - file:
2075 | The file or stream to which the mapping will be printed"""
2076 | token_mapping = self.get_tokenization_mapping(tokens)
2077 | for id, bytes in token_mapping:
2078 | print(f"{id:>7} -> {repr(bytes)} ({bytes.hex(':')})", file=file)
2079 | #print(f"{id:>7} -> {str(txt)}", file=file)
2080 | print(f"Total number of tokens: {len(token_mapping)}", file=file, flush=True)
2081 |
2082 | def name(self) -> str:
2083 | """Get the name of the model from the GGUF metadata"""
2084 | # '/path/to/my-model.gguf' --> 'my-model'
2085 | model_file_basename = os.path.basename(self._model.path_model).removesuffix('.gguf')
2086 | # 'my-model-00001-of-99999' --> 'my-model'
2087 | model_file_basename = re.sub(r'-\d{5}-of-\d{5}$', '', model_file_basename)
2088 | # TODO: get from metadata instead, fallback to using filename
2089 | return model_file_basename
2090 |
2091 | def bpw(self) -> float:
2092 | """Get the average bits per weight of the model"""
2093 | return (self._model_size_bytes * 8) / self._n_params
2094 |
2095 | def save_state(self, file_path: str) -> None:
2096 | """Save the current state of the context to a file"""
2097 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.save_state')
2098 |
2099 | state_size_bytes = lib.llama_state_get_size(self._ctx.ctx)
2100 | state_size_mib = int(state_size_bytes / (1024 * 1024)) # approximate
2101 |
2102 | log(f'Llama.save_state: state size: {state_size_mib} MiB ({state_size_bytes} bytes)')
2103 | log(f'Llama.save_state: saving to {file_path} ...')
2104 |
2105 | if os.path.exists(file_path):
2106 | log(f'Llama.save_state: file exists, will be overwritten', 2)
2107 |
2108 | # save the llama state
2109 | with suppress_output(disable=get_verbose()):
2110 | success = lib.llama_state_save_file(self._ctx.ctx, file_path, self.context_tokens)
2111 |
2112 | if success:
2113 | log(f'Llama.save_state: successfully saved state')
2114 | else:
2115 | raise RuntimeError(f'Llama.save_state: failed to save state')
2116 |
2117 | def load_state(self, file_path: str) -> None:
2118 | """Load a previously saved context state from a file"""
2119 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.load_state')
2120 |
2121 | if not os.path.exists(file_path):
2122 | raise FileNotFoundError(f'Llama.load_state: file_path {file_path} does not exist')
2123 |
2124 | if os.path.isdir(file_path):
2125 | raise IsADirectoryError(f'Llama.load_state: file_path {file_path} is a directory')
2126 |
2127 | # reset the current context before loading the new one
2128 | self.reset()
2129 |
2130 | n_ctx = self.n_ctx()
2131 | loaded_tokens_buf = (lib.llama_token * n_ctx)()
2132 | n_loaded_tokens_p = ctypes.c_size_t(0)
2133 |
2134 | log(f'Llama.load_state: loading from {file_path} ...')
2135 |
2136 | # load the llama state
2137 | with suppress_output(disable=get_verbose()):
2138 | success = lib.llama_state_load_file(
2139 | ctx=self._ctx.ctx,
2140 | path_session=file_path,
2141 | tokens_out=loaded_tokens_buf,
2142 | n_token_capacity=n_ctx,
2143 | n_token_count_out=ctypes.byref(n_loaded_tokens_p)
2144 | )
2145 |
2146 | if success:
2147 | n_loaded_tokens = n_loaded_tokens_p.value
2148 |
2149 | self.context_tokens = list(loaded_tokens_buf[:n_loaded_tokens])
2150 | self.pos = n_loaded_tokens
2151 |
2152 | state_size_bytes = lib.llama_state_get_size(self._ctx.ctx)
2153 | state_size_mib = int(state_size_bytes / (1024 * 1024)) # approximate
2154 |
2155 | log(f'Llama.load_state: state size: {state_size_mib} MiB ({state_size_bytes} bytes)')
2156 | log(f'Llama.load_state: successfully loaded state ({n_loaded_tokens} tokens)')
2157 | else:
2158 | raise RuntimeError(f'Llama.load_state: failed to load state')
2159 |
2160 | def reset(self) -> None:
2161 | """Reset the position of the model and clear the KV cache"""
2162 | self.kv_cache_clear()
2163 | self.pos = 0
2164 | self.context_tokens = []
2165 |
--------------------------------------------------------------------------------
/easy_llama/sampling.py:
--------------------------------------------------------------------------------
1 | # sampling.py
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | """This file provides functionality for defining the sampler parameters used to control
6 | text generation."""
7 |
8 | # TODO: implement grammar, implement custom sampling chains
9 |
10 | import os
11 | import sys
12 | import ctypes
13 |
14 | from .utils import null_ptr_check, log, ez_encode
15 | from .libllama import _internals
16 | from typing import Optional
17 |
18 | from . import libllama as lib
19 |
20 | HIGH_TEMP = 10_000.0
21 |
22 | class Llama: # can't import the real Llama - would be circular
23 | """Type hint denoting a `llama.Llama` instance"""
24 |
25 | def _get_random_seed() -> int:
26 | # unsigned 32-bit integer
27 | return int.from_bytes(bytes=os.urandom(4), byteorder=sys.byteorder, signed=False)
28 |
29 | class SamplerParams:
30 | """A SamplerParams object is used by a Llama model to define sampling behaviour.
31 |
32 | However, SamplerParams objects also require some information about the Llama model itself,
33 | such as n_ctx_train, n_vocab, etc. Therefore Llama models and SamplerParams are tightly
34 | coupled.
35 |
36 | A SamplerPreset (which is a separate class) can be used to define sampling parameters
37 | without having to specify a Llama object. In turn, the Llama class can use these presets to
38 | create the actual SamplerParams object it needs for sampling."""
39 |
40 | # NOTE: as of 2025-04-04, the default sampler chain for llama-cli is:
41 | #
42 | # logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p ->
43 | # xtc -> temp-ext -> dist
44 | #
45 | # -----------------------------------------------------------------------------------
46 | #
47 | # as of 2025-04-22, the sampler chain for easy-llama is constructed as follows:
48 | #
49 | # -----------------------------------------------------------------------------------
50 | #
51 | # logits -> logit-bias -> top-k (k=128) -> penalties -> dry -> xtc ...
52 | #
53 | # -- IF TEMP <= 0.0:
54 | #
55 | # ... -> greedy
56 | #
57 | # -- IF MIROSTAT v1:
58 | #
59 | # ... -> temp(-ext) -> mirostat-v1
60 | #
61 | # -- IF MIROSTAT v2:
62 | #
63 | # ... -> temp(-ext) -> mirostat-v2
64 | #
65 | # -- ELSE:
66 | #
67 | # -- IF TOP-N-SIGMA >= 0:
68 | #
69 | # ... -> top-k -> temp -> top-n-sigma -> dist
70 | #
71 | # -- ELSE:
72 | #
73 | # ... -> top-k -> typical-p -> top-p -> min-p -> temp(-ext) -> dist
74 | #
75 | # ----------------------------------------------------------------------------------
76 | #
77 | # - "temp(-ext)" denotes "temp" if dynatemp_range == 0.0, otherwise
78 | # "temp-ext"
79 | #
80 | # - "top-k (k=128)" is always performed before applying penalties
81 | # and DRY to improve performance
82 | #
83 | # - note that "logit-bias", "top-k (k=128)", "penalties", and "dry"
84 | # are always applied
85 |
86 | def __init__(
87 | #
88 | # ref: llama.cpp/common/common.h: struct common_params_sampling { ... }
89 | #
90 | self,
91 | llama: Llama, # some samplers require info about n_ctx_train, n_vocab, etc.
92 |
93 | seed: int = -1, # random seed: <= 0
94 | top_k: int = 40, # neutral: <= 0
95 | top_p: float = 0.95, # neutral: 1.0
96 | min_p: float = 0.05, # neutral: 0.0
97 | xtc_probability: float = 0.0, # neutral: 0.0
98 | xtc_threshold: float = 0.1, # disable: > 0.5
99 | typical_p: float = 1.0, # neutral: 1.0
100 | temp: float = 0.8, # neutral: 1.0, greedy: <= 0.0
101 | dynatemp_delta: float = 0.0, # neutral: <= 0.0
102 | dynatemp_exponent: float = 1.0, # controls how entropy maps to dynamic temperature
103 | penalty_last_n: int = 64, # disable: 0, n_ctx: -1, last n tokens to penalize
104 | penalty_repeat: float = 1.0, # neutral: 1.0, should be between 1.0 and ~1.1
105 | penalty_freq: float = 0.0, # neutral: 0.0
106 | penalty_present: float = 0.0, # neutral: 0.0
107 | dry_multiplier: float = 0.0, # disable: 0.0, DRY repetition penalty for tokens extending repetition:
108 | dry_base: float = 1.75, # disable: 0.0, multiplier * base ^ (length of sequence before token - allowed length)
109 | dry_allowed_length: int = 2, # tokens extending repetitions beyond this receive penalty
110 | dry_penalty_last_n: int = -1, # disable: 0, n_ctx: -1, how many tokens to scan for repetitions
111 | mirostat: int = 0, # disable: 0, use v1: 1, use v2: 2
112 | top_n_sigma: float = -1.0, # disable: -1.0
113 | mirostat_tau: float = 5.0, # target entropy for mirostat
114 | mirostat_eta: float = 0.1, # learning rate for mirostat
115 |
116 | dry_sequence_breakers: list[str] = ["\n", ":", "\"", "*"], # default sequence breakers for DRY
117 |
118 | # TODO: grammar goes here
119 |
120 | logit_bias: Optional[dict[int, float]] = None
121 | ):
122 | self.smpl = None
123 |
124 | #
125 | # ref: llama.cpp/common/common.h: common_sampler_init(...) { ... }
126 | #
127 |
128 | #
129 | # Store parameter values as attributes
130 | #
131 |
132 | # NOTE: Changing these attributes will not change the sampling. If you need to change
133 | # sampling, construct a new sampler.
134 |
135 | self.llama = llama
136 | self.seed = seed
137 |
138 | self.top_k = top_k
139 | self.top_p = top_p
140 | self.min_p = min_p
141 | self.xtc_probability = xtc_probability
142 | self.xtc_threshold = xtc_threshold
143 | self.typical_p = typical_p
144 | self.temp = temp
145 | self.dynatemp_delta = dynatemp_delta
146 | self.dynatemp_exponent = dynatemp_exponent
147 | self.penalty_last_n = penalty_last_n
148 | self.penalty_repeat = penalty_repeat
149 | self.penalty_freq = penalty_freq
150 | self.penalty_present = penalty_present
151 | self.dry_multiplier = dry_multiplier
152 | self.dry_base = dry_base
153 | self.dry_allowed_length = dry_allowed_length
154 | self.dry_penalty_last_n = dry_penalty_last_n
155 | self.mirostat = mirostat
156 | self.top_n_sigma = top_n_sigma
157 | self.mirostat_tau = mirostat_tau
158 | self.mirostat_eta = mirostat_eta
159 |
160 | # TODO
161 | # self._validate_params() # show warnings/errors for nonsensical sampler values
162 |
163 | self.dry_sequence_breakers = dry_sequence_breakers
164 |
165 | self.logit_bias = logit_bias
166 |
167 | self._chain_str = ''
168 |
169 | sparams = lib.llama_sampler_chain_default_params()
170 | null_ptr_check(sparams, 'sparams', 'SamplerParams.__init__')
171 |
172 | smpl = lib.llama_sampler_chain_init(sparams)
173 | null_ptr_check(smpl, 'smpl', 'SamplerParams.__init__')
174 |
175 | # Logit bias
176 |
177 | if logit_bias is not None:
178 | if len(logit_bias) == 1:
179 | self._chain_str += f'one logit bias -> '
180 | else:
181 | self._chain_str += f'{len(logit_bias)} logit biases -> '
182 | logit_bias_arr = _internals.get_logit_bias_array(logit_bias)
183 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_logit_bias(
184 | n_vocab=llama._n_vocab,
185 | n_logit_bias=len(logit_bias),
186 | logit_bias=logit_bias_arr
187 | ))
188 |
189 | penalty_last_n = penalty_last_n if penalty_last_n >= 0 else llama.n_ctx()
190 | _penalties_sampler_active = penalty_last_n != 0 and any(
191 | [penalty_repeat != 1.0, penalty_present != 0.0, penalty_freq != 0.0]
192 | )
193 | _dry_sampler_active = dry_multiplier > 0.0
194 | _use_topk128_cutoff = _penalties_sampler_active or _dry_sampler_active
195 |
196 | if _use_topk128_cutoff:
197 |
198 | # Top-K (where k == 128)
199 | #
200 | # NOTE: This improves performance by greatly reducing the number of
201 | # tokens that the DRY and penalties samplers need to consider. It
202 | # should have absolutely no effect on the output except under the
203 | # strangest and most unlikely circumstances (like when temp > 10.0
204 | # and all other samplers are explicitly disabled).
205 | #
206 | # If you really need to bypass this, you can construct your own
207 | # llama_sampler_chain manually. But you probably don't need to.
208 | #
209 | self._chain_str += 'top-k 128 -> '
210 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_k(k=128))
211 |
212 | # Penalties
213 |
214 | if _penalties_sampler_active:
215 | self._chain_str += f'penalty last:{penalty_last_n}'
216 | self._chain_str += f' rept:{penalty_repeat:.3f}' if penalty_repeat != 1.0 else ''
217 | self._chain_str += f' pres:{penalty_present:.3f}' if penalty_present != 0.0 else ''
218 | self._chain_str += f' freq:{penalty_freq:.3f}' if penalty_freq != 0.0 else ''
219 | self._chain_str += f' -> '
220 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_penalties(
221 | penalty_last_n=penalty_last_n,
222 | penalty_repeat=penalty_repeat,
223 | penalty_freq=penalty_freq,
224 | penalty_present=penalty_present
225 | ))
226 |
227 | # DRY
228 |
229 | if _dry_sampler_active:
230 | self._chain_str += f'DRY x{dry_multiplier:.2f} base:{dry_base:.2f} '
231 | self._chain_str += f'len:{dry_allowed_length} -> '
232 | # dry == D.R.Y. ("Don't Repeat Yourself")
233 | # ref: https://github.com/oobabooga/text-generation-webui/pull/5677
234 | null_ptr_check(llama._vocab, 'llama._vocab', 'SamplerParams.__init__')
235 | seq_breakers = dry_sequence_breakers
236 | seq_breakers_bytes = [ez_encode(s) for s in seq_breakers]
237 | arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
238 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_dry(
239 | vocab=llama._vocab,
240 | n_ctx_train=llama._n_ctx_train,
241 | dry_multiplier=dry_multiplier,
242 | dry_base=dry_base,
243 | dry_allowed_length=dry_allowed_length,
244 | dry_penalty_last_n=dry_penalty_last_n,
245 | seq_breakers=arr,
246 | num_breakers=len(seq_breakers)
247 | ))
248 |
249 | # XTC
250 |
251 | if xtc_probability > 0.0:
252 | self._chain_str += f'XTC p:{xtc_probability:.2f} t:{xtc_threshold:.2f} -> '
253 | lib.llama_sampler_chain_add(
254 | smpl, lib.llama_sampler_init_xtc(
255 | p=xtc_probability,
256 | t=xtc_threshold,
257 | min_keep=1,
258 | seed=seed if seed > 0 else _get_random_seed()
259 | )
260 | )
261 |
262 | # IF TEMP <= 0.0:
263 |
264 | if temp <= 0.0:
265 | # ... -> greedy
266 | self._chain_str += 'greedy'
267 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_greedy())
268 |
269 | # IF MIROSTAT v1:
270 |
271 | elif mirostat == 1:
272 | # ... -> temp(-ext) -> mirostat-v1
273 | if dynatemp_delta != 0.0:
274 | # dynamic temperature AKA entropy sampling
275 | self._chain_str += f'temp {temp:.2f} +/- {dynatemp_delta:.2f} -> '
276 | lib.llama_sampler_chain_add(
277 | smpl, lib.llama_sampler_init_temp_ext(
278 | t=temp,
279 | delta=dynatemp_delta,
280 | exponent=dynatemp_exponent
281 | )
282 | )
283 | else:
284 | # standard temperature
285 | self._chain_str += f'temp {temp:.2f} -> '
286 | lib.llama_sampler_chain_add(
287 | smpl, lib.llama_sampler_init_temp(t=temp)
288 | )
289 | self._chain_str += f'mirostat v1 tau:{mirostat_tau:.2f} eta:{mirostat_eta:.2f}'
290 | lib.llama_sampler_chain_add(
291 | smpl, lib.llama_sampler_init_mirostat(
292 | seed=seed if seed > 0 else _get_random_seed(),
293 | tau=mirostat_tau,
294 | eta=mirostat_eta
295 | )
296 | )
297 |
298 | # IF MIROSTAT v2:
299 |
300 | elif mirostat == 2:
301 | # ... -> temp(-ext) -> mirostat-v2
302 | if dynatemp_delta != 0.0:
303 | # dynamic temperature AKA entropy sampling
304 | self._chain_str += f'temp-ext {temp:.2f} +/- {dynatemp_delta:.2f} -> '
305 | lib.llama_sampler_chain_add(
306 | smpl, lib.llama_sampler_init_temp_ext(
307 | t=temp,
308 | delta=dynatemp_delta,
309 | exponent=dynatemp_exponent
310 | )
311 | )
312 | else:
313 | # standard temperature
314 | self._chain_str += f'temp {temp:.2f} -> '
315 | lib.llama_sampler_chain_add(
316 | smpl, lib.llama_sampler_init_temp(t=temp)
317 | )
318 | self._chain_str += f'mirostat v2 tau:{mirostat_tau:.2f} eta:{mirostat_eta:.2f}'
319 | lib.llama_sampler_chain_add(
320 | smpl, lib.llama_sampler_init_mirostat_v2(
321 | seed=seed if seed > 0 else _get_random_seed(),
322 | tau=mirostat_tau,
323 | eta=mirostat_eta
324 | )
325 | )
326 |
327 | # DEFAULT CASE
328 |
329 | elif mirostat == 0:
330 | if top_n_sigma >= 0.0:
331 | # ... -> top-k -> temp -> top-n-sigma -> ...
332 | self._chain_str += (
333 | f'top-k {top_k} -> temp {temp:.2f} -> top-n-sigma {top_n_sigma:.2f} -> '
334 | )
335 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_k(k=top_k))
336 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_temp(t=temp))
337 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_n_sigma(n=top_n_sigma))
338 | else:
339 | # ... -> top-k -> typical-p -> top-p -> min-p -> temp(-ext) -> ...
340 | if top_k > 0:
341 | self._chain_str += f'top-k {top_k} -> '
342 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_k(k=top_k))
343 | if typical_p != 1.0:
344 | self._chain_str += f'typical-p {typical_p:.2f} -> '
345 | lib.llama_sampler_chain_add(
346 | smpl, lib.llama_sampler_init_typical(p=typical_p, min_keep=1)
347 | )
348 | if top_p < 1.0:
349 | self._chain_str += f'top-p {top_p:.2f} -> '
350 | lib.llama_sampler_chain_add(
351 | smpl, lib.llama_sampler_init_top_p(p=top_p, min_keep=1)
352 | )
353 | if min_p > 0.0:
354 | self._chain_str += f'min-p {min_p:.3f} -> '
355 | lib.llama_sampler_chain_add(
356 | smpl, lib.llama_sampler_init_min_p(p=min_p, min_keep=1)
357 | )
358 | if dynatemp_delta != 0.0:
359 | # dynamic temperature AKA entropy sampling
360 | self._chain_str += f'temp {temp:.2f} +/- {dynatemp_delta:.2f} -> '
361 | lib.llama_sampler_chain_add(
362 | smpl, lib.llama_sampler_init_temp_ext(
363 | t=temp,
364 | delta=dynatemp_delta,
365 | exponent=dynatemp_exponent
366 | )
367 | )
368 | else:
369 | # standard temperature
370 | self._chain_str += f'temp {temp:.2f} -> '
371 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_temp(t=temp))
372 |
373 | # ... -> dist
374 | self._chain_str += 'dist'
375 | lib.llama_sampler_chain_add(
376 | smpl, lib.llama_sampler_init_dist(seed=seed if seed > 0 else _get_random_seed())
377 | )
378 |
379 | else:
380 | raise ValueError(
381 | f'SamplerParams.__init__: unknown mirostat version {mirostat!r}'
382 | )
383 |
384 | self.smpl = smpl
385 |
386 | def __del__(self):
387 | self.free()
388 |
389 | def __repr__(self) -> str:
390 | return (
391 | f"SamplerParams("
392 | f"llama=, "
393 | f"seed={self.seed}, "
394 | f"top_k={self.top_k}, "
395 | f"min_p={self.min_p}, "
396 | f"xtc_probability={self.xtc_probability}, "
397 | f"xtc_threshold={self.xtc_threshold}, "
398 | f"typical_p={self.typical_p}, "
399 | f"temp={self.temp}, "
400 | f"dynatemp_delta={self.dynatemp_delta}, "
401 | f"dynatemp_exponent={self.dynatemp_exponent}, "
402 | f"penalty_last_n={self.penalty_last_n}, "
403 | f"penalty_repeat={self.penalty_repeat}, "
404 | f"penalty_freq={self.penalty_freq}, "
405 | f"penalty_present={self.penalty_present}, "
406 | f"dry_multiplier={self.dry_multiplier}, "
407 | f"dry_base={self.dry_base}, "
408 | f"dry_allowed_length={self.dry_allowed_length}, "
409 | f"dry_penalty_last_n={self.dry_penalty_last_n}, "
410 | f"top_n_sigma={self.top_n_sigma}, "
411 | f"mirostat={self.mirostat}, "
412 | f"mirostat_tau={self.mirostat_tau}, "
413 | f"mirostat_eta={self.mirostat_eta}, "
414 | f"dry_sequence_breakers={self.dry_sequence_breakers!r}, "
415 | f"logit_bias={self.logit_bias!r}"
416 | f")"
417 | )
418 |
419 | def print_chain(self) -> None:
420 | """Print the active sampler chain."""
421 | log(f'sampler chain: {self._chain_str}')
422 |
423 | def free(self) -> None:
424 | if self.smpl is not None:
425 | lib.llama_sampler_free(self.smpl)
426 | self.smpl = None
427 |
428 | def reset(self) -> None:
429 | null_ptr_check(self.smpl, 'self.smpl', 'SamplerParams.reset')
430 | lib.llama_sampler_reset(self.smpl)
431 |
432 | def to_dict(self) -> dict:
433 | """Return the sampler parameters as a dictionary."""
434 | return {
435 | # not including "llama"
436 | "seed" : self.seed,
437 | "top_k" : self.top_k,
438 | "top_p" : self.top_p,
439 | "min_p" : self.min_p,
440 | "xtc_probability" : self.xtc_probability,
441 | "xtc_threshold" : self.xtc_threshold,
442 | "typical_p" : self.typical_p,
443 | "temp" : self.temp,
444 | "dynatemp_delta" : self.dynatemp_delta,
445 | "dynatemp_exponent" : self.dynatemp_exponent,
446 | "penalty_last_n" : self.penalty_last_n,
447 | "penalty_repeat" : self.penalty_repeat,
448 | "penalty_freq" : self.penalty_freq,
449 | "penalty_present" : self.penalty_present,
450 | "dry_multiplier" : self.dry_multiplier,
451 | "dry_base" : self.dry_base,
452 | "dry_allowed_length" : self.dry_allowed_length,
453 | "dry_penalty_last_n" : self.dry_penalty_last_n,
454 | "mirostat" : self.mirostat,
455 | "top_n_sigma" : self.top_n_sigma,
456 | "mirostat_tau" : self.mirostat_tau,
457 | "mirostat_eta" : self.mirostat_eta,
458 | "dry_sequence_breakers" : self.dry_sequence_breakers,
459 | "logit_bias" : self.logit_bias
460 | }
461 |
462 | @classmethod
463 | def from_dict(cls, llama: Llama, params_dict: dict):
464 | """Creates a SamplerParams instance from a dictionary.
465 |
466 | Args:
467 | llama: The Llama object associated with these parameters.
468 | params_dict: A dictionary containing the sampler parameters.
469 |
470 | Returns:
471 | A new SamplerParams instance.
472 | """
473 | # Create a copy to avoid modifying the original dictionary
474 | # and remove keys that are not constructor arguments (like 'llama' if present)
475 | filtered_params = {k: v for k, v in params_dict.items() if k in SamplerPreset.__init__.__code__.co_varnames}
476 | return cls(llama=llama, **filtered_params)
477 |
478 |
479 | class SamplerPreset:
480 | """A SamplerPreset object contains all the values necessary to construct a SamplerParams
481 | object using a Llama model.
482 |
483 | Llama objects use SamplerParam objects to define the sampling parameters, but these
484 | SamplerParam objects also require some information about the Llama model itself, such as
485 | n_ctx_train, n_vocab, etc. Therefore Llama models and SamplerParams are tightly coupled.
486 |
487 | A SamplerPreset (this class) can be used to define sampling parameters without having to
488 | specify a Llama object. In turn, the Llama class can use these presets to create the actual
489 | SamplerParams object it needs for sampling."""
490 |
491 | def __init__(
492 | self,
493 |
494 | seed: int = -1, # random seed: <= 0
495 | top_k: int = 40, # neutral: <= 0
496 | top_p: float = 0.95, # neutral: 1.0
497 | min_p: float = 0.05, # neutral: 0.0
498 | xtc_probability: float = 0.0, # neutral: 0.0
499 | xtc_threshold: float = 0.1, # disable: > 0.5
500 | typical_p: float = 1.0, # neutral: 1.0
501 | temp: float = 0.8, # neutral: 1.0, greedy: <= 0.0
502 | dynatemp_delta: float = 0.0, # neutral: <= 0.0
503 | dynatemp_exponent: float = 1.0, # controls how entropy maps to dynamic temperature
504 | penalty_last_n: int = 64, # disable: 0, n_ctx: -1, last n tokens to penalize
505 | penalty_repeat: float = 1.0, # neutral: 1.0, should be between 1.0 and ~1.1
506 | penalty_freq: float = 0.0, # neutral: 0.0
507 | penalty_present: float = 0.0, # neutral: 0.0
508 | dry_multiplier: float = 0.0, # disable: 0.0, DRY repetition penalty for tokens extending repetition:
509 | dry_base: float = 1.75, # disable: 0.0, multiplier * base ^ (length of sequence before token - allowed length)
510 | dry_allowed_length: int = 2, # tokens extending repetitions beyond this receive penalty
511 | dry_penalty_last_n: int = -1, # disable: 0, n_ctx: -1, how many tokens to scan for repetitions
512 | mirostat: int = 0, # disable: 0, use v1: 1, use v2: 2
513 | top_n_sigma: float = -1.0, # disable: -1.0
514 | mirostat_tau: float = 5.0, # target entropy for mirostat
515 | mirostat_eta: float = 0.1, # learning rate for mirostat
516 |
517 | dry_sequence_breakers: list[str] = ["\n", ":", "\"", "*"], # default sequence breakers for DRY
518 |
519 | # TODO: grammar goes here
520 |
521 | logit_bias: Optional[dict[int, float]] = None
522 | ):
523 | self.seed = seed
524 |
525 | self.top_k = top_k
526 | self.top_p = top_p
527 | self.min_p = min_p
528 | self.xtc_probability = xtc_probability
529 | self.xtc_threshold = xtc_threshold
530 | self.typical_p = typical_p
531 | self.temp = temp
532 | self.dynatemp_delta = dynatemp_delta
533 | self.dynatemp_exponent = dynatemp_exponent
534 | self.penalty_last_n = penalty_last_n
535 | self.penalty_repeat = penalty_repeat
536 | self.penalty_freq = penalty_freq
537 | self.penalty_present = penalty_present
538 | self.dry_multiplier = dry_multiplier
539 | self.dry_base = dry_base
540 | self.dry_allowed_length = dry_allowed_length
541 | self.dry_penalty_last_n = dry_penalty_last_n
542 | self.mirostat = mirostat
543 | self.top_n_sigma = top_n_sigma
544 | self.mirostat_tau = mirostat_tau
545 | self.mirostat_eta = mirostat_eta
546 |
547 | self.dry_sequence_breakers = dry_sequence_breakers
548 |
549 | self.logit_bias = logit_bias
550 |
551 | def __repr__(self) -> str:
552 | return (
553 | f"SamplerPreset("
554 | f"seed={self.seed}, "
555 | f"top_k={self.top_k}, "
556 | f"min_p={self.min_p}, "
557 | f"xtc_probability={self.xtc_probability}, "
558 | f"xtc_threshold={self.xtc_threshold}, "
559 | f"typical_p={self.typical_p}, "
560 | f"temp={self.temp}, "
561 | f"dynatemp_delta={self.dynatemp_delta}, "
562 | f"dynatemp_exponent={self.dynatemp_exponent}, "
563 | f"penalty_last_n={self.penalty_last_n}, "
564 | f"penalty_repeat={self.penalty_repeat}, "
565 | f"penalty_freq={self.penalty_freq}, "
566 | f"penalty_present={self.penalty_present}, "
567 | f"dry_multiplier={self.dry_multiplier}, "
568 | f"dry_base={self.dry_base}, "
569 | f"dry_allowed_length={self.dry_allowed_length}, "
570 | f"dry_penalty_last_n={self.dry_penalty_last_n}, "
571 | f"mirostat={self.mirostat}, "
572 | f"top_n_sigma={self.top_n_sigma}, "
573 | f"mirostat_tau={self.mirostat_tau}, "
574 | f"mirostat_eta={self.mirostat_eta}, "
575 | f"dry_sequence_breakers={self.dry_sequence_breakers!r}, "
576 | f"logit_bias={self.logit_bias!r}"
577 | f")"
578 | )
579 |
580 | def as_dict(self) -> dict:
581 | """Returns the sampler parameters as a dictionary."""
582 | return {
583 | "seed" : self.seed,
584 | "top_k" : self.top_k,
585 | "top_p" : self.top_p,
586 | "min_p" : self.min_p,
587 | "xtc_probability" : self.xtc_probability,
588 | "xtc_threshold" : self.xtc_threshold,
589 | "typical_p" : self.typical_p,
590 | "temp" : self.temp,
591 | "dynatemp_delta" : self.dynatemp_delta,
592 | "dynatemp_exponent" : self.dynatemp_exponent,
593 | "penalty_last_n" : self.penalty_last_n,
594 | "penalty_repeat" : self.penalty_repeat,
595 | "penalty_freq" : self.penalty_freq,
596 | "penalty_present" : self.penalty_present,
597 | "dry_multiplier" : self.dry_multiplier,
598 | "dry_base" : self.dry_base,
599 | "dry_allowed_length" : self.dry_allowed_length,
600 | "dry_penalty_last_n" : self.dry_penalty_last_n,
601 | "mirostat" : self.mirostat,
602 | "top_n_sigma" : self.top_n_sigma,
603 | "mirostat_tau" : self.mirostat_tau,
604 | "mirostat_eta" : self.mirostat_eta,
605 | "dry_sequence_breakers" : self.dry_sequence_breakers,
606 | "logit_bias" : self.logit_bias
607 | }
608 |
609 | @classmethod
610 | def from_dict(cls, params_dict: dict):
611 | """Creates a SamplerPreset instance from a dictionary.
612 |
613 | Args:
614 | params_dict: A dictionary containing the sampler parameters.
615 |
616 | Returns:
617 | A new SamplerPreset instance.
618 | """
619 | # Create a copy to avoid modifying the original dictionary
620 | # and remove keys that are not constructor arguments
621 | filtered_params = {k: v for k, v in params_dict.items() if k in cls.__init__.__code__.co_varnames}
622 | return cls(**filtered_params)
623 |
624 |
625 | class SamplerPresets:
626 | """This class contains several ready-made `SamplerPreset` objects that can be used to
627 | control text generation."""
628 |
629 | Greedy = SamplerPreset(
630 | seed = 1,
631 | top_k = 1,
632 | temp = 0.0
633 | )
634 | """The most likely token is always chosen"""
635 |
636 | Default = SamplerPreset()
637 | """The default easy-llama sampler preset"""
638 |
639 | LlamaCPP = SamplerPreset(
640 | top_k = 40,
641 | top_p = 0.95,
642 | min_p = 0.05,
643 | temp = 0.8
644 | )
645 | """The default llama.cpp sampler preset"""
646 |
647 | Cool = SamplerPreset(
648 | top_k = -1,
649 | top_p = 1.0,
650 | min_p = 0.5,
651 | temp = 1.0
652 | )
653 | """The recommended easy-llama sampler preset for predictable output"""
654 |
655 | Warm = SamplerPreset(
656 | top_k = -1,
657 | top_p = 1.0,
658 | min_p = 0.1,
659 | temp = 1.5
660 | )
661 | """The recommended easy-llama sampler preset for creative yet coherent output"""
662 |
663 | Neutral = SamplerPreset(
664 | top_k = -1,
665 | top_p = 1.0,
666 | min_p = 0.0,
667 | temp = 1.0
668 | )
669 | """All samplers neutralized"""
670 |
671 | ContrastiveSearchCool = SamplerPreset(
672 | top_k = -1,
673 | top_p = 1.0,
674 | min_p = 0.0,
675 | temp = 0.4,
676 | penalty_present = 0.6,
677 | )
678 | """Constrastive Search as described in https://arxiv.org/abs/2210.14140 (less random)"""
679 |
680 | ContrastiveSearchWarm = SamplerPreset(
681 | top_k = -1,
682 | top_p = 1.0,
683 | min_p = 0.0,
684 | temp = 0.8,
685 | penalty_present = 0.6
686 | )
687 | """Constrastive Search as described in https://arxiv.org/abs/2210.14140 (more random)"""
688 |
689 | NucleusSamplingCool = SamplerPreset(
690 | top_k = -1,
691 | top_p = 0.25,
692 | min_p = 0.0,
693 | temp = 1.0
694 | )
695 | """Nucleus sampling as described in https://arxiv.org/abs/1904.09751 (less random)"""
696 |
697 | NucleusSamplingWarm = SamplerPreset(
698 | top_k = -1,
699 | top_p = 0.9,
700 | min_p = 0.0,
701 | temp = 1.0
702 | )
703 | """Nucleus sampling as described in https://arxiv.org/abs/1904.09751 (more random)"""
704 |
705 | TopNSigma = SamplerPreset(
706 | top_k = -1,
707 | top_p = 1.0,
708 | min_p = 0.0,
709 | temp = 1.0,
710 | top_n_sigma = 1.0
711 | )
712 | """Top-nσ as described on [arXiv](https://arxiv.org/pdf/2411.07641) and
713 | in [llama.cpp#11223](https://github.com/ggml-org/llama.cpp/pull/11223)"""
714 |
715 | TopNSigmaRandom = SamplerPreset(
716 | top_k = -1,
717 | top_p = 1.0,
718 | min_p = 0.0,
719 | temp = 9999.9,
720 | top_n_sigma = 1.0
721 | )
722 | """Top-nσ as described on [arXiv](https://arxiv.org/pdf/2411.07641) and
723 | in [llama.cpp#11223](https://github.com/ggml-org/llama.cpp/pull/11223), except that
724 | `temp = 9999.9` to randomly select any token that is determined to be valid
725 | by Top-nσ."""
726 |
727 | DRY = SamplerPreset(dry_multiplier=0.8, dry_base=1.75, dry_allowed_length=2)
728 | """https://github.com/oobabooga/text-generation-webui/pull/5677"""
729 |
730 | XTC = SamplerPreset(
731 | top_k=-1,
732 | top_p=1.0,
733 | min_p=0.02,
734 | xtc_probability=0.5,
735 | xtc_threshold=0.1
736 | )
737 | """https://github.com/oobabooga/text-generation-webui/pull/6335"""
738 |
739 | #
740 | # Samplers below this line are for specific models / model families
741 | #
742 |
743 | Llama3 = SamplerPreset(
744 | top_k = -1,
745 | top_p = 0.9,
746 | min_p = 0.0,
747 | temp = 0.6
748 | )
749 | """[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/)"""
750 |
751 | Llama3Classic = SamplerPreset(
752 | top_k = 40,
753 | top_p = 0.95,
754 | min_p = 0.05,
755 | temp = 0.65
756 | )
757 | """Unofficial preset based on the developer's personal preference"""
758 |
759 | Llama3Cool = SamplerPreset(
760 | top_k = -1,
761 | top_p = 0.9,
762 | min_p = 0.0,
763 | temp = 0.45
764 | )
765 | """Llama3 preset with reduced temperature (less random)"""
766 |
767 | Llama3Warm = SamplerPreset(
768 | top_k = -1,
769 | top_p = 0.9,
770 | min_p = 0.0,
771 | temp = 1.2
772 | )
773 | """Llama3 preset with increased temperature (more random)"""
774 |
775 | Mistral = SamplerPreset(
776 | temp = 0.3
777 | )
778 | """Mistral models tend to require a lower temperature"""
779 |
780 | Qwen2_5Official = SamplerPreset(
781 | top_k = 20,
782 | top_p = 0.8,
783 | min_p = 0.0,
784 | temp = 0.7,
785 | penalty_repeat = 1.05
786 | )
787 | """[Qwen/Qwen2.5-14B-Instruct/](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct/)
788 | (official, but not recommended)"""
789 |
790 | Qwen2_5Recommended = SamplerPreset(
791 | top_k = -1,
792 | top_p = 0.9,
793 | min_p = 0.1,
794 | temp = 1.1
795 | )
796 | """[Qwen/Qwen2.5-14B-Instruct/](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct/)
797 | (unofficial, but recommended)"""
798 |
799 | Qwen3Thinking = SamplerPreset(
800 | top_k = 20,
801 | top_p = 0.95,
802 | min_p = 0.0,
803 | temp = 0.6
804 | )
805 | """[Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/README.md)"""
806 |
807 | Qwen3NoThinking = SamplerPreset(
808 | top_k = 20,
809 | top_p = 0.8,
810 | min_p = 0.0,
811 | temp = 0.7
812 | )
813 | """[Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/README.md)"""
814 |
--------------------------------------------------------------------------------
/easy_llama/server.py:
--------------------------------------------------------------------------------
1 | # server.py
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | """The easy-llama FastAPI server, including an API endpoint and a simple WebUI"""
6 |
7 | # NOTE: This module is WIP.
8 |
9 | import os
10 | import uvicorn
11 |
12 | import easy_llama as ez
13 |
14 | from pydantic import BaseModel
15 | from fastapi.staticfiles import StaticFiles
16 | from fastapi.responses import FileResponse
17 | from fastapi.middleware.cors import CORSMiddleware
18 | from easy_llama.utils import assert_type, log, ANSI
19 | from typing import Optional, Union, Literal
20 | from fastapi import FastAPI, APIRouter, Body, HTTPException
21 |
22 |
23 | WEBUI_DIRECTORY = os.path.join(os.path.dirname(__file__), 'webui')
24 |
25 | STATUS_RESPONSE_SUCCESS = {'success': True}
26 | STATUS_RESPONSE_FAILURE = {'success': False}
27 |
28 | Y = ANSI.FG_BRIGHT_YELLOW
29 | R = ANSI.MODE_RESET_ALL
30 |
31 | #
32 | # Pydantic models for FastAPI
33 | #
34 |
35 | class StatusResponseModel(BaseModel):
36 | success: bool
37 |
38 | class MessageModel(BaseModel):
39 | role: str
40 | content: str
41 |
42 | class SendMessageModel(BaseModel):
43 | role: str
44 | content: str
45 | n_predict: Optional[int] = -1
46 |
47 | class MessageResponseModel(BaseModel):
48 | role: str
49 | content: str
50 | n_input_tokens: int
51 | n_output_tokens: int
52 |
53 | class SetSystemPromptRequestModel(BaseModel):
54 | content: str
55 |
56 | class SummaryResponseModel(BaseModel):
57 | summary: str
58 |
59 | class InfoResponseModel(BaseModel):
60 | model_name: str
61 | model_n_params: int
62 | model_bpw: float
63 | n_tokens: int
64 | n_ctx: int
65 | n_ctx_train: int
66 | n_thread_tokens: int
67 | n_thread_messages: int
68 |
69 | class SamplerSettingsModel(BaseModel):
70 | seed: Optional[int]
71 | top_k: Optional[int]
72 | top_p: Optional[float]
73 | min_p: Optional[float]
74 | xtc_probability: Optional[float]
75 | xtc_threshold: Optional[float]
76 | typical_p: Optional[float]
77 | temp: Optional[float]
78 | dynatemp_delta: Optional[float]
79 | dynatemp_exponent: Optional[float]
80 | penalty_last_n: Optional[int]
81 | penalty_repeat: Optional[float]
82 | penalty_freq: Optional[float]
83 | penalty_present: Optional[float]
84 | dry_multiplier: Optional[float]
85 | dry_base: Optional[float]
86 | dry_allowed_length: Optional[int]
87 | dry_penalty_last_n: Optional[int]
88 | mirostat: Optional[int]
89 | top_n_sigma: Optional[float]
90 | mirostat_tau: Optional[float]
91 | mirostat_eta: Optional[float]
92 | dry_sequence_breakers: Optional[list[str]]
93 | logit_bias: Optional[dict[int, float]]
94 |
95 |
96 | class Server:
97 | """The easy-llama FastAPI server, providing a WebUI and an API router"""
98 |
99 | def log(self, text: str, level: Literal[1,2,3,4] = 1) -> None:
100 | log(f'ez.Server @ {Y}{self._host}{R}:{Y}{self._port}{R} {text}', level=level)
101 |
102 | def __init__(
103 | self,
104 | thread: 'ez.Thread',
105 | host: str = "127.0.0.1",
106 | port: int = 8080
107 | ):
108 | assert_type(thread, getattr(ez, 'Thread'), 'thread', 'Server.__init__')
109 | self._thread = thread
110 | self._host = host
111 | self._port = port
112 | self._app = FastAPI(title=f"ez.Server @ {host}:{port}")
113 | self._router = APIRouter(prefix="/api")
114 |
115 | self._setup_api_routes()
116 |
117 | self._app.add_middleware(
118 | CORSMiddleware,
119 | allow_origins=["*"],
120 | allow_credentials=True,
121 | allow_methods=["*"],
122 | allow_headers=["*"],
123 | )
124 |
125 | self._app.include_router(self._router)
126 |
127 | if os.path.isdir(WEBUI_DIRECTORY):
128 | self._app.mount("/static", StaticFiles(directory=WEBUI_DIRECTORY), name="static")
129 |
130 | @self._app.get("/", include_in_schema=False)
131 | async def read_index():
132 | index_path = os.path.join(WEBUI_DIRECTORY, 'index.html')
133 | if os.path.exists(index_path):
134 | return FileResponse(index_path)
135 | else:
136 | # fallback if index.html is missing but directory exists
137 | self.log('index.html not found', 3)
138 | raise HTTPException(status_code=404, detail="index.html not found")
139 |
140 | else:
141 | self.log(f"WebUI directory not found at {WEBUI_DIRECTORY}", level=2)
142 | @self._app.get("/", include_in_schema=False)
143 | async def root_fallback():
144 | # provide a message if WebUI isn't found
145 | return {"message": "easy-llama server is running. WebUI not found."}
146 |
147 |
148 | def _setup_api_routes(self):
149 | """Define all API endpoints."""
150 | router = self._router
151 |
152 | @router.get("/ping", response_model=StatusResponseModel)
153 | async def ping() -> dict:
154 | """Check if the server is running."""
155 | self.log('pong')
156 | return STATUS_RESPONSE_SUCCESS
157 |
158 | @router.post("/send", response_model=MessageResponseModel)
159 | async def send(message: SendMessageModel = Body(...)) -> dict[str, Union[str, int]]:
160 | """Send a user message and get the bot's response.
161 | Adds both messages to the thread history."""
162 |
163 | try:
164 | self._thread.add_message(message.role, message.content)
165 |
166 | except ValueError as exc:
167 | self.log(f"/send: error adding user message: {exc}", 3)
168 | raise HTTPException(status_code=400, detail=str(exc))
169 |
170 | input_toks = self._thread.get_input_ids(role='bot')
171 |
172 | try:
173 | response_toks = self._thread.llama.generate(
174 | input_tokens=input_toks,
175 | n_predict=message.n_predict if message.n_predict is not None else -1,
176 | stop_tokens=self._thread._stop_tokens,
177 | sampler_preset=self._thread.sampler_preset
178 | )
179 | response_txt = self._thread.llama.detokenize(response_toks, special=False)
180 |
181 | except ez.llama.ExceededContextLengthException as exc:
182 | self.log(f"ExceededContextLengthException: {exc}", 3)
183 | raise HTTPException(status_code=413, detail=str(exc))
184 |
185 | except Exception as exc:
186 | self.log(f"Error during generation: {type(exc).__name__}: {exc}", level=3)
187 | raise HTTPException(status_code=500, detail=str(exc))
188 |
189 | # Add bot response to history
190 | self._thread.add_message('bot', response_txt)
191 |
192 | return {
193 | 'role': 'bot',
194 | 'content': response_txt,
195 | 'n_input_tokens': len(input_toks), # Tokens processed for this turn
196 | 'n_output_tokens': len(response_toks) # Tokens generated this turn
197 | }
198 |
199 | @router.post("/add_message", response_model=StatusResponseModel, tags=["Chat"])
200 | async def add_message(message: MessageModel = Body(...)) -> dict:
201 | """
202 | Add a message to the thread history without triggering a bot response.
203 | Useful for manually setting up conversation state.
204 | """
205 | self.log(f"/add_message request: role='{message.role}', content='{message.content[:50]}...'")
206 | try:
207 | self._thread.add_message(message.role, message.content)
208 | return STATUS_RESPONSE_SUCCESS
209 | except ValueError as exc:
210 | self.log(f'Failed to add message: {exc}', 3)
211 | raise HTTPException(status_code=400, detail=str(exc))
212 |
213 | @router.post("/set_system_prompt", response_model=StatusResponseModel, tags=["Chat"])
214 | async def set_system_prompt(request: SetSystemPromptRequestModel = Body(...)) -> dict:
215 | """
216 | Set or update the system prompt (the first message if it has a system role).
217 | If no system prompt exists, it will be prepended.
218 | """
219 | content = request.content
220 | new_sys_msg = {'role': 'system', 'content': content}
221 | messages = self._thread.messages
222 |
223 | if messages and messages[0]['role'].lower() in self._thread.valid_system_roles:
224 | self.log("Updating existing system prompt.")
225 | messages[0]['content'] = content
226 | else:
227 | self.log("Prepending new system prompt.")
228 | messages.insert(0, new_sys_msg)
229 |
230 | orig_messages = self._thread._orig_messages
231 | if orig_messages and orig_messages[0]['role'].lower() in self._thread.valid_system_roles:
232 | orig_messages[0]['content'] = content
233 | else:
234 | orig_messages.insert(0, new_sys_msg)
235 |
236 | return STATUS_RESPONSE_SUCCESS
237 |
238 | @router.post("/trigger", response_model=MessageResponseModel, tags=["Chat"])
239 | async def trigger() -> dict[str, Union[str, int]]:
240 | """Trigger the bot to generate a response based on the current history,
241 | without adding a user message first. Appends the bot's response."""
242 | self.log("/trigger request received")
243 | if not self._thread.messages:
244 | # TODO: this is incorrect
245 | self.log("Cannot trigger response: No messages in history.", level=2)
246 | raise HTTPException(status_code=400, detail="Cannot trigger response with empty history")
247 |
248 | last_message_role = self._thread.messages[-1]['role'].lower()
249 | if last_message_role in self._thread.valid_bot_roles:
250 | self.log("Last message was from bot, triggering another bot response.", level=1)
251 | # Allow triggering even if last was bot, might be desired sometimes
252 | elif last_message_role in self._thread.valid_system_roles:
253 | self.log("Last message was system, triggering bot response.", level=1)
254 | elif last_message_role in self._thread.valid_user_roles:
255 | self.log("Last message was user, triggering bot response.", level=1)
256 |
257 | # prepare for generation
258 |
259 | try:
260 | input_toks = self._thread.get_input_ids(role='bot')
261 | except ez.llama.ExceededContextLengthException as exc:
262 | self.log(f"Context length exceeded before generation: {exc}", level=3)
263 | raise HTTPException(status_code=413, detail=f"Input context too long: {exc}")
264 | except ValueError as exc: # Handle potential errors in get_input_ids
265 | self.log(f"Error getting input IDs: {exc}", level=3)
266 | raise HTTPException(status_code=500, detail=f"Internal error preparing generation: {exc}")
267 |
268 | # generate response
269 |
270 | try:
271 | response_toks = self._thread.llama.generate(
272 | input_tokens=input_toks,
273 | n_predict=-1,
274 | stop_tokens=self._thread._stop_tokens,
275 | sampler_preset=self._thread.sampler_preset
276 | )
277 | response_txt = self._thread.llama.detokenize(response_toks, special=False).strip()
278 | self.log(f"Generated {len(response_toks)} tokens: '{response_txt[:50]}...'")
279 | except ez.llama.ExceededContextLengthException as exc:
280 | self.log(f"Context length exceeded during generation: {exc}", level=3)
281 | raise HTTPException(status_code=413, detail=f"Context length exceeded during generation: {exc}")
282 | except Exception as exc:
283 | self.log(f"Error during generation: {type(exc).__name__}: {exc}", level=3)
284 | raise HTTPException(status_code=500, detail=f"Generation failed: {exc}")
285 |
286 | # Add bot response to history
287 | self._thread.add_message('bot', response_txt)
288 |
289 | return {
290 | 'role': 'bot',
291 | 'content': response_txt,
292 | 'n_input_tokens': len(input_toks),
293 | 'n_output_tokens': len(response_toks)
294 | }
295 |
296 |
297 | @router.get("/messages", response_model=list[MessageModel], tags=["Chat"])
298 | async def messages() -> list[dict]:
299 | """Get the current list of messages in the thread history."""
300 | self.log("/messages request received")
301 | return self._thread.messages
302 |
303 |
304 | @router.get("/summarize", response_model=SummaryResponseModel, tags=["Chat"])
305 | async def summarize() -> dict[str, str]:
306 | """
307 | Generate a summary of the current chat thread.
308 | Note: This performs inference and modifies the Llama context temporarily.
309 | """
310 | self.log("/summarize request received")
311 | try:
312 | summary_text = self._thread.summarize()
313 | self.log(f"Generated summary: '{summary_text[:50]}...'")
314 | return {"summary": summary_text}
315 | except Exception as exc:
316 | self.log(f"Error during summarization: {type(exc).__name__}: {exc}", level=3)
317 | raise HTTPException(status_code=500, detail=f"Summarization failed: {exc}")
318 |
319 |
320 | @router.post("/cancel", response_model=StatusResponseModel, tags=["Control"])
321 | async def cancel() -> dict:
322 | """(Not Implemented) Attempt to cancel ongoing generation."""
323 | # NOTE: llama.cpp doesn't have a built-in robust way to interrupt
324 | # llama_decode mid-computation from Python easily without potential
325 | # instability. This requires more complex handling (e.g., abort callbacks).
326 | self.log('Endpoint `/api/cancel` is not implemented yet!', 2)
327 | # return STATUS_RESPONSE_FAILURE
328 | raise HTTPException(status_code=501, detail="Cancel operation not implemented")
329 |
330 |
331 | @router.post("/reset", response_model=StatusResponseModel, tags=["Control"])
332 | async def reset() -> dict:
333 | """Reset the chat thread to its initial state (usually just the system prompt)."""
334 | self.log("/reset request received")
335 | try:
336 | self._thread.reset()
337 | self.log("Thread and Llama context reset successfully.")
338 | return STATUS_RESPONSE_SUCCESS
339 | except Exception as exc:
340 | self.log(f"Error during reset: {type(exc).__name__}: {exc}", level=3)
341 | raise HTTPException(status_code=500, detail=f"Reset failed: {exc}")
342 |
343 |
344 | @router.get("/info", response_model=InfoResponseModel, tags=["Status"])
345 | async def get_info() -> dict[str, Union[str, int, float]]:
346 | """Get information about the loaded model and current thread state."""
347 | self.log("/info request received")
348 | # Use suppress_output if get_input_ids might log verbosely
349 | # with ez.utils.suppress_output(disable=ez.get_verbose()):
350 | input_ids = self._thread.get_input_ids(role=None) # Get tokens just for counting
351 |
352 | return {
353 | 'model_name': self._thread.llama.name(),
354 | 'model_n_params': self._thread.llama.n_params(),
355 | 'model_bpw': self._thread.llama.bpw(),
356 | 'n_tokens': self._thread.llama.pos,
357 | 'n_ctx': self._thread.llama.n_ctx(),
358 | 'n_ctx_train': self._thread.llama.n_ctx_train(),
359 | 'n_thread_tokens': len(input_ids),
360 | 'n_thread_messages': len(self._thread.messages)
361 | }
362 |
363 | @router.post("/sampler", response_model=SamplerSettingsModel, tags=["Sampling"])
364 | async def set_sampler(settings: SamplerSettingsModel = Body(...)) -> dict:
365 | """
366 | Update the sampler settings for subsequent generations.
367 | Returns the complete current settings after applying the update.
368 | """
369 | self.log(f"/sampler POST request received with updates: {settings.model_dump(exclude_unset=True)}")
370 | current = self._thread.sampler_preset
371 | update_data = settings.model_dump(exclude_unset=True) # Get only provided fields
372 |
373 | # Create a dictionary from the current preset
374 | current_data = current.as_dict()
375 |
376 | # Update the current data with the new values
377 | current_data.update(update_data)
378 |
379 | # Create a new preset from the merged data
380 | try:
381 | new_preset = ez.SamplerPreset(**current_data)
382 | # Update the thread's active sampler preset
383 | self._thread.sampler_preset = new_preset
384 | self.log(f"Sampler settings updated. Current: {new_preset}")
385 | return new_preset.as_dict() # Return the full new settings
386 | except Exception as exc: # Catch potential validation errors in SamplerPreset
387 | self.log(f"Error updating sampler settings: {exc}", level=3)
388 | raise HTTPException(status_code=400, detail=f"Invalid sampler settings: {exc}")
389 |
390 |
391 | @router.get("/sampler", response_model=SamplerSettingsModel, tags=["Sampling"])
392 | async def get_sampler() -> dict:
393 | """Get the current sampler settings."""
394 | self.log("/sampler GET request received")
395 | return self._thread.sampler_preset.as_dict()
396 |
397 |
398 | # --- Server Start Method ---
399 | def start(self):
400 | """Start the Uvicorn server."""
401 | try:
402 | uvicorn.run(
403 | app=self._app,
404 | host=self._host,
405 | port=self._port
406 | )
407 | except KeyboardInterrupt:
408 | self.log('KeyboardInterrupt')
409 | except Exception as exc:
410 | self.log(f'Server crashed: {type(exc).__name__}: {exc}', 3)
411 | # Potentially re-raise or handle specific exceptions (e.g., port in use)
412 | raise exc
413 | finally:
414 | self.log("Server shutdown complete.")
415 |
--------------------------------------------------------------------------------
/easy_llama/thread.py:
--------------------------------------------------------------------------------
1 | # thread.py
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | """This file provides functionality for multi-turn conversations with Llama models."""
6 |
7 | # TODO: jinja2
8 |
9 | import sys
10 | import jinja2
11 |
12 | from typing import Optional
13 | from .formats import PromptFormat
14 | from .sampling import SamplerPreset
15 | from .utils import (
16 | _SupportsWriteAndFlush, ANSI, log, assert_type, ez_encode, ez_decode, suppress_output,
17 | KeyboardInterruptHandler
18 | )
19 |
20 | from . import llama as _llama # avoid confusion with Thread.llama attribute
21 |
22 |
23 | class Thread:
24 |
25 | valid_system_roles = ['system', 'developer' ]
26 | valid_user_roles = ['user', 'human' ]
27 | valid_bot_roles = ['bot', 'assistant', 'model', 'gpt', 'llama']
28 |
29 | all_valid_roles = valid_system_roles + valid_user_roles + valid_bot_roles
30 |
31 | def __init__(
32 | self,
33 | llama: _llama.Llama,
34 | prompt_format: PromptFormat,
35 | sampler_preset: Optional[SamplerPreset] = None
36 | ) -> None:
37 |
38 | assert_type(llama, _llama.Llama, 'llama', 'Thread.__init__')
39 | assert_type(prompt_format, PromptFormat, 'prompt_format', 'Thread.__init__')
40 |
41 | llama._validate_model_state()
42 |
43 | self.llama = llama
44 | self.prompt_format = prompt_format
45 | self.sampler_preset = sampler_preset if sampler_preset is not None else SamplerPreset()
46 |
47 | self.messages: list[dict[str, str]] = []
48 |
49 | system_prompt = prompt_format.system_prompt()
50 |
51 | if system_prompt != '':
52 | self.messages.append({
53 | 'role': 'system',
54 | 'content': system_prompt
55 | })
56 |
57 | # stop tokens
58 | format_stops = self.prompt_format.stop_tokens()
59 | self._stop_tokens = format_stops if format_stops is not None else self.llama.eog_tokens
60 |
61 | # save the original messages for self.reset()
62 | self._orig_messages = self.messages.copy()
63 |
64 | # save the sampler_preset param for repr
65 | self._sampler_preset = self.sampler_preset
66 |
67 | def __repr__(self) -> str:
68 | return (
69 | f"Thread("
70 | f"llama={self.llama!r}, "
71 | f"prompt_format={self.prompt_format!r}, "
72 | f"sampler_preset={self._sampler_preset!r}"
73 | f")"
74 | )
75 |
76 | def _messages_in_jinja_format(self) -> list[dict]:
77 | jinja_messages = []
78 | for message in self.messages:
79 | try:
80 | role = message['role']
81 | except KeyError:
82 | log(f'_messages_in_jinja_format: skipping message with no role!', 2)
83 | continue
84 | try:
85 | content = message['content']
86 | except KeyError:
87 | log(f'_messages_in_jinja_format: skipping message with no content!', 2)
88 | continue
89 | if role in Thread.valid_system_roles:
90 | jinja_messages.append({'role': 'system', 'content': content})
91 | elif role in Thread.valid_user_roles:
92 | jinja_messages.append({'role': 'user', 'content': content})
93 | elif role in Thread.valid_bot_roles:
94 | jinja_messages.append({'role': 'assistant', 'content': content})
95 | else:
96 | log(
97 | f'_messages_in_jinja_format: skipping message with invalid role {role!r}!',
98 | 2
99 | )
100 | return jinja_messages
101 |
102 | def _render_messages(
103 | self,
104 | add_generation_prompt: bool = True,
105 | **kwargs
106 | ) -> str:
107 | """Render the Jinja template with current messages"""
108 | try:
109 | template = jinja2.Template(
110 | source=self.llama._chat_template,
111 | undefined=jinja2.StrictUndefined
112 | )
113 | context = {
114 | 'messages': self._messages_in_jinja_format(),
115 | 'add_generation_prompt': add_generation_prompt,
116 | **kwargs
117 | }
118 | return template.render(context)
119 | except jinja2.exceptions.TemplateSyntaxError as e:
120 | log(f"_render_messages: invalid chat template syntax: {e}", 3)
121 | raise ValueError(f"_render_messages: invalid chat template syntax: {e}") from e
122 | except jinja2.exceptions.UndefinedError as e:
123 | log(f"_render_messages: missing template variable: {e}", 3)
124 | raise ValueError(f"_render_messages: missing template variable: {e}") from e
125 | except Exception as e:
126 | log(f"_render_messages: template rendering error: {e}", 3)
127 | raise ValueError(f"_render_messages: template rendering error: {e}") from e
128 |
129 | def get_input_ids(self, role: Optional[str] = 'bot') -> list[int]:
130 | """Get a list of token IDs in this thread, to be used for inference
131 |
132 | - role:
133 | The role for which inference will be performed (usually 'bot'). Can be 'system',
134 | 'user', 'bot', or None. If None, no role prefix will be appended (this is useful
135 | when you just want to get all the tokens in this Thread but are not going to do
136 | inference)."""
137 |
138 | if role is None and len(self.messages) == 0:
139 | if self.llama.add_bos_token():
140 | return [self.llama.token_bos()]
141 | else:
142 | return []
143 |
144 | input_ids = []
145 | if len(self.messages) > 0:
146 | # the prefix of the first message requires `add_special=True` in order to set
147 | # the BOS token correctly
148 | first_msg = self.messages[0]
149 | if first_msg['role'].lower() in Thread.valid_system_roles:
150 | input_ids.extend(self.llama.tokenize(
151 | text_bytes=ez_encode(self.prompt_format.system_prefix()),
152 | add_special=True,
153 | parse_special=True
154 | ))
155 | input_ids.extend(self.llama.tokenize(
156 | text_bytes=ez_encode(first_msg['content']),
157 | add_special=False,
158 | parse_special=False
159 | ))
160 | input_ids.extend(self.llama.tokenize(
161 | text_bytes=ez_encode(self.prompt_format.system_suffix()),
162 | add_special=False,
163 | parse_special=True
164 | ))
165 |
166 | elif first_msg['role'].lower() in Thread.valid_user_roles:
167 | input_ids.extend(self.llama.tokenize(
168 | text_bytes=ez_encode(self.prompt_format.user_prefix()),
169 | add_special=True,
170 | parse_special=True
171 | ))
172 | input_ids.extend(self.llama.tokenize(
173 | text_bytes=ez_encode(first_msg['content']),
174 | add_special=False,
175 | parse_special=False
176 | ))
177 | input_ids.extend(self.llama.tokenize(
178 | text_bytes=ez_encode(self.prompt_format.user_suffix()),
179 | add_special=False,
180 | parse_special=True
181 | ))
182 |
183 | elif first_msg['role'].lower() in Thread.valid_bot_roles:
184 | input_ids.extend(self.llama.tokenize(
185 | text_bytes=ez_encode(self.prompt_format.bot_prefix()),
186 | add_special=True,
187 | parse_special=True
188 | ))
189 | input_ids.extend(self.llama.tokenize(
190 | text_bytes=ez_encode(first_msg['content']),
191 | add_special=False,
192 | parse_special=False
193 | ))
194 | input_ids.extend(self.llama.tokenize(
195 | text_bytes=ez_encode(self.prompt_format.bot_suffix()),
196 | add_special=False,
197 | parse_special=True
198 | ))
199 |
200 | else:
201 | raise ValueError(
202 | f'Thread.get_input_ids: first message has invalid role {role!r}'
203 | )
204 |
205 | # all the other messages are treated the same
206 | i = 0
207 | for msg in self.messages[1:]:
208 | i += 1
209 | if msg['role'].lower() in Thread.valid_system_roles:
210 | raise ValueError(
211 | f'Thread.get_input_ids: multiple system messages are not supported'
212 | )
213 | elif msg['role'].lower() in Thread.valid_user_roles:
214 | input_ids.extend(self.llama.tokenize(
215 | text_bytes=ez_encode(self.prompt_format.user_prefix()),
216 | add_special=False,
217 | parse_special=True
218 | ))
219 | input_ids.extend(self.llama.tokenize(
220 | text_bytes=ez_encode(msg['content']),
221 | add_special=False,
222 | parse_special=False
223 | ))
224 | input_ids.extend(self.llama.tokenize(
225 | text_bytes=ez_encode(self.prompt_format.user_suffix()),
226 | add_special=False,
227 | parse_special=True
228 | ))
229 | elif msg['role'].lower() in Thread.valid_bot_roles:
230 | input_ids.extend(self.llama.tokenize(
231 | text_bytes=ez_encode(self.prompt_format.bot_prefix()),
232 | add_special=False,
233 | parse_special=True
234 | ))
235 | input_ids.extend(self.llama.tokenize(
236 | text_bytes=ez_encode(msg['content']),
237 | add_special=False,
238 | parse_special=False
239 | ))
240 | input_ids.extend(self.llama.tokenize(
241 | text_bytes=ez_encode(self.prompt_format.bot_suffix()),
242 | add_special=False,
243 | parse_special=True
244 | ))
245 | else:
246 | raise ValueError(
247 | f'Thread.get_input_ids: message {i} has invalid role {role!r}'
248 | )
249 |
250 | if role is not None:
251 | # append the role prefix tokens to the end
252 | # (if role is None, no prefix is appended)
253 | if role.lower() in Thread.valid_system_roles:
254 | raise ValueError(
255 | f'Thread.get_input_ids: multiple system messages are not supported'
256 | )
257 | elif role.lower() in Thread.valid_user_roles:
258 | input_ids.extend(self.llama.tokenize(
259 | text_bytes=ez_encode(self.prompt_format.user_prefix()),
260 | add_special=False,
261 | parse_special=True
262 | ))
263 | elif role.lower() in Thread.valid_bot_roles:
264 | input_ids.extend(self.llama.tokenize(
265 | text_bytes=ez_encode(self.prompt_format.bot_prefix()),
266 | add_special=False,
267 | parse_special=True
268 | ))
269 | else:
270 | raise ValueError(f'Thread.get_input_ids: invalid role {role!r}')
271 |
272 | # input_ids is now fully constructed
273 | n_input_ids = len(input_ids)
274 |
275 | _llama.log_verbose(
276 | f'Thread.get_input_ids: converted {len(self.messages)} messages to '
277 | f'{n_input_ids} tokens'
278 | )
279 |
280 | if n_input_ids >= self.llama._n_ctx:
281 | log(
282 | f'Thread.get_input_ids: length of input_ids {n_input_ids} '
283 | f'equals or exceeds the current context length '
284 | f'{self.llama._n_ctx}', 2
285 | )
286 |
287 | return input_ids
288 |
289 | def send(self, content: str, n_predict: Optional[int] = None) -> str:
290 | """Send a message in this thread and return the generated response. This adds your
291 | message and the bot's message to the thread."""
292 | self.messages.append({
293 | 'role': 'user',
294 | 'content': content
295 | })
296 | response_toks = self.llama.generate(
297 | input_tokens=self.get_input_ids(role='bot'),
298 | n_predict=n_predict if n_predict is not None else -1,
299 | stop_tokens=self._stop_tokens,
300 | sampler_preset=self.sampler_preset
301 | )
302 | response_txt = self.llama.detokenize(response_toks, special=False)
303 | self.messages.append({
304 | 'role': 'bot',
305 | 'content': response_txt
306 | })
307 | return response_txt
308 |
309 | def as_string(self) -> str:
310 | """Return this thread's message history as a string"""
311 | result_str = ''
312 | for msg in self.messages:
313 | if msg['role'].lower() in Thread.valid_system_roles:
314 | result_str += ''.join([
315 | self.prompt_format.system_prefix(),
316 | msg['content'],
317 | self.prompt_format.system_suffix()
318 | ])
319 | elif msg['role'].lower() in Thread.valid_user_roles:
320 | result_str += ''.join([
321 | self.prompt_format.user_prefix(),
322 | msg['content'],
323 | self.prompt_format.user_suffix()
324 | ])
325 | elif msg['role'].lower() in Thread.valid_bot_roles:
326 | result_str += ''.join([
327 | self.prompt_format.bot_prefix(),
328 | msg['content'],
329 | self.prompt_format.bot_suffix()
330 | ])
331 | else:
332 | raise ValueError(f"Thread.as_string: invalid message role {msg['role']!r}")
333 | return result_str
334 |
335 | def add_message(self, role: str, content: str) -> None:
336 | """Append a message to `Thread.messages` with the specified role and content
337 |
338 | - role:
339 | The role of the message, for example 'system', 'user', or 'bot'.
340 | - content:
341 | The text content of the message."""
342 | if role.lower() in Thread.valid_system_roles:
343 | self.messages.append({'role': 'system', 'content': content})
344 | elif role.lower() in Thread.valid_user_roles:
345 | self.messages.append({'role': 'user', 'content': content})
346 | elif role.lower() in Thread.valid_bot_roles:
347 | self.messages.append({'role': 'bot', 'content': content})
348 | else:
349 | raise ValueError(f'Thread.add_message: invalid role {role!r}')
350 |
351 | def warmup(self) -> None:
352 | input_ids = self.get_input_ids()
353 | if self.llama._first_valid_pos(input_ids) < len(input_ids):
354 | _llama.log_verbose('Thread.warmup: processing thread content with model ...')
355 | with suppress_output(disable=_llama.get_verbose()):
356 | self.llama.generate(input_tokens=input_ids, n_predict=0)
357 |
358 | # if the above condition is not True, the thread is already in the cache, so
359 | # nothing needs to be done
360 | _llama.log_verbose('Thread.warmup: done')
361 |
362 | def interact(self, stream: bool = True) -> None:
363 | """Start an interactive terminal-based chat using this thread"""
364 | R = ANSI.MODE_RESET_ALL
365 | B = ANSI.FG_BRIGHT_CYAN
366 | G = ANSI.FG_BRIGHT_GREEN
367 | with KeyboardInterruptHandler():
368 | print()
369 | while True:
370 | user_input = input(f'{R} > {G}')
371 | print(R, end='\n', flush=True)
372 |
373 | if stream:
374 | self.messages.append({'role': 'user', 'content': user_input})
375 | input_ids = self.get_input_ids()
376 |
377 | tok_gen = self.llama.stream(
378 | input_tokens=input_ids,
379 | n_predict=-1,
380 | stop_tokens=self._stop_tokens,
381 | sampler_preset=self.sampler_preset
382 | )
383 |
384 | response_toks = []
385 | detok_bytes_buffer = b''
386 |
387 | for tok in tok_gen:
388 | response_toks.append(tok)
389 | #
390 | # detok_bytes_buffer holds any incomplete UTF-8 characters until they
391 | # are completed by future tokens
392 | #
393 | # for example, emojis are often split between two tokens, with one or
394 | # both of those tokens not being valid UTF-8 on its own
395 | #
396 | detok_bytes_buffer += self.llama.token_to_piece(tok, special=False)
397 | try:
398 | detok_txt = detok_bytes_buffer.decode('utf-8', errors='strict')
399 | except UnicodeDecodeError:
400 | pass # try again on next token
401 | else:
402 | detok_bytes_buffer = b''
403 | print(f'{B}{detok_txt}{R}', end='', flush=True)
404 |
405 | # print any leftover bytes (though ideally there should be none)
406 | if detok_bytes_buffer != b'':
407 | leftover_txt = ez_decode(detok_bytes_buffer)
408 | print(f'{B}{leftover_txt}{R}', end='', flush=True)
409 |
410 | self.messages.append({
411 | 'role': 'bot',
412 | 'content': self.llama.detokenize(response_toks, special=False)
413 | })
414 |
415 | print()
416 | if not _llama.get_verbose():
417 | print()
418 |
419 | else:
420 | response = self.send(user_input)
421 | print(f'\n{B}{response}{R}\n')
422 |
423 | def give_input_output_examples(self, examples: dict[str, str]) -> None:
424 | """Provide examples for few-shot prompting""" # TODO: this should be renamed or removed
425 | for input_msg_content, output_msg_content in examples.items():
426 | self.add_message('user', input_msg_content)
427 | self.add_message('bot', output_msg_content)
428 |
429 | def summarize(self) -> str:
430 | """Generate a summary of this thread"""
431 | thread_as_string = self.as_string()
432 | orig_thread_messages = self.messages.copy()
433 | self.messages = [
434 | {
435 | 'role': 'system',
436 | 'content': 'Follow the given instructions exactly. Do not add any unnecessary '
437 | 'information.'
438 | },
439 | {
440 | 'role': 'user',
441 | 'content': 'Take a moment to read through the following conversation '
442 | 'carefully. When you\'re done, write a single paragraph that '
443 | 'explains all of the most relevant details.'
444 | f'\n\n```\n{thread_as_string}\n```\n\n'
445 | 'Now that you\'ve read the above conversation, provide a summary '
446 | 'in the form of a single paragraph.'
447 | }
448 | ]
449 | input_ids = self.get_input_ids() # uses the above messages
450 | output_ids = self.llama.generate(input_tokens=input_ids, n_predict=300)
451 | summary = self.llama.detokenize(output_ids, special=False)
452 | self.messages = orig_thread_messages.copy()
453 | return summary
454 |
455 | def print_stats(self, file: _SupportsWriteAndFlush = sys.stderr) -> None:
456 | """Print stats about the context usage in this thread"""
457 | with suppress_output():
458 | input_ids = self.get_input_ids(role=None)
459 | n_thread_tokens = len(input_ids)
460 | n_msgs = len(self.messages)
461 | n_ctx = self.llama._n_ctx
462 | c = (n_thread_tokens/n_ctx) * 100
463 | ctx_used_pct = int(c) + (c > int(c)) # round up to next integer
464 | print(f"{n_thread_tokens} / {n_ctx} tokens", file=file)
465 | print(f"{ctx_used_pct}% of context used", file=file)
466 | print(f"{n_msgs} messages", file=file)
467 |
468 | def reset(self) -> None:
469 | self.messages = self._orig_messages.copy()
470 |
--------------------------------------------------------------------------------
/easy_llama/utils.py:
--------------------------------------------------------------------------------
1 | # utils.py
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | """Submodule containing convenience functions used throughout easy_llama"""
6 |
7 | from . import __version__
8 |
9 | import os
10 | import sys
11 | import datetime
12 | import contextlib
13 |
14 | import numpy as np
15 |
16 | from typing import Iterable, TextIO, Optional, TypeVar, Generic, NoReturn, Literal, Union
17 |
18 | class ANSI:
19 | """ANSI codes for terminal emulators"""
20 |
21 | BELL = '\a'
22 |
23 | CLEAR = '\x1bc\x1b[3J' # technically this is two ANSI codes in one
24 |
25 | # Standard colors
26 | FG_BLACK = '\x1b[30m'
27 | BG_BLACK = '\x1b[40m'
28 | FG_RED = '\x1b[31m'
29 | BG_RED = '\x1b[41m'
30 | FG_GREEN = '\x1b[32m'
31 | BG_GREEN = '\x1b[42m'
32 | FG_YELLOW = '\x1b[33m'
33 | BG_YELLOW = '\x1b[43m'
34 | FG_BLUE = '\x1b[34m'
35 | BG_BLUE = '\x1b[44m'
36 | FG_MAGENTA = '\x1b[35m'
37 | BG_MAGENTA = '\x1b[45m'
38 | FG_CYAN = '\x1b[36m'
39 | BG_CYAN = '\x1b[46m'
40 | FG_WHITE = '\x1b[37m'
41 | BG_WHITE = '\x1b[47m'
42 |
43 | # Bright colors
44 | FG_BRIGHT_BLACK = '\x1b[90m'
45 | BG_BRIGHT_BLACK = '\x1b[100m'
46 | FG_BRIGHT_RED = '\x1b[91m'
47 | BG_BRIGHT_RED = '\x1b[101m'
48 | FG_BRIGHT_GREEN = '\x1b[92m'
49 | BG_BRIGHT_GREEN = '\x1b[102m'
50 | FG_BRIGHT_YELLOW = '\x1b[93m'
51 | BG_BRIGHT_YELLOW = '\x1b[103m'
52 | FG_BRIGHT_BLUE = '\x1b[94m'
53 | BG_BRIGHT_BLUE = '\x1b[104m'
54 | FG_BRIGHT_MAGENTA = '\x1b[95m'
55 | BG_BRIGHT_MAGENTA = '\x1b[105m'
56 | FG_BRIGHT_CYAN = '\x1b[96m'
57 | BG_BRIGHT_CYAN = '\x1b[106m'
58 | FG_BRIGHT_WHITE = '\x1b[97m'
59 | BG_BRIGHT_WHITE = '\x1b[107m'
60 |
61 | # Modes
62 | MODE_RESET_ALL = '\x1b[0m'
63 | MODE_BOLD_SET = '\x1b[1m'
64 | MODE_BOLD_RESET = '\x1b[22m'
65 | MODE_DIM_SET = '\x1b[2m'
66 | MODE_DIM_RESET = '\x1b[22m'
67 | MODE_ITALIC_SET = '\x1b[3m'
68 | MODE_ITALIC_RESET = '\x1b[23m'
69 | MODE_UNDERLINE_SET = '\x1b[4m'
70 | MODE_UNDERLINE_RESET = '\x1b[24m'
71 | MODE_BLINKING_SET = '\x1b[5m'
72 | MODE_BLINKING_RESET = '\x1b[25m'
73 | MODE_REVERSE_SET = '\x1b[7m'
74 | MODE_REVERSE_RESET = '\x1b[27m'
75 | MODE_HIDDEN_SET = '\x1b[8m'
76 | MODE_HIDDEN_RESET = '\x1b[28m'
77 | MODE_STRIKETHROUGH_SET = '\x1b[9m'
78 | MODE_STRIKETHROUGH_RESET = '\x1b[29m'
79 |
80 | NoneType: type = type(None)
81 |
82 | _VERBOSE = True
83 |
84 | _DEBUG = False
85 | """Package-wide debug flag"""
86 |
87 | class _ArrayLike(Iterable):
88 | """Anything that can be interpreted as a numpy array"""
89 |
90 | class _SupportsWriteAndFlush(TextIO):
91 | """A file, stream, or buffer that supports writing and flushing"""
92 |
93 | class UnreachableException(Exception):
94 | """The code has reached an unreachable state"""
95 | def __init__(self):
96 | super().__init__(
97 | "the code has reached a location that was thought to be "
98 | "unreachable. please report this issue to the developer at this "
99 | "link: https://github.com/ddh0/easy-llama/issues/new/choose"
100 | )
101 |
102 | class LlamaNullException(Exception):
103 | """Raised when a libllama function returns NULL or NULLPTR"""
104 |
105 | T = TypeVar('T')
106 |
107 | class ptr(Generic[T]):
108 | """Generic type hint representing any ctypes pointer. Optionally subscriptable with any
109 | type."""
110 |
111 | def set_verbose(state: bool) -> None:
112 | """Enable or disable verbose terminal output from easy-llama"""
113 | global _VERBOSE
114 | _VERBOSE = state
115 |
116 | def get_verbose() -> bool:
117 | """Return `True` if verbose terminal output is enabled in easy-llama, `False` otherwise"""
118 | global _VERBOSE
119 | return _VERBOSE
120 |
121 | @contextlib.contextmanager
122 | def KeyboardInterruptHandler():
123 | log_verbose('Press CTRL+C to exit')
124 | try:
125 | yield
126 | except KeyboardInterrupt:
127 | print(ANSI.MODE_RESET_ALL, end='\n', flush=True)
128 |
129 | def log(
130 | text: str,
131 | level: Literal[1, 2, 3, 4] = 1,
132 | disable: bool = False
133 | ) -> None:
134 | """Print the given text, prefixed with a timestamp"""
135 | if disable:
136 | return
137 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d %a %k:%M:%S.%f")[:-3]
138 | if level == 1:
139 | lvltxt = f"{ANSI.FG_BRIGHT_GREEN}INFO"
140 | elif level == 2:
141 | lvltxt = f"{ANSI.FG_BRIGHT_YELLOW}WARNING"
142 | elif level == 3:
143 | lvltxt = f"{ANSI.FG_BRIGHT_RED}ERROR"
144 | elif level == 4:
145 | lvltxt = f"{ANSI.FG_BRIGHT_MAGENTA}STOPWATCH"
146 | else:
147 | raise ValueError(f'parameter `level` must be one of `[1, 2, 3, 4]`, not {level}')
148 | print(
149 | f"{ANSI.MODE_RESET_ALL}{ANSI.MODE_BOLD_SET}{ANSI.FG_BRIGHT_BLACK}[{timestamp}]"
150 | f"{ANSI.MODE_RESET_ALL}{ANSI.MODE_BOLD_SET} {lvltxt}{ANSI.MODE_RESET_ALL}"
151 | f"{ANSI.MODE_BOLD_SET}{ANSI.FG_BRIGHT_BLACK}:{ANSI.MODE_RESET_ALL} {text}",
152 | end='\n',
153 | file=sys.stdout if level in [1,4] else sys.stderr,
154 | flush=True
155 | )
156 |
157 | def log_verbose(text: str, level: Literal[1,2,3,4] = 1) -> None:
158 | if get_verbose():
159 | log(text, level)
160 |
161 | def log_debug(text: str, level: Literal[1,2,3,4] = 1) -> None:
162 | if _DEBUG:
163 | log('[DEBUG] ' + text, level)
164 |
165 | def softmax(z, T: float = 1.0):
166 | """Numerically stable softmax over all dimensions of an arbitrarily shaped array in
167 | float32 precision."""
168 | z_arr = np.array(z, dtype=np.float32)
169 | if z_arr.size == 0:
170 | return z_arr
171 | if T == 0.0:
172 | result = np.zeros_like(z_arr)
173 | flat = z_arr.ravel()
174 | result.flat[np.argmax(flat)] = 1.0
175 | return result
176 | if T < 0:
177 | z_arr = -z_arr
178 | T = -T
179 | max_val = np.max(z_arr)
180 | scaled = (z_arr - max_val) / T
181 | exp_vals = np.exp(scaled)
182 | return exp_vals / np.sum(exp_vals)
183 |
184 | def cls() -> None:
185 | """Clear the terminal"""
186 | if os.name == 'nt':
187 | os.system('cls')
188 | else:
189 | #os.system('clear')
190 | print(ANSI.CLEAR, end='', flush=True)
191 |
192 | def truncate(text: str) -> str:
193 | return text if len(text) < 72 else f"{text[:69]}..."
194 |
195 | def ez_encode(txt: str) -> bytes:
196 | """Encode the given text `txt` from string to UTF-8. If strict encoding fails, an error
197 | will be shown and the offending character(s) will be replaced with `?`."""
198 | try:
199 | return txt.encode('utf-8', errors='strict')
200 | except UnicodeEncodeError:
201 | log(f'error encoding string to UTF-8. using ? replacement character.', level=3)
202 | return txt.encode('utf-8', errors='replace')
203 |
204 | def ez_decode(txt: bytes) -> str:
205 | """Decode the given text `txt` from UTF-8 to string. If strict decoding fails, an error
206 | will be shown and the offending character(s) will be replaced with `�` (U+FFFD)."""
207 | try:
208 | return txt.decode('utf-8', errors='strict')
209 | except UnicodeDecodeError:
210 | log(f'error decoding string from UTF-8. using � replacement character.', level=3)
211 | return txt.decode('utf-8', errors='replace')
212 |
213 | _open = open
214 | _sys = sys
215 | _os = os
216 |
217 | @contextlib.contextmanager
218 | def suppress_output(disable: bool = False):
219 | """Suppress stdout and stderr."""
220 |
221 | # NOTE: simply changing sys.stdout and sys.stderr does not affect output from llama.cpp.
222 | # this method (or similar) is required to suppress all undesired output, for example
223 | # when `verbose=False`.
224 |
225 | if disable:
226 | yield
227 | else:
228 | # save the original file descriptors
229 | original_stdout_fd = _sys.stdout.fileno()
230 | original_stderr_fd = _sys.stderr.fileno()
231 |
232 | saved_stdout_fd = _os.dup(original_stdout_fd)
233 | saved_stderr_fd = _os.dup(original_stderr_fd)
234 |
235 | with _open(_os.devnull, 'wb') as devnull:
236 | devnull_fd = devnull.fileno()
237 |
238 | _os.dup2(devnull_fd, original_stdout_fd)
239 | _os.dup2(devnull_fd, original_stderr_fd)
240 |
241 | try:
242 | yield
243 | finally:
244 | # restore the original file descriptors
245 | _os.dup2(saved_stdout_fd, original_stdout_fd)
246 | _os.dup2(saved_stderr_fd, original_stderr_fd)
247 |
248 | _os.close(saved_stdout_fd)
249 | _os.close(saved_stderr_fd)
250 |
251 | def assert_type(
252 | obj: object,
253 | expected_type: Union[type, tuple[type]],
254 | obj_name: str,
255 | code_location: str,
256 | hint: Optional[str] = None
257 | ):
258 | """Ensure that `obj` is an instance of `expected_type`.
259 |
260 | If `expected_type` is a tuple, ensure that `obj` is an instance of some type in the tuple.
261 |
262 | Raise `TypeError` otherwise, using `obj_name` and `code_location` to make an informative
263 | exception message.
264 |
265 | If specified, `hint` is added as a note to the exception."""
266 |
267 | if isinstance(obj, expected_type):
268 | return
269 |
270 | obj_type_repr = repr(type(obj).__name__)
271 |
272 | if not isinstance(expected_type, tuple):
273 | # represent `int` as 'int' instead of ""
274 | expected_type_repr = repr(expected_type.__name__)
275 | exc = TypeError(
276 | f"{code_location}: {obj_name} should be an instance of {expected_type_repr}, "
277 | f"not {obj_type_repr}"
278 | )
279 | else:
280 | # represent `(int, list)` as "('int', 'list')" instead of
281 | # "(, )"
282 | expected_type_repr = repr(tuple(t.__name__ for t in expected_type))
283 | exc = TypeError(
284 | f"{code_location}: {obj_name} should be one of {expected_type_repr}, "
285 | f"not {obj_type_repr}"
286 | )
287 | if isinstance(hint, str):
288 | exc.add_note(hint)
289 | raise exc
290 |
291 | def null_ptr_check(ptr: ptr, ptr_name: str, loc_hint: str) -> None | NoReturn:
292 | """Ensure that the given object `ptr` is not NULL / NULLPTR
293 |
294 | Raise LlamaNullException on failure
295 |
296 | - ptr:
297 | The object to check
298 | - ptr_name:
299 | The name of the object (for error messages)
300 | - loc_hint:
301 | Code location hint used in easy-llama"""
302 | if not bool(ptr):
303 | raise LlamaNullException(f"{loc_hint}: pointer `{ptr_name}` is null")
304 |
--------------------------------------------------------------------------------
/easy_llama/webui/index.html:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ddh0/easy-llama/04b75b10d8854025eb4855e616a682962e3eef82/easy_llama/webui/index.html
--------------------------------------------------------------------------------
/easy_llama/webui/script.js:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ddh0/easy-llama/04b75b10d8854025eb4855e616a682962e3eef82/easy_llama/webui/script.js
--------------------------------------------------------------------------------
/easy_llama/webui/style.css:
--------------------------------------------------------------------------------
1 | :root {
2 | --color-black: #141414;
3 | --color-dark-grey: #292929;
4 | --color-neutral-grey: #808080;
5 | --color-light-grey: #e1e1e1;
6 | --color-white: #f5f5f5;
7 | --color-green: #22c55e;
8 | --color-green-dim: #1cab51;
9 | --color-blue: #0ea5e9;
10 | --color-blue-dim: #0b8fcb;
11 | --color-red: #ef4444;
12 | --color-red-dim: #d03a3a;
13 | --color-purple: #a855f7;
14 | --color-purple-dim: #9249d7;
15 | --color-yellow: #eab308;
16 | --color-yellow-dim: #cb9b06;
17 | }
18 |
19 | body {
20 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif;
21 | background-color: var(--color-black);
22 | color: var(--color-white);
23 | line-height: 1.0;
24 | display: flex;
25 | justify-content: center;
26 | align-items: center;
27 | min-height: 100svh;
28 | padding: 10px;
29 | }
30 |
--------------------------------------------------------------------------------
/examples/chat_demo.py:
--------------------------------------------------------------------------------
1 | # chat_demo.py
2 | # Python 3.12.9
3 |
4 | import easy_llama as ez
5 |
6 | Llama = ez.Llama(
7 | path_model='/path/to/your-llama3-model.gguf',
8 | n_gpu_layers=-1,
9 | n_ctx=1024,
10 | offload_kqv=True,
11 | warmup=False,
12 | verbose=True
13 | )
14 |
15 | Thread = ez.Thread(
16 | llama=Llama,
17 | prompt_format=ez.PromptFormats.Llama3("You are a helpful assistant.")
18 | )
19 |
20 | ez.set_verbose(False)
21 | ez.utils.cls()
22 | Thread.interact()
23 | print('-' * 80)
24 | Thread.print_stats()
25 |
--------------------------------------------------------------------------------
/examples/pretrained_demo.py:
--------------------------------------------------------------------------------
1 | # pretrained_demo.py
2 | # Python 3.12.9
3 |
4 | import easy_llama as ez
5 |
6 | Llama = ez.Llama(
7 | path_model='/path/to/your-pretrained-model.gguf',
8 | n_gpu_layers=-1,
9 | n_ctx=1024,
10 | offload_kqv=True,
11 | warmup=False,
12 | verbose=True
13 | )
14 |
15 | ez.set_verbose(False)
16 | ez.utils.cls()
17 |
18 | prompt = input(f'Enter your prompt:\n > ')
19 |
20 | print(
21 | f'\n\n{ez.utils.ANSI.FG_BRIGHT_GREEN}{prompt}{ez.utils.ANSI.MODE_RESET_ALL}',
22 | end='',
23 | flush=True
24 | )
25 |
26 | prompt_bytes = prompt.encode('utf-8', errors='replace')
27 |
28 | prompt_tokens = Llama.tokenize(prompt_bytes, add_special=True, parse_special=False)
29 |
30 | token_generator = Llama.stream(
31 | input_tokens=prompt_tokens,
32 | n_predict=-1,
33 | stop_tokens=[]
34 | )
35 |
36 | for token in token_generator:
37 | tok_str = Llama.token_to_piece(token, special=True).decode('utf-8', errors='replace')
38 | print(
39 | f"{ez.utils.ANSI.FG_BRIGHT_CYAN}{tok_str}{ez.utils.ANSI.MODE_RESET_ALL}",
40 | sep='',
41 | end='',
42 | flush=True
43 | )
44 |
45 | print(f'\n{"-" * 80}')
46 |
--------------------------------------------------------------------------------
/examples/simple.py:
--------------------------------------------------------------------------------
1 | # import the package
2 | import easy_llama as ez
3 |
4 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail)
5 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf')
6 |
7 | # evaluate a single token and print the raw logits for inferred the next token
8 | logits = MyLlama.eval([0])
9 | print(logits)
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # pyproject.toml
2 | # https://github.com/ddh0/easy-llama/
3 | # MIT License -- Copyright (c) 2024 Dylan Halladay
4 |
5 | [build-system]
6 | requires = ["setuptools>=61.0"]
7 | build-backend = "setuptools.build_meta"
8 |
9 | [project]
10 | name = "easy_llama"
11 | dynamic = ["version"]
12 | description = "This is easy-llama, a Python package that wraps the C/C++ API provided by llama.cpp. It's intended for developers and machine learning hobbyists who want to integrate on-device language models (LLMs) into their applications."
13 | readme = "README.md"
14 | authors = [{ name = "Dylan Halladay", email = "chemist-mulches-39@icloud.com" }]
15 | license = { text = "MIT" }
16 | requires-python = ">=3.9"
17 | classifiers = [
18 | "Development Status :: 4 - Beta",
19 | "Intended Audience :: Developers",
20 | "Intended Audience :: Science/Research",
21 | "License :: OSI Approved :: MIT License",
22 | "Natural Language :: English",
23 | "Programming Language :: Python :: 3 :: Only",
24 | "Programming Language :: Python :: 3.9",
25 | "Programming Language :: Python :: 3.10",
26 | "Programming Language :: Python :: 3.11",
27 | "Programming Language :: Python :: 3.12"
28 | ]
29 | dependencies = [
30 | "numpy",
31 | "fastapi",
32 | "uvicorn",
33 | "jinja2",
34 | "tqdm"
35 | ]
36 |
37 | [project.urls]
38 | Homepage = "https://github.com/ddh0/easy-llama"
39 |
40 | [tool.setuptools]
41 | packages = ["easy_llama", "easy_llama.webui"]
42 | include-package-data = true
43 |
44 | [tool.setuptools.dynamic]
45 | version = {attr = "easy_llama.__version__"}
46 |
47 | [tool.setuptools.package-data]
48 | "easy_llama.webui" = [
49 | "*.ico",
50 | "*.png",
51 | "*.html",
52 | "*.css",
53 | "*.js",
54 | "*.webmanifest",
55 | ]
56 |
--------------------------------------------------------------------------------