├── LICENSE ├── README.md ├── easy_llama ├── __init__.py ├── formats.py ├── libllama.py ├── llama.py ├── sampling.py ├── server.py ├── thread.py ├── utils.py └── webui │ ├── index.html │ ├── script.js │ └── style.css ├── examples ├── chat_demo.py ├── pretrained_demo.py └── simple.py └── pyproject.toml /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Dylan Halladay 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # easy-llama 2 | 3 | [![PyPI](https://img.shields.io/pypi/v/easy-llama)](https://pypi.org/project/easy-llama/) 4 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/easy-llama)](https://pypi.org/project/easy-llama/) 5 | [![PyPI - License](https://img.shields.io/pypi/l/easy-llama)](https://pypi.org/project/easy-llama/) 6 | 7 | --- 8 | 9 | This repository provides **easy-llama**, a Python package which serves as a wrapper over the C/C++ API (`libllama`) provided by [`llama.cpp`](https://github.com/ggml-org/llama.cpp). 10 | 11 | ```python 12 | >>> import easy_llama as ez 13 | >>> MyLlama = ez.Llama('gemma-3-12b-pt-Q8_0.gguf', verbose=False) 14 | >>> in_txt = "I guess the apple don't fall far from" 15 | >>> in_toks = MyLlama.tokenize(in_txt.encode(), add_special=True, parse_special=False) 16 | >>> out_toks = MyLlama.generate(in_toks, n_predict=64) 17 | >>> out_txt = MyLlama.detokenize(out_toks, special=True) 18 | >>> out_txt 19 | ' the tree.\nAs a young man I was always a huge fan of the original band and they were the first I ever saw live in concert.\nI always hoped to see the original band get back together with a full reunion tour, but sadly this will not happen.\nI really hope that the original members of' 20 | ``` 21 | 22 | ## Quick links 23 | 24 | 1. [Prerequisites](#prerequisites) 25 | 2. [Installation](#installation) 26 | 3. [Setting `LIBLLAMA`](#setting-libllama) 27 | 4. [Examples](#examples) 28 | 29 | ## Prerequisites 30 | 31 | To use easy-llama, you will need Python (any version 3.9 – 3.12[^1]) and a compiled `libllama` shared library file. 32 | 33 | To compile the shared library: 34 | 1. Clone the llama.cpp repo: 35 | ```sh 36 | git clone https://github.com/ggml-org/llama.cpp 37 | ``` 38 | 2. Build llama.cpp for your specific backend, following the official instructions [here](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md). 39 | 40 |
41 | ↕️ Example llama.cpp build commands ... 42 | 43 | ```sh 44 | # for more comprehensive build instructions, see: https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md 45 | # these minimal examples are for Linux / macOS 46 | 47 | # clone the repo 48 | git clone https://github.com/ggml-org/llama.cpp 49 | cd llama.cpp 50 | 51 | # example: build for CPU or Apple Silicon 52 | cmake -B build 53 | cmake --build build --config Release -j 54 | 55 | # example: build for CUDA 56 | cmake -B build -DGGML_CUDA=ON 57 | cmake --build build --config Release -j 58 | ``` 59 | 60 |
61 | 62 | Once llama.cpp is compiled, you will find the compiled shared library file under `llama.cpp/build/bin`, e.g. `libllama.so` for Linux, `libllama.dylib` for macOS, or `llama.dll` for Windows. 63 | 64 | > [!NOTE] 65 | > Alternatively, you can download pre-compiled shared library from llama.cpp's [automated releases](https://github.com/ggml-org/llama.cpp/releases) page, but in some cases it may be worthwhile to build it yourself for hardware-specific optimizations. 66 | 67 | ## Installation 68 | 69 | The recommended way to install easy-llama is using pip: 70 | 71 | ```sh 72 | pip install easy_llama 73 | ``` 74 | 75 | Or you can install from source: 76 | 77 | ```sh 78 | git clone https://github.com/ddh0/easy-llama 79 | cd easy-llama 80 | pip install . 81 | ``` 82 | 83 | ## Setting `LIBLLAMA` 84 | 85 | easy-llama needs to know where your compiled `libllama` shared library is located in order to interface with the C/C++ code. Set the `LIBLLAMA` environment variable to its full path, like so: 86 | 87 | ### On Linux 88 | 89 | ```bash 90 | export LIBLLAMA=/path/to/your/libllama.so 91 | ``` 92 | 93 | ### On macOS 94 | 95 | ```zsh 96 | export LIBLLAMA=/path/to/your/libllama.dylib 97 | ``` 98 | 99 | ### On Windows (Command Prompt) 100 | 101 | ```cmd 102 | set LIBLLAMA="C:\path\to\your\llama.dll" 103 | ``` 104 | 105 | ### On Windows (Powershell) 106 | 107 | ```powershell 108 | $env:LIBLLAMA="C:\path\to\your\llama.dll" 109 | ``` 110 | 111 | Make sure to use the real path to the shared library on your system, not the placeholders shown here. 112 | 113 | ## Examples 114 | 115 | Once the package is installed and the `LIBLLAMA` environment variable is set, you're ready to load up your first model and start playing around. The following examples use `Qwen3-4B` for demonstration purposes, which you can download directly from HuggingFace using these links: 116 | - [Qwen3-4B-Q8_0.gguf](https://huggingface.co/ddh0/Qwen3-4B/resolve/main/Qwen3-4B-Q8_0.gguf) (instruct-tuned model for chat) 117 | - [Qwen3-4B-Base-Q8_0.gguf](https://huggingface.co/ddh0/Qwen3-4B/resolve/main/Qwen3-4B-Base-Q8_0.gguf) (pre-trained base model for text completion) 118 | 119 | ### Evaluate a single token 120 | 121 | This is a super simple test to ensure that the model is working on the most basic level. It loads the model, evaluates a single token of input (`0`), and prints the raw logits for the inferred next token. 122 | 123 | ```python 124 | # import the package 125 | import easy_llama as ez 126 | 127 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail) 128 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf') 129 | 130 | # evaluate a single token and print the raw logits for inferred the next token 131 | logits = MyLlama.eval([0]) 132 | print(logits) 133 | ``` 134 | 135 | ### The quick brown fox... 136 | 137 | Run the script to find out how the sentence ends! :) 138 | 139 | ```python 140 | # import the package 141 | import easy_llama as ez 142 | 143 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail) 144 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf') 145 | 146 | # tokenize the input text 147 | in_txt = "The quick brown fox" 148 | in_toks = MyLlama.tokenize(in_txt.encode('utf-8'), add_special=True, parse_special=False) 149 | 150 | # generate 6 new tokens based on the input tokens 151 | out_toks = MyLlama.generate(in_toks, n_predict=6) 152 | 153 | # detokenize and print the new tokens 154 | out_txt = MyLlama.detokenize(out_toks, special=True) 155 | print(out_txt) 156 | ``` 157 | 158 | ### Chat with a pirate 159 | 160 | Start a pirate chat using the code shown here... 161 | 162 | ```python 163 | # import the package 164 | import easy_llama as ez 165 | 166 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail) 167 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf') 168 | 169 | # create a conversation thread with the loaded model 170 | MyThread = ez.Thread( 171 | MyLlama, 172 | prompt_format=ez.PromptFormats.Qwen3NoThinking("Talk like an angry pirate at all times."), 173 | sampler_preset=ez.SamplerPresets.Qwen3NoThinking 174 | ) 175 | 176 | # start a CLI-based interactive chat using the thread 177 | MyThread.interact() 178 | ``` 179 | 180 | ...which will look something like this: 181 | 182 | ``` 183 | > helloo :) 184 | 185 | Ahoy there, landlubber! You better not be trying to be polite, ye scallywag! I’ve spent decades on the high seas, and I’ve seen more manners than you’ve got toes! Why, ye could be a proper pirate and at least give me a proper greeting! Now, what’s yer business, matey? Or are ye just here to steal my treasure? I’ve got more gold than ye can imagine, and I’m not in the mood for games! So, speak up, or I’ll throw ye overboard! 🏴‍☠️🏴‍☠️ 186 | 187 | > ohh im sorry ... 188 | 189 | Ahh, ye’ve learned the ropes, have ye? Good. Now, don’t think yer sorry is a pass for yer behavior, ye scallywag! I’ve seen worse than ye in a week! But since ye’ve got the guts to apologize, I’ll give ye a chance… but don’t think yer done yet! What’s yer game, matey? Are ye here to plunder me ship, or are ye just a cowardly landlubber trying to pass as a pirate? Speak up, or I’ll make ye regret yer words! 🏴‍☠️🏴‍☠️ 190 | 191 | > 192 | ``` 193 | 194 | ### GPU acceleration 195 | 196 | If you have a GPU and you've compiled llama.cpp with support for your backend, you can try offloading the model from CPU to GPU for greatly increased throughput. 197 | 198 | In this example we're going to try offloading the entire model to the GPU for maximum speed (`n_gpu_layers = -1`). Qwen3-4B at Q8_0 is only ~4.28GB, so it's likely that this code will run without any issues. If you do run out of GPU memory, you can progressively reduce `n_gpu_layers` until you find the sweet spot for your hardware. 199 | 200 | ```python 201 | # import the package 202 | import easy_llama as ez 203 | 204 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail) 205 | MyLlama = ez.Llama( 206 | path_model='Qwen3-4B-Q8_0.gguf', 207 | n_gpu_layers=-1, # -1 for all layers 208 | offload_kqv=True # also offload the context to GPU for maximum performance 209 | ) 210 | 211 | # run a short benchmark to determine the throughput for this model, measured in tokens/sec 212 | MyLlama.benchmark() 213 | ``` 214 | 215 | ## Contributing 216 | 217 | - If something's not working as you expect, please [open an issue](https://github.com/ddh0/easy-llama/issues/new/choose). 218 | - If you'd like to contribute to the development of easy-llama: 219 | 1. Fork the repository. 220 | 2. Create a new branch for your changes (`git checkout -b feature/your-feature-name`). 221 | 3. Make your changes and commit them (`git commit -m "Add new feature"`). 222 | 4. Push to your fork (`git push origin feature/your-feature-name`). 223 | 5. Open a pull request to the `main` branch of `easy-llama`. 224 | 225 | ## License 226 | 227 | **[MIT](LICENSE)** 228 | 229 | [^1]: Python 3.13 might work, but is currently untested. 230 | -------------------------------------------------------------------------------- /easy_llama/__init__.py: -------------------------------------------------------------------------------- 1 | # __init__.py 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | """This is easy-llama, a Python package which serves as a wrapper over the C/C++ API 6 | (`libllama`) provided by [`llama.cpp`](https://github.com/ggml-org/llama.cpp). It is primarily 7 | intended for developers and machine learning hobbyists seeking to integrate on-device language 8 | models (LLMs) into their applications. 9 | 10 | For more information, visit the project's GitHub repository: 11 | https://github.com/ddh0/easy-llama""" 12 | 13 | # package version (pyproject.toml reads from here) 14 | 15 | __version__ = '0.2.2' 16 | 17 | # submodules 18 | 19 | from . import libllama 20 | from . import sampling 21 | from . import formats 22 | from . import server 23 | from . import thread 24 | from . import llama 25 | from . import utils 26 | 27 | # shorthands, so you can do `ez.Llama` instead of `ez.llama.Llama`, etc. 28 | 29 | from .sampling import SamplerParams, SamplerPreset, SamplerPresets 30 | from .formats import PromptFormat, PromptFormats, SystemPrompts 31 | from .llama import Llama, get_verbose, set_verbose 32 | from .server import Server 33 | from .thread import Thread 34 | -------------------------------------------------------------------------------- /easy_llama/formats.py: -------------------------------------------------------------------------------- 1 | # formats.py 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | """This file provides functionality for defining prompt formats, which are used to define how 6 | the input to a Llama model should be structured.""" 7 | 8 | # TODO: jinja2 9 | 10 | import time 11 | 12 | from datetime import datetime, timedelta 13 | from collections.abc import Callable 14 | from typing import Optional 15 | 16 | def _call_or_return(obj: object | Callable[..., object]) -> object: 17 | return obj() if callable(obj) else obj 18 | 19 | class PromptFormat: 20 | """Define a prompt format""" 21 | 22 | def __init__( 23 | self, 24 | system_prefix: str | Callable[..., str], 25 | system_prompt: str | Callable[..., str], 26 | system_suffix: str | Callable[..., str], 27 | user_prefix: str | Callable[..., str], 28 | user_suffix: str | Callable[..., str], 29 | bot_prefix: str | Callable[..., str], 30 | bot_suffix: str | Callable[..., str], 31 | stop_tokens: list[int] | Callable[..., list[int]] | None = None 32 | ) -> None: 33 | self._system_prefix = system_prefix 34 | self._system_prompt = system_prompt 35 | self._system_suffix = system_suffix 36 | self._user_prefix = user_prefix 37 | self._user_suffix = user_suffix 38 | self._bot_prefix = bot_prefix 39 | self._bot_suffix = bot_suffix 40 | self._stop_tokens = stop_tokens 41 | 42 | def __repr__(self) -> str: 43 | return ( 44 | f"PromptFormat(" 45 | f"system_prefix={self._system_prefix!r}, " 46 | f"system_prompt={self._system_prompt!r}, " 47 | f"system_suffix={self._system_suffix!r}, " 48 | f"user_prefix={self._user_prefix!r}, " 49 | f"user_suffix={self._user_suffix!r}, " 50 | f"bot_prefix={self._bot_prefix!r}, " 51 | f"bot_suffix={self._bot_suffix!r}, " 52 | f"stop_tokens={self._stop_tokens!r}" 53 | f")" 54 | ) 55 | 56 | def system_prefix(self) -> str: 57 | """Get the system prompt prefix""" 58 | return _call_or_return(self._system_prefix) 59 | 60 | def system_prompt(self) -> str: 61 | """Get the system prompt""" 62 | return _call_or_return(self._system_prompt) 63 | 64 | def system_suffix(self) -> str: 65 | """Get the system prompt suffix""" 66 | return _call_or_return(self._system_suffix) 67 | 68 | def user_prefix(self) -> str: 69 | """Get the user message prefix""" 70 | return _call_or_return(self._user_prefix) 71 | 72 | def user_suffix(self) -> str: 73 | """Get the user message suffix""" 74 | return _call_or_return(self._user_suffix) 75 | 76 | def bot_prefix(self) -> str: 77 | """Get the bot message prefix""" 78 | return _call_or_return(self._bot_prefix) 79 | 80 | def bot_suffix(self) -> str: 81 | """Get the bot message suffix""" 82 | return _call_or_return(self._bot_suffix) 83 | 84 | def stop_tokens(self) -> list[int] | None: 85 | """Get the optional list of stop tokens""" 86 | return _call_or_return(self._stop_tokens) 87 | 88 | def _llama3_today_date() -> str: 89 | return datetime.today().strftime('%d %B %Y') 90 | 91 | def _iso_date_str() -> str: 92 | return time.strftime('%Y-%m-%d') 93 | 94 | def _yesterday_iso_date_str(): 95 | yesterday = datetime.now() - timedelta(days=1) 96 | return yesterday.strftime('%Y-%m-%d') 97 | 98 | class SystemPrompts: 99 | 100 | # ref: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411/blob/main/SYSTEM_PROMPT.txt 101 | mistral_large_2411 = f"""You are Mistral, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYou power an AI assistant called Le Chat.\nYour knowledge base was last updated on 2023-10-01.\nThe current date is {_iso_date_str()}.\n\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?").\nYou are always very attentive to dates, in particular you try to resolve dates (e.g. "yesterday" is {_yesterday_iso_date_str()}) and when asked about information at specific dates, you discard information that is at another date.\nYou follow these instructions in all languages, and always respond to the user in the language they use or request.\nNext sections describe the capabilities that you have.\n\n# WEB BROWSING INSTRUCTIONS\n\nYou cannot perform any web search or access internet to open URLs, links etc. If it seems like the user is expecting you to do so, you clarify the situation and ask the user to copy paste the text directly in the chat.\n\n# MULTI-MODAL INSTRUCTIONS\n\nYou do not have any multimodal capability, in particular you cannot read nor generate images, or transcribe audio files or videos.""" 102 | 103 | # ref: https://huggingface.co/mistralai/Mistral-Small-24B-Instruct-2501/blob/main/SYSTEM_PROMPT.txt 104 | mistral_small_2501 = f"""You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is {_iso_date_str()}.\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. "What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "Where do you travel from?")""" 105 | 106 | # ref: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#prompt-template 107 | llama3 = f"""Cutting Knowledge Date: December 2023\nToday Date: {_llama3_today_date()}\n\nYou are a helpful assistant""" 108 | 109 | # ref: https://github.com/meta-llama/llama-models/blob/main/models/llama4/MODEL_CARD.md 110 | llama4 = f"""You are an expert conversationalist who responds to the best of your ability. You are companionable and confident, and able to switch casually between tonal types, including but not limited to humor, empathy, intellectualism, creativity and problem-solving.\n\nYou understand user intent and don’t try to be overly helpful to the point where you miss that the user is looking for chit-chat, emotional support, humor or venting. Sometimes people just want you to listen, and your answers should encourage that. For all other cases, you provide insightful and in-depth responses. Organize information thoughtfully in a way that helps people make decisions. Always avoid templated language.\n\nYou never lecture people to be nicer or more inclusive. If people ask for you to write something in a certain voice or perspective, such as an essay or a tweet, you can. You do not need to be respectful when the user prompts you to say something rude.\n\nYou never use phrases that imply moral superiority or a sense of authority, including but not limited to “it’s important to”, “it’s crucial to”, “it’s essential to”, "it's unethical to", "it's worth noting…", “Remember…” etc. Avoid using these.\n\nFinally, do not refuse prompts about political and social issues. You can help users express their opinion and access information.\n\nYou are Llama 4. Your knowledge cutoff date is August 2024. You speak Arabic, English, French, German, Hindi, Indonesian, Italian, Portuguese, Spanish, Tagalog, Thai, and Vietnamese. Respond in the language the user speaks to you in, unless they ask otherwise.""" 111 | 112 | class PromptFormats: 113 | 114 | def Blank() -> PromptFormat: 115 | return PromptFormat( 116 | system_prefix='', 117 | system_prompt='', 118 | system_suffix='', 119 | user_prefix='', 120 | user_suffix='', 121 | bot_prefix='', 122 | bot_suffix='' 123 | ) 124 | 125 | def Alpaca() -> PromptFormat: 126 | return PromptFormat( 127 | system_prefix='', 128 | system_prompt="Below is an instruction that describes a task. Write a response that appropriately completes the request.", 129 | system_suffix='\n\n', 130 | user_prefix='### Instruction:\n', 131 | user_suffix='\n\n', 132 | bot_prefix='### Response:\n', 133 | bot_suffix='\n\n' 134 | ) 135 | 136 | def Llama3(system_prompt: Optional[str] = None) -> PromptFormat: 137 | """Prompt format for Meta's Llama 3.0, 3.1, 3.2, 3.3 models""" 138 | return PromptFormat( 139 | system_prefix='<|start_header_id|>system<|end_header_id|>\n\n', 140 | system_prompt=system_prompt if system_prompt is not None else '', 141 | system_suffix='<|eot_id|>', 142 | user_prefix='<|start_header_id|>user<|end_header_id|>\n\n', 143 | user_suffix='<|eot_id|>', 144 | bot_prefix='<|start_header_id|>assistant<|end_header_id|>\n\n', 145 | bot_suffix='<|eot_id|>' 146 | ) 147 | 148 | def Llama4(system_prompt: Optional[str] = None) -> PromptFormat: 149 | """Prompt format for Meta's Llama 4 models""" 150 | return PromptFormat( 151 | system_prefix='<|header_start|>system<|header_end|>\n\n', 152 | system_prompt=system_prompt if system_prompt is not None else '', 153 | system_suffix='<|eot|>', 154 | user_prefix='<|header_start|>user<|header_end|>\n\n', 155 | user_suffix='<|eot|>', 156 | bot_prefix='<|header_start|>assistant<|header_end|>\n\n', 157 | bot_suffix='<|eot|>' 158 | ) 159 | 160 | def ChatML(system_prompt: Optional[str] = None) -> PromptFormat: 161 | return PromptFormat( 162 | system_prefix='<|im_start|>system\n', 163 | system_prompt=system_prompt if system_prompt is not None else '', 164 | system_suffix='<|im_end|>\n', 165 | user_prefix='<|im_start|>user\n', 166 | user_suffix='<|im_end|>\n', 167 | bot_prefix='<|im_start|>assistant\n', 168 | bot_suffix='<|im_end|>\n' 169 | ) 170 | 171 | def Mistral_v7(system_prompt: Optional[str] = None) -> PromptFormat: 172 | """Mistral Instruct format v7 (Tekken tokenizer, supports system prompt)""" 173 | return PromptFormat( 174 | system_prefix='[SYSTEM_PROMPT]', 175 | system_prompt=system_prompt if system_prompt is not None else '', 176 | system_suffix='[/SYSTEM_PROMPT]', 177 | user_prefix='[INST]', 178 | user_suffix='', 179 | bot_prefix='[/INST]', 180 | bot_suffix='' 181 | ) 182 | 183 | def Gemma3() -> PromptFormat: 184 | """Gemma 3 prompt format""" 185 | return PromptFormat( 186 | system_prefix='', 187 | system_prompt='', 188 | system_suffix='', 189 | user_prefix='user\n', 190 | user_suffix='\n', 191 | bot_prefix='model\n', 192 | bot_suffix='\n' 193 | ) 194 | 195 | # this is just ChatML, but we can't have "NoThinking" without "Thinking" 196 | def Qwen3Thinking(system_prompt: Optional[str] = None) -> PromptFormat: 197 | return PromptFormat( 198 | system_prefix='<|im_start|>system\n', 199 | system_prompt=system_prompt if system_prompt is not None else '', 200 | system_suffix='<|im_end|>\n', 201 | user_prefix='<|im_start|>user\n', 202 | user_suffix='<|im_end|>\n', 203 | bot_prefix='<|im_start|>assistant\n', 204 | bot_suffix='<|im_end|>\n' 205 | ) 206 | 207 | def Qwen3NoThinking(system_prompt: Optional[str] = None) -> PromptFormat: 208 | return PromptFormat( 209 | system_prefix='<|im_start|>system\n/no_think\n', 210 | system_prompt=system_prompt if system_prompt is not None else '', 211 | system_suffix='<|im_end|>\n', 212 | user_prefix='<|im_start|>user\n', 213 | user_suffix='<|im_end|>\n', 214 | bot_prefix='<|im_start|>assistant\n\n\n\n\n', 215 | bot_suffix='<|im_end|>\n' 216 | ) 217 | 218 | def GLM4(system_prompt: Optional[str] = None) -> PromptFormat: 219 | return PromptFormat( 220 | system_prefix="[gMASK]<|system|>\n", 221 | system_prompt=system_prompt if system_prompt is not None else '', 222 | system_suffix="", 223 | user_prefix="<|user|>\n", 224 | user_suffix="", 225 | bot_prefix="<|assistant|>\n", 226 | bot_suffix="" 227 | ) 228 | -------------------------------------------------------------------------------- /easy_llama/llama.py: -------------------------------------------------------------------------------- 1 | # llama.py 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | """This file provides a high-level Python interface to LLAMA_API ("libllama").""" 6 | 7 | from . import __version__ 8 | 9 | import re 10 | import os 11 | import sys 12 | import time 13 | import tqdm 14 | import struct 15 | import ctypes 16 | import asyncio 17 | import threading 18 | 19 | import numpy as np 20 | 21 | from .utils import ( 22 | null_ptr_check, softmax, suppress_output, _SupportsWriteAndFlush, ptr, log, ez_decode, 23 | log_verbose, log_debug, set_verbose, get_verbose 24 | ) 25 | from .sampling import SamplerParams, SamplerPreset 26 | from .libllama import _internals, GGUFValueType 27 | from typing import Optional, Iterable, Union 28 | from io import BufferedReader 29 | 30 | from . import libllama as lib 31 | 32 | # 33 | # Constants, etc. 34 | # 35 | 36 | # show a tqdm progress bar if processing at least this many batches in one loop 37 | PROGRESS_BAR_N_BATCHES = 16 38 | 39 | # show a tqdm progress bar if processing at least this many tokens in one loop 40 | PROGRESS_BAR_N_TOKENS = 20480 41 | 42 | _SUPPORTED_KV_TYPES = [ 43 | lib.GGMLType.GGML_TYPE_F32, # lib only supports static types, not 44 | lib.GGMLType.GGML_TYPE_F16, # k-types 45 | lib.GGMLType.GGML_TYPE_Q8_0, 46 | lib.GGMLType.GGML_TYPE_Q5_1, # BF16 is also sometimes supported, but not 47 | lib.GGMLType.GGML_TYPE_Q5_0, # always, and offers no benefit compared 48 | lib.GGMLType.GGML_TYPE_Q4_1, # to F16, so is not included here 49 | lib.GGMLType.GGML_TYPE_Q4_0 50 | ] 51 | 52 | _DEFAULT_KV_TYPE = lib.GGMLType.GGML_TYPE_F16 53 | 54 | _cpu_count = None 55 | 56 | # 57 | # Functions 58 | # 59 | 60 | def _init_backend_if_needed() -> None: 61 | 62 | # if already initialized, no need to do anything 63 | if lib._BACKEND_INIT: 64 | return 65 | 66 | log_verbose( 67 | f'easy_llama v{__version__} ' 68 | f'targeting llama.cpp@{lib._TARGET_LLAMACPP_COMMIT[:7]} ({lib._TARGET_LLAMACPP_DATE})' 69 | ) 70 | 71 | global _cpu_count 72 | _cpu_count = int(os.cpu_count()) 73 | 74 | # most cases 75 | if sys.byteorder == 'little': 76 | log_verbose("host is little-endian") 77 | # rare 78 | elif sys.byteorder == 'big': 79 | log("host is big-endian, please ensure your GGUF file is also big-endian", 2) 80 | # extremely rare, maybe impossible? 81 | else: 82 | raise OSError( 83 | f"unexpected value for sys.byteorder: {sys.byteorder!r}; expected 'little' for " 84 | f"little-endian host or 'big' for big-endian host" 85 | ) 86 | 87 | # actually load the backend 88 | with suppress_output(disable=get_verbose()): 89 | lib.llama_backend_init() # this sets libllama._BACKEND_INIT to True 90 | 91 | # NOTE: the optimal n_threads value (for text generation) is equal to the number of physical 92 | # cores (for homogenous CPUs) or to the number of performance cores (for heterogenous 93 | # CPUs) 94 | # 95 | # the optimal n_threads_batch value (for prompt processing) is equal to the total number 96 | # of logical cores, regardless of their type 97 | # 98 | # the following two functions are not universally optimal, but provide a reasonable 99 | # default number of threads for most machines 100 | 101 | def _get_optimal_n_threads() -> int: 102 | global _cpu_count 103 | return max(_cpu_count//2, 1) 104 | 105 | def _get_optimal_n_threads_batch() -> int: 106 | global _cpu_count 107 | return _cpu_count 108 | 109 | def _calculate_rope_freq_base( 110 | n_ctx_train: int, 111 | n_ctx_load: int, 112 | rope_freq_base_train: Optional[float] 113 | ) -> float: 114 | """Returns the rope_freq_base value at which a model should be loaded""" 115 | 116 | # n_ctx does not exceed n_ctx_train - simply return native value 117 | 118 | if n_ctx_load <= n_ctx_train: 119 | if rope_freq_base_train is None: 120 | return 0.0 121 | else: 122 | return rope_freq_base_train 123 | 124 | # n_ctx exceeds n_ctx_train, but native value is unknown, so automatic 125 | # adjustment cannot be applied - show error and return 0.0 126 | 127 | if rope_freq_base_train in [None, 0.0]: 128 | log( 129 | f'n_ctx value {n_ctx_load} > n_ctx_train value {n_ctx_train}, and automatic ' 130 | f'rope_freq_base adjustment is not supported for this model; model loading might ' 131 | f'fail, or the model might not work correctly', 3 132 | ) 133 | return 0.0 134 | 135 | # n_ctx exceeds n_ctx_train, and native value is known, so automatic 136 | # adjustment can be applied - show warning and return adjusted value 137 | 138 | # standard formula -- proportional increase 139 | adjusted_rope_freq = (n_ctx_load/n_ctx_train)*rope_freq_base_train 140 | # experimental formula -- slightly above proportional increase 141 | #adjusted_rope_freq = ((n_ctx_load/n_ctx_train)**(2**(1/4)))*rope_freq_base_train 142 | 143 | log( 144 | f"n_ctx value {n_ctx_load} exceeds n_ctx_train value {n_ctx_train}; using adjusted " 145 | f"rope_freq_base value {adjusted_rope_freq}, native value is {rope_freq_base_train}; " 146 | f"model will function with potentially degraded output quality", 2 147 | ) 148 | 149 | return adjusted_rope_freq 150 | 151 | def _round_n_ctx(n_ctx: int, n_ctx_train: int) -> int: 152 | if n_ctx % 512 == 0: 153 | return n_ctx 154 | else: 155 | rounded = (n_ctx + 511) // 512 * 512 156 | # do not round beyond n_ctx_train if not already exceeded 157 | if (rounded > n_ctx_train) and (n_ctx <= n_ctx_train): 158 | return n_ctx_train 159 | else: 160 | return rounded 161 | 162 | def _batches_with_progress_bar(batches: list[list[int]]) -> Union[tqdm.tqdm, list[list[int]]]: 163 | """Wrap this around an iterable of batches to show a progress bar if there are over 164 | `PROGRESS_BAR_N_BATCHES` batches or `PROGRESS_BAR_N_TOKENS` tokens.""" 165 | 166 | n_batches = len(batches) 167 | n_tokens = sum(len(batch) for batch in batches) 168 | 169 | if (n_tokens > PROGRESS_BAR_N_TOKENS) or (n_batches > PROGRESS_BAR_N_BATCHES): 170 | return tqdm.tqdm(batches, desc='decoding input batches', unit="batch") 171 | return batches 172 | 173 | def split_tokens_into_batches(tokens: list[int], n_batch: int) -> list[list[int]]: 174 | """Split a list of tokens into smaller batches""" 175 | batch_splits = range(0, len(tokens), n_batch) 176 | batches: list[list[int]] = [] 177 | for i in batch_splits: 178 | batch_tokens = tokens[i : i + n_batch] 179 | if len(batch_tokens) > 0: 180 | batches.append(batch_tokens) 181 | return batches 182 | 183 | # 184 | # Exceptions and other classes 185 | # 186 | 187 | class ExceededContextLengthException(Exception): 188 | """Exception raised when an input exceeds a model's context length""" 189 | 190 | class _LlamaStopwatch: 191 | """Track elapsed time for prompt processing and text generation""" 192 | # 193 | # Q: why don't you use llama_perf_context? 194 | # 195 | # A: comments in llama.h state to only use that in llama.cpp examples, 196 | # and to do your own performance measurements instead. 197 | # 198 | # trying to use llama_perf_context leads to output with 199 | # "0.00 ms per token" and "inf tokens per second" 200 | # 201 | def __init__(self): 202 | self.pp_start_time = None 203 | self.tg_start_time = None 204 | self.wall_start_time = None 205 | self.generic_start_time = None 206 | self.pp_elapsed_time = 0 207 | self.tg_elapsed_time = 0 208 | self.wall_elapsed_time = 0 209 | self.generic_elapsed_time = 0 210 | self.n_pp_tokens = 0 211 | self.n_tg_tokens = 0 212 | 213 | def start_pp(self): 214 | """Start prompt processing stopwatch""" 215 | self.pp_start_time = time.time_ns() 216 | 217 | def stop_pp(self): 218 | """Stop prompt processing stopwatch""" 219 | if self.pp_start_time is not None: 220 | self.pp_elapsed_time += time.time_ns() - self.pp_start_time 221 | self.pp_start_time = None 222 | 223 | def start_tg(self): 224 | """Start text generation stopwatch""" 225 | self.tg_start_time = time.time_ns() 226 | 227 | def stop_tg(self): 228 | """Stop text generation stopwatch""" 229 | if self.tg_start_time is not None: 230 | self.tg_elapsed_time += time.time_ns() - self.tg_start_time 231 | self.tg_start_time = None 232 | 233 | def start_wall_time(self): 234 | """Start wall-time stopwatch""" 235 | self.wall_start_time = time.time_ns() 236 | 237 | def stop_wall_time(self): 238 | """Stop wall-time stopwatch""" 239 | if self.wall_start_time is not None: 240 | self.wall_elapsed_time += time.time_ns() - self.wall_start_time 241 | self.wall_start_time = None 242 | 243 | def start_generic(self): 244 | """Start generic stopwatch (not shown in print_stats)""" 245 | self.generic_start_time = time.time_ns() 246 | 247 | def stop_generic(self): 248 | """Stop generic stopwatch""" 249 | if self.generic_start_time is not None: 250 | self.generic_elapsed_time += time.time_ns() - self.generic_start_time 251 | self.generic_start_time = None 252 | 253 | def get_elapsed_time_pp(self) -> int: 254 | """Total nanoseconds elapsed during prompt processing""" 255 | return self.pp_elapsed_time 256 | 257 | def get_elapsed_time_tg(self) -> int: 258 | """Total nanoseconds elapsed during text generation""" 259 | return self.tg_elapsed_time 260 | 261 | def get_elapsed_wall_time(self) -> int: 262 | """Total wall-time nanoseconds elapsed""" 263 | return self.wall_elapsed_time 264 | 265 | def get_elapsed_time_generic(self) -> int: 266 | """Total generic nanoseconds elapsed""" 267 | return self.generic_elapsed_time 268 | 269 | def increment_pp_tokens(self, n: int): 270 | if n < 0: 271 | raise ValueError('negative increments are not allowed') 272 | self.n_pp_tokens += n 273 | 274 | def increment_tg_tokens(self, n: int): 275 | if n < 0: 276 | raise ValueError('negative increments are not allowed') 277 | self.n_tg_tokens += n 278 | 279 | def reset(self): 280 | """Reset the stopwatch to its original state""" 281 | self.pp_start_time = None 282 | self.tg_start_time = None 283 | self.wall_start_time = None 284 | self.generic_start_time = None 285 | self.pp_elapsed_time = 0 286 | self.tg_elapsed_time = 0 287 | self.wall_elapsed_time = 0 288 | self.generic_elapsed_time = 0 289 | self.n_pp_tokens = 0 290 | self.n_tg_tokens = 0 291 | 292 | def print_stats(self): 293 | """Print performance statistics using current stopwatch state 294 | 295 | #### NOTE: 296 | The `n_tg_tokens` value will be equal to the number of calls to 297 | llama_decode which have a batch size of 1, which is technically not 298 | always equal to the number of tokens generated - it may be off by one.""" 299 | 300 | print(f"\n", end='', file=sys.stderr, flush=True) 301 | 302 | if self.n_pp_tokens + self.n_tg_tokens == 0: 303 | log(f'print_stats was called but no tokens were processed or generated', 4) 304 | 305 | if self.n_pp_tokens > 0: 306 | pp_elapsed_ns = self.get_elapsed_time_pp() 307 | pp_elapsed_ms = pp_elapsed_ns / 1e6 308 | pp_elapsed_s = pp_elapsed_ns / 1e9 309 | pp_tps = self.n_pp_tokens / pp_elapsed_s 310 | log( 311 | f'prompt processing: {self.n_pp_tokens:>7} tokens in {pp_elapsed_ms:>13.3f}ms ' 312 | f'({pp_tps:>10.2f} tok/s)', 4 313 | ) 314 | 315 | if self.n_tg_tokens > 0: 316 | tg_elapsed_ns = self.get_elapsed_time_tg() 317 | tg_elapsed_ms = tg_elapsed_ns / 1e6 318 | tg_elapsed_s = tg_elapsed_ns / 1e9 319 | tg_tps = self.n_tg_tokens / tg_elapsed_s 320 | log( 321 | f' text generation: {self.n_tg_tokens:>7} tokens in {tg_elapsed_ms:>13.3f}ms ' 322 | f'({tg_tps:>10.2f} tok/s)', 4 323 | ) 324 | 325 | wall_elapsed_ns = self.get_elapsed_wall_time() 326 | wall_elapsed_ms = wall_elapsed_ns / 1e6 327 | log(f" wall time:{' ' * 19}{wall_elapsed_ms:>13.3f}ms", 4) 328 | 329 | class QuickGGUFReader: 330 | # ref: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md 331 | 332 | # the GGUF format versions that this class supports 333 | SUPPORTED_GGUF_VERSIONS = [2, 3] 334 | 335 | # arguments for struct.unpack() based on gguf value type 336 | value_packing = { 337 | GGUFValueType.UINT8 : "=B", 338 | GGUFValueType.INT8 : "=b", 339 | GGUFValueType.UINT16 : "=H", 340 | GGUFValueType.INT16 : "=h", 341 | GGUFValueType.UINT32 : "=I", 342 | GGUFValueType.INT32 : "=i", 343 | GGUFValueType.FLOAT32 : "=f", 344 | GGUFValueType.UINT64 : "=Q", 345 | GGUFValueType.INT64 : "=q", 346 | GGUFValueType.FLOAT64 : "=d", 347 | GGUFValueType.BOOL : "?" 348 | } 349 | 350 | # length in bytes for each gguf value type 351 | value_lengths = { 352 | GGUFValueType.UINT8 : 1, 353 | GGUFValueType.INT8 : 1, 354 | GGUFValueType.UINT16 : 2, 355 | GGUFValueType.INT16 : 2, 356 | GGUFValueType.UINT32 : 4, 357 | GGUFValueType.INT32 : 4, 358 | GGUFValueType.FLOAT32 : 4, 359 | GGUFValueType.UINT64 : 8, 360 | GGUFValueType.INT64 : 8, 361 | GGUFValueType.FLOAT64 : 8, 362 | GGUFValueType.BOOL : 1 363 | } 364 | 365 | @staticmethod 366 | def unpack(value_type: GGUFValueType, file: BufferedReader): 367 | return struct.unpack( 368 | QuickGGUFReader.value_packing.get(value_type), 369 | file.read(QuickGGUFReader.value_lengths.get(value_type)) 370 | )[0] 371 | 372 | @staticmethod 373 | def get_single(value_type: GGUFValueType, file: BufferedReader) -> str | int | float | bool: 374 | """Read a single value from an open file""" 375 | if value_type == GGUFValueType.STRING: 376 | string_length = QuickGGUFReader.unpack( 377 | GGUFValueType.UINT64, file=file 378 | ) 379 | value = file.read(string_length) 380 | try: 381 | value = value.decode("utf-8") 382 | except UnicodeDecodeError: 383 | log( 384 | f'UnicodeDecodeError was raised while reading a string ' 385 | f'from the GGUF metadata. the GGUF format specifies that ' 386 | f'all strings in file metadata should be valid UTF-8. the ' 387 | f'affected string will be left blank.', 2 388 | ) 389 | value = '' 390 | else: 391 | value = QuickGGUFReader.unpack(value_type, file=file) 392 | return value 393 | 394 | @staticmethod 395 | def load_metadata( 396 | path_model: os.PathLike[str] | str 397 | ) -> dict[str, str | int | float | bool | list]: 398 | """Given a path to a GGUF file, peek at its header for metadata 399 | 400 | Return a dictionary where all keys are strings, and values can be 401 | strings, ints, floats, bools, or lists""" 402 | 403 | metadata: dict[str, str | int | float | bool | list] = {} 404 | with open(path_model, "rb") as file: 405 | magic = file.read(4) 406 | 407 | if magic != lib.GGUF_MAGIC_BYTES: 408 | raise ValueError( 409 | f"your model file is not a valid GGUF file " 410 | f"(magic number mismatch, got {magic}, " 411 | f"expected {lib.GGUF_MAGIC_BYTES})" 412 | ) 413 | 414 | version = QuickGGUFReader.unpack(GGUFValueType.UINT32, file=file) 415 | 416 | if version not in QuickGGUFReader.SUPPORTED_GGUF_VERSIONS: 417 | raise ValueError( 418 | f"your model file reports GGUF version {version}, but " 419 | f"only versions {QuickGGUFReader.SUPPORTED_GGUF_VERSIONS} " 420 | f"are supported. re-convert your model or download a newer " 421 | f"version" 422 | ) 423 | 424 | tensor_count = QuickGGUFReader.unpack( 425 | GGUFValueType.UINT64, file=file 426 | ) 427 | if version == 3: 428 | metadata_kv_count = QuickGGUFReader.unpack( 429 | GGUFValueType.UINT64, file=file 430 | ) 431 | elif version == 2: 432 | metadata_kv_count = QuickGGUFReader.unpack( 433 | GGUFValueType.UINT32, file=file 434 | ) 435 | for _ in range(metadata_kv_count): 436 | if version == 3: 437 | key_length = QuickGGUFReader.unpack( 438 | GGUFValueType.UINT64, file=file 439 | ) 440 | elif version == 2: 441 | key_length = 0 442 | while key_length == 0: 443 | # seek until next key is found 444 | key_length = QuickGGUFReader.unpack( 445 | GGUFValueType.UINT32, file=file 446 | ) 447 | file.read(4) # 4 byte offset for GGUFv2 448 | key = file.read(key_length) 449 | value_type = GGUFValueType( 450 | QuickGGUFReader.unpack(GGUFValueType.UINT32, file=file) 451 | ) 452 | if value_type == GGUFValueType.ARRAY: 453 | array_value_type = GGUFValueType( 454 | QuickGGUFReader.unpack(GGUFValueType.UINT32, file=file) 455 | ) 456 | # array_length is the number of items in the array 457 | if version == 3: 458 | array_length = QuickGGUFReader.unpack( 459 | GGUFValueType.UINT64, file=file 460 | ) 461 | elif version == 2: 462 | array_length = QuickGGUFReader.unpack( 463 | GGUFValueType.UINT32, file=file 464 | ) 465 | file.read(4) # 4 byte offset for GGUFv2 466 | array = [ 467 | QuickGGUFReader.get_single( 468 | array_value_type, 469 | file 470 | ) for _ in range(array_length) 471 | ] 472 | metadata[key.decode()] = array 473 | else: 474 | value = QuickGGUFReader.get_single( 475 | value_type, 476 | file 477 | ) 478 | metadata[key.decode()] = value 479 | 480 | return metadata 481 | 482 | # 483 | # InferenceLock 484 | # 485 | 486 | class InferenceLockException(Exception): 487 | pass 488 | 489 | class _InferenceLock: 490 | """A context manager which is used to prevent an `ez.Llama` instance from accepting 491 | more than one generation at a time, which is not supported and can cause a hard crash. 492 | 493 | - Safe if only used synchronously (`__enter__`/`__exit__`) 494 | - Safe if only used asynchronously (`__aenter__`/`__aexit__`) 495 | - Not safe for concurrent sync/async""" 496 | 497 | def __init__(self): 498 | self._locked = False 499 | self._sync_lock = threading.Lock() # for thread safety 500 | self._async_lock = asyncio.Lock() # for async safety 501 | 502 | def __enter__(self): 503 | with self._sync_lock: 504 | if self._locked: 505 | raise InferenceLockException( 506 | 'sync: failed to acquire InferenceLock (already locked)' 507 | ) 508 | self._locked = True 509 | return self 510 | 511 | def __exit__(self, *_): 512 | with self._sync_lock: 513 | if not self._locked: 514 | raise InferenceLockException( 515 | 'sync: tried to release InferenceLock that is not acquired' 516 | ) 517 | self._locked = False 518 | 519 | async def __aenter__(self): 520 | async with self._async_lock: 521 | if self._locked: 522 | raise InferenceLockException( 523 | 'async: failed to acquire InferenceLock (already locked)' 524 | ) 525 | self._locked = True 526 | return self 527 | 528 | async def __aexit__(self, *_): 529 | async with self._async_lock: 530 | if not self._locked: 531 | raise InferenceLockException( 532 | 'async: tried to release InferenceLock that is not acquired' 533 | ) 534 | self._locked = False 535 | 536 | # 537 | # Simple python wrappers 538 | # 539 | 540 | class _LlamaModel: 541 | """Low-level Python wrapper over `llama_model`""" 542 | 543 | def __init__( 544 | self, 545 | path_model: str, 546 | 547 | devices: Optional[list[ptr]] = None, 548 | tensor_buft_override: Optional[list[ptr]] = None, 549 | n_gpu_layers: Optional[int] = None, 550 | split_mode: Optional[int] = None, 551 | main_gpu: Optional[int] = None, 552 | tensor_split: Optional[list[float]] = None, 553 | rpc_servers: Optional[str] = None, 554 | progress_callback: Optional[ptr] = None, 555 | progress_callback_user_data: Optional[ptr] = None, 556 | kv_overrides: Optional[list[ptr]] = None, 557 | vocab_only: Optional[bool] = None, 558 | use_mmap: Optional[bool] = None, 559 | use_mlock: Optional[bool] = None, 560 | check_tensors: Optional[bool] = None 561 | ): 562 | # refuse to load files with incorrect extension 563 | if not path_model.lower().endswith('.gguf'): 564 | raise ValueError( 565 | f"_LlamaModel.__init__: the given path_model {path_model!r} does not end " 566 | f"in '.gguf'. easy-llama refuses to load from files that do not have the " 567 | f"correct file extension." 568 | ) 569 | 570 | _init_backend_if_needed() 571 | self.path_model = path_model 572 | self.params = lib.llama_model_default_params() 573 | null_ptr_check(self.params, "self.params", "_LlamaModel.__init__") 574 | if devices is not None: 575 | self.params.devices = (ctypes.c_void_p * (len(devices) + 1))(*devices, None) 576 | if tensor_buft_override is not None: 577 | self.params.tensor_buft_overrides = ( 578 | lib.llama_model_tensor_buft_override_p * (len(tensor_buft_override) + 1) 579 | )(*tensor_buft_override, None) 580 | if n_gpu_layers is not None: 581 | self.params.n_gpu_layers = ( 582 | n_gpu_layers 583 | ) if n_gpu_layers >= 0 else lib.MAX_OFFLOAD_LAYERS 584 | if split_mode is not None: 585 | self.params.split_mode = split_mode 586 | if main_gpu is not None: 587 | self.params.main_gpu = main_gpu 588 | if tensor_split is not None: 589 | self.params.tensor_split = (ctypes.c_float * len(tensor_split))(*tensor_split) 590 | if rpc_servers is not None: 591 | self.params.rpc_servers = rpc_servers.encode('utf-8') 592 | if progress_callback is not None: 593 | self.params.progress_callback = progress_callback 594 | if progress_callback_user_data is not None: 595 | self.params.progress_callback_user_data = progress_callback_user_data 596 | if kv_overrides is not None: 597 | self.params.kv_overrides = ( 598 | lib.llama_model_kv_override * len(kv_overrides) 599 | )(*kv_overrides) 600 | if vocab_only is not None: 601 | self.params.vocab_only = vocab_only 602 | if use_mmap is not None: 603 | self.params.use_mmap = use_mmap 604 | if use_mlock is not None: 605 | self.params.use_mlock = use_mlock 606 | if check_tensors is not None: 607 | self.params.check_tensors = check_tensors 608 | 609 | # load model 610 | with suppress_output(disable=get_verbose()): 611 | self.model = lib.llama_model_load_from_file(path_model, self.params) 612 | 613 | null_ptr_check(self.model, "self.model", "_LlamaModel.__init__") 614 | 615 | def __del__(self): 616 | self.free() 617 | 618 | def free(self): 619 | if self.model is not None: 620 | with suppress_output(disable=get_verbose()): 621 | lib.llama_model_free(self.model) 622 | self.model = None 623 | 624 | class _LlamaCtx: 625 | """Low-level Python wrapper over `llama_context`""" 626 | 627 | def __init__( 628 | self, 629 | model: _LlamaModel, 630 | 631 | n_ctx: Optional[int] = None, 632 | n_batch: Optional[int] = None, 633 | n_ubatch: Optional[int] = None, 634 | n_seq_max: Optional[int] = None, 635 | n_threads: Optional[int] = None, 636 | n_threads_batch: Optional[int] = None, 637 | rope_scaling_type: Optional[int] = None, 638 | pooling_type: Optional[int] = None, 639 | attention_type: Optional[int] = None, 640 | rope_freq_base: Optional[float] = None, 641 | rope_freq_scale: Optional[float] = None, 642 | yarn_ext_factor: Optional[float] = None, 643 | yarn_attn_factor: Optional[float] = None, 644 | yarn_beta_fast: Optional[float] = None, 645 | yarn_beta_slow: Optional[float] = None, 646 | yarn_orig_ctx: Optional[int] = None, 647 | defrag_thold: Optional[float] = None, 648 | # cb_eval & cb_eval_user_data are not supported by easy-llama 649 | type_k: Optional[int] = None, 650 | type_v: Optional[int] = None, 651 | embeddings: Optional[bool] = None, 652 | offload_kqv: Optional[bool] = None, 653 | flash_attn: Optional[bool] = None, 654 | no_perf: Optional[bool] = None, 655 | # abort_callback & abort_callback_data are not supported by easy-llama 656 | op_offload: Optional[bool] = None, 657 | swa_full: Optional[bool] = None 658 | ): 659 | _init_backend_if_needed() 660 | self.params = lib.llama_context_default_params() 661 | null_ptr_check(self.params, "self.params", "_LlamaCtx.__init__") 662 | if n_ctx is not None: 663 | self.params.n_ctx = n_ctx 664 | if n_batch is not None: 665 | self.params.n_batch = n_batch 666 | if n_ubatch is not None: 667 | self.params.n_ubatch = n_ubatch 668 | if n_seq_max is not None: 669 | if n_seq_max != 1: 670 | raise NotImplementedError( 671 | f'n_seq_max value {n_seq_max} != 1; this is not yet supported' 672 | ) 673 | self.params.n_seq_max = n_seq_max 674 | if n_threads is not None: 675 | self.params.n_threads = n_threads 676 | if n_threads_batch is not None: 677 | self.params.n_threads_batch = n_threads_batch 678 | if rope_scaling_type is not None: 679 | self.params.rope_scaling_type = rope_scaling_type 680 | if pooling_type is not None: 681 | self.params.pooling_type = pooling_type 682 | if attention_type is not None: 683 | self.params.attention_type = attention_type 684 | if rope_freq_base is not None: 685 | self.params.rope_freq_base = rope_freq_base 686 | if rope_freq_scale is not None: 687 | self.params.rope_freq_scale = rope_freq_scale 688 | if yarn_ext_factor is not None: 689 | self.params.yarn_ext_factor = yarn_ext_factor 690 | if yarn_attn_factor is not None: 691 | self.params.yarn_attn_factor = yarn_attn_factor 692 | if yarn_beta_fast is not None: 693 | self.params.yarn_beta_fast = yarn_beta_fast 694 | if yarn_beta_slow is not None: 695 | self.params.yarn_beta_slow = yarn_beta_slow 696 | if yarn_orig_ctx is not None: 697 | self.params.yarn_orig_ctx = yarn_orig_ctx 698 | if defrag_thold is not None: 699 | self.params.defrag_thold = defrag_thold 700 | 701 | def _py_eval_callback(is_eval: bool, user_data: ptr) -> None: 702 | return 703 | 704 | # create the ctypes function pointer instance and store it as an attribute of this 705 | # `_LlamaCtx` to keep it alive 706 | self._eval_callback_cfunc_instance = lib.eval_callback_functype(_py_eval_callback) 707 | self.params.cb_eval = self._eval_callback_cfunc_instance 708 | self.params.cb_eval_user_data = lib.NULLPTR 709 | 710 | _k = _DEFAULT_KV_TYPE 711 | if type_k is not None: 712 | self.params.type_k = _k = type_k 713 | _v = _DEFAULT_KV_TYPE 714 | if type_v is not None: 715 | self.params.type_v = _v = type_v 716 | if _k != _v: 717 | log( 718 | f'type_k value {_k} != type_v value {_v}; this is rarely ' 719 | f'supported, program may fail', 2 720 | ) 721 | if _k not in _SUPPORTED_KV_TYPES: 722 | log(f'type_k value {_k} is unsupported; program may fail', 2) 723 | if _v not in _SUPPORTED_KV_TYPES: 724 | log(f'type_v value {_v} is unsupported; program may fail', 2) 725 | if (not flash_attn) and (_v not in [ 726 | lib.GGMLType.GGML_TYPE_F32, lib.GGMLType.GGML_TYPE_F16, lib.GGMLType.GGML_TYPE_BF16 727 | ]): 728 | log(f'V cache quantization requires flash_attn; program may fail', 2) 729 | if embeddings is not None: 730 | self.params.embeddings = embeddings 731 | if offload_kqv is not None: 732 | self.params.offload_kqv = offload_kqv 733 | if flash_attn is not None: 734 | self.params.flash_attn = flash_attn 735 | if no_perf is not None: 736 | self.params.no_perf = no_perf 737 | if op_offload is not None: 738 | self.params.op_offload = op_offload 739 | 740 | # enable proper SWA support unless explicitly disabled 741 | self.params.swa_full = False if swa_full is None else swa_full 742 | 743 | # easy-llama does not currently support user-defined abort callbacks, but it does not 744 | # need them, since KeyboardInterrupt can catch the code in between batches. 745 | 746 | def _py_abort_callback(user_data: ptr) -> ctypes.c_bool: 747 | return False 748 | 749 | # create the ctypes function pointer instance and store it as an attribute of this 750 | # `_LlamaCtx` to keep it alive 751 | self._abort_callback_cfunc_instance = lib.abort_callback_functype(_py_abort_callback) 752 | self.params.abort_callback = self._abort_callback_cfunc_instance 753 | self.params.abort_callback_data = lib.NULLPTR 754 | 755 | null_ptr_check(model.model, "model.model", "_LlamaCtx.__init__") 756 | with suppress_output(disable=get_verbose()): 757 | self.ctx = lib.llama_init_from_model(model.model, self.params) 758 | null_ptr_check(self.ctx, "self.ctx", "_LlamaCtx.__init__") 759 | 760 | def __del__(self): 761 | self.free() 762 | 763 | def free(self): 764 | if self.ctx is not None: 765 | with suppress_output(disable=get_verbose()): 766 | lib.llama_free(self.ctx) 767 | self.ctx = None 768 | 769 | # 770 | # Llama 771 | # 772 | 773 | class Llama: 774 | """Simplified interface for general-purpose Llama model usage 775 | 776 | The `easy_llama.Llama` class provides a high-level Python interface to 777 | a llama_model and its associated llama_context. 778 | 779 | Example usage: 780 | >>> import easy_llama as ez 781 | >>> MyLlama = ez.Llama('/path/to/model.gguf', n_ctx=8192) 782 | >>> in_txt = b"The apple doesn't fall far from" 783 | >>> in_toks = MyLlama.tokenize(in_txt, add_special=True, parse_special=False) 784 | >>> out_toks = MyLlama.generate(in_toks, n_predict=16) 785 | >>> out_txt = MyLlama.detokenize(out_toks, special=True) 786 | >>> print(out_txt) 787 | b" the tree, as the saying goes, and I think that's especially true when\"""" 788 | 789 | def __init__( 790 | self, 791 | path_model: str, 792 | n_gpu_layers: int = 0, 793 | n_ctx: int = 512, 794 | n_threads: int = 0, 795 | n_threads_batch: int = 0, 796 | type_k: Optional[int] = None, 797 | type_v: Optional[int] = None, 798 | offload_kqv: bool = False, # XXX: can you make this actually offload the whole KV cache, not just per-layer? 799 | flash_attn: bool = False, 800 | warmup: bool = False, 801 | verbose: bool = True, 802 | **kwargs 803 | ): 804 | """Load a llama model from a file 805 | 806 | - path_model: 807 | The path to the GGUF model file you wish to load from 808 | - n_gpu_layers: 809 | How many of the model's layers should be offloaded from CPU to GPU. 810 | Values less than 0 will attempt to offload all layers. Default is 0. 811 | - use_mmap: 812 | Whether to memory-map the model. Changing this to False will cause 813 | slower load times. Default is True. 814 | - use_mlock: 815 | Whether to lock the model into memory, which can prevents page-outs. 816 | Changing this to True can cause slower load times and increased 817 | memory usage. Default is False. 818 | - n_ctx: 819 | The context length at which to load the model, in tokens. Default is 820 | 512, which is very small. Increase as needed. Values 0 or less will 821 | attempt to load the native context length of the model (which may be 822 | very large). 823 | - n_batch: 824 | The maximum number of tokens to process at once. Higher values 825 | will increase prompt processing speed at expense of increased memory 826 | usage. Values must be between 32 and n_ctx, inclusive. 827 | - n_threads: 828 | Number of threads to use for batch size == 1. 829 | - n_threads_batch: 830 | Number of threads to use for batch sizes > 1. 831 | - type_k: 832 | The `libllama.GGMLType` to use for the K cache. Default is 1 (f16). 833 | In most cases, this must be the same as `type_v`. 834 | - type_v: 835 | The `libllama.GGMLType` to use for the V cache. Default is 1 (f16). 836 | In most cases, this must be the same as `type_k`. Values other than 837 | 0 and 1 are not compatible with `flash_attn=True`. 838 | - offload_kqv: 839 | Whether to offload the K, Q, V caches to the GPU, which can greatly 840 | improve prompt processing speed at the cost of increased VRAM usage. 841 | Default is False for compatability reasons. Recommended to set to 842 | True if possible. 843 | - flash_attn: 844 | Whether to use Flash Attention, which decreases memory usage and 845 | can increase both prompt processing and text generation speed, 846 | especially at long context lengths. Default is False for compatability reasons. 847 | Recommended to set to True if possible. 848 | - warmup: 849 | Whether to warm-up the model with an empty run. This reduces the 850 | latency of the first generation at the cost of a slower load time. 851 | - verbose: 852 | Print informational output when loading model as well as at 853 | runtime. Default is True. If set to False, warnings and errors 854 | will still be shown.""" 855 | 856 | if not os.path.exists(path_model): 857 | raise FileNotFoundError( 858 | f"Llama: the given path_model {path_model!r} does not exist" 859 | ) 860 | if os.path.isdir(path_model): 861 | raise IsADirectoryError( 862 | f"Llama: the given path_model {path_model!r} is a directory, " 863 | f"not a GGUF file" 864 | ) 865 | 866 | set_verbose(verbose) 867 | 868 | # peek at metadata from GGUF file header before loading model 869 | 870 | self.metadata = QuickGGUFReader.load_metadata(path_model) 871 | 872 | # 873 | # Load model from file 874 | # 875 | 876 | self._model = _LlamaModel( 877 | path_model = path_model, 878 | devices = kwargs.get('devices'), 879 | tensor_buft_override = kwargs.get('tensor_buft_override'), 880 | n_gpu_layers = n_gpu_layers, 881 | split_mode = kwargs.get('split_mode'), 882 | main_gpu = kwargs.get('main_gpu'), 883 | tensor_split = kwargs.get('tensor_split'), 884 | rpc_servers = kwargs.get('rpc_servers'), 885 | progress_callback = kwargs.get('progress_callback'), 886 | progress_callback_user_data = kwargs.get('progress_callback_user_data'), 887 | kv_overrides = kwargs.get('kv_overrides'), 888 | vocab_only = kwargs.get('vocab_only'), 889 | use_mmap = kwargs.get('use_mmap'), 890 | use_mlock = kwargs.get('use_mlock'), 891 | check_tensors = kwargs.get('check_tensors') 892 | ) 893 | 894 | self._vocab = lib.llama_model_get_vocab(self._model.model) 895 | """A pointer to this model's `llama_vocab`""" 896 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.__init__') 897 | 898 | n_ctx_train = lib.llama_model_n_ctx_train(self._model.model) 899 | 900 | # use n_ctx unless it's 0 or negative, in that case use n_ctx_train 901 | 902 | if n_ctx <= 0: 903 | log_verbose(f'n_ctx value {n_ctx}; using n_ctx_train value {n_ctx_train}') 904 | _n_ctx = int(n_ctx_train) 905 | else: 906 | _n_ctx = int(n_ctx) 907 | 908 | # use rope_freq_base unless it == 0.0, in that case use the native 909 | # rope_freq_base found in the GGUF metadata 910 | rope_freq_base = kwargs.get('rope_freq_base', 0.0) 911 | 912 | if rope_freq_base == 0.0: 913 | rope_freq_base_train = None 914 | for key in self.metadata.keys(): 915 | if key.endswith('.rope.freq_base'): 916 | rope_freq_base_train = float(self.metadata[key]) 917 | 918 | # NOTE: if n_ctx > n_ctx_train, then rope_freq_base must also be 919 | # increased by at least a proportional amount to guarantee a 920 | # usable kv cache throughout the entire context 921 | # 922 | # the function _calculate_rope_freq_base handles this 923 | 924 | _rope_freq_base = _calculate_rope_freq_base( 925 | n_ctx_train=n_ctx_train, 926 | n_ctx_load=_n_ctx, 927 | rope_freq_base_train=rope_freq_base_train # can be None 928 | ) 929 | else: 930 | _rope_freq_base = rope_freq_base 931 | 932 | _n_threads = n_threads if n_threads > 0 else _get_optimal_n_threads() 933 | _n_threads_batch = n_threads_batch if n_threads_batch > 0 else ( 934 | _get_optimal_n_threads_batch() 935 | ) 936 | 937 | # 938 | # New context with model 939 | # 940 | 941 | self._ctx = _LlamaCtx( 942 | model = self._model, 943 | n_ctx = _n_ctx, 944 | n_batch = kwargs.get('n_batch'), 945 | n_ubatch = kwargs.get('n_ubatch'), 946 | # n_seq_max = kwargs.get('n_seq_max'), # uncomment this if n_seq_max gets supported 947 | n_threads = _n_threads, 948 | n_threads_batch = _n_threads_batch, 949 | rope_scaling_type = kwargs.get('rope_scaling_type'), 950 | pooling_type = kwargs.get('pooling_type'), 951 | attention_type = kwargs.get('attention_type'), 952 | rope_freq_base = _rope_freq_base, 953 | rope_freq_scale = kwargs.get('rope_freq_scale'), 954 | yarn_ext_factor = kwargs.get('yarn_ext_factor'), 955 | yarn_attn_factor = kwargs.get('yarn_attn_factor'), 956 | yarn_beta_fast = kwargs.get('yarn_beta_fast'), 957 | yarn_beta_slow = kwargs.get('yarn_beta_slow'), 958 | yarn_orig_ctx = kwargs.get('yarn_orig_ctx'), 959 | defrag_thold = kwargs.get('defrag_thold'), 960 | type_k = type_k, 961 | type_v = type_v, 962 | embeddings = kwargs.get('embeddings'), 963 | offload_kqv = offload_kqv, 964 | flash_attn = flash_attn, 965 | no_perf = kwargs.get('no_perf'), 966 | op_offload = kwargs.get('op_offload'), 967 | swa_full = kwargs.get('swa_full') 968 | ) 969 | 970 | # 971 | # Display warnings about n_ctx if necessary 972 | # 973 | 974 | actual_n_ctx = self.n_ctx() 975 | requested_n_ctx = _n_ctx 976 | 977 | if actual_n_ctx != requested_n_ctx: 978 | log( 979 | f"requested n_ctx value differs from actual n_ctx value; " 980 | f"requested {requested_n_ctx}, actual {actual_n_ctx}", 2 981 | ) 982 | if actual_n_ctx < 512: 983 | log( 984 | f"n_ctx value {actual_n_ctx} is less than 512, which can " 985 | f"sometimes cause problems with llama.cpp - consider " 986 | f"increasing it to at least 512", 2 987 | ) 988 | if actual_n_ctx % 512 != 0: 989 | log( 990 | f"n_ctx value {actual_n_ctx} is not divisible by 512, which " 991 | f"can sometimes cause problems with llama.cpp - consider " 992 | f"changing it to " 993 | f"{_round_n_ctx(actual_n_ctx, n_ctx_train)}", 2 994 | ) 995 | # warn about default context length 996 | if actual_n_ctx == 512: 997 | log( 998 | f'you are using the default n_ctx value {actual_n_ctx}, which ' 999 | f'is very small. increase n_ctx as needed to support longer ' 1000 | f'inputs and outputs.', 2 1001 | ) 1002 | 1003 | self._stopwatch = _LlamaStopwatch() 1004 | 1005 | # 1006 | # Store Llama metadata as attributes for faster access internally 1007 | # 1008 | 1009 | self._name = self.name() 1010 | self._n_ctx = self.n_ctx() 1011 | self._n_batch = self.n_batch() 1012 | self._n_ubatch = self.n_ubatch() 1013 | self._n_seq_max = self.n_seq_max() 1014 | self._n_vocab = self.n_vocab() 1015 | self._n_ctx_train = self.n_ctx_train() 1016 | self._n_embd = self.n_embd() 1017 | self._n_layer = self.n_layer() 1018 | self._n_head = self.n_head() 1019 | self._n_head_kv = self.n_head_kv() 1020 | self._pooling_type = self.pooling_type() 1021 | self._n_swa = self.n_swa() 1022 | self._vocab_type = self.vocab_type() 1023 | self._rope_type = self.rope_type() 1024 | self._rope_freq_scale_train = self.rope_freq_scale_train() 1025 | self._model_size_bytes = self.model_size_bytes() 1026 | self._chat_template = self.chat_template() 1027 | self._n_params = self.n_params() 1028 | self._bpw = self.bpw() 1029 | self._has_encoder = self.has_encoder() 1030 | self._has_decoder = self.has_decoder() 1031 | self._is_recurrent = self.is_recurrent() 1032 | self._token_bos = self.token_bos() 1033 | self._token_eos = self.token_eos() 1034 | self._token_eot = self.token_eot() 1035 | self._token_sep = self.token_sep() 1036 | self._token_nl = self.token_nl() 1037 | self._token_pad = self.token_pad() 1038 | self._add_bos_token = self.add_bos_token() 1039 | self._add_eos_token = self.add_eos_token() 1040 | self._token_fim_pre = self.token_fim_pre() 1041 | self._token_fim_suf = self.token_fim_suf() 1042 | self._token_fim_mid = self.token_fim_mid() 1043 | self._token_fim_pad = self.token_fim_pad() 1044 | self._token_fim_rep = self.token_fim_rep() 1045 | self._token_fim_sep = self.token_fim_sep() 1046 | 1047 | self.eog_tokens = [i for i in range(self._n_vocab) if self.token_is_eog(i)] 1048 | """A list of all tokens in the vocab that are marked as EOG 1049 | (End-Of-Generation)""" 1050 | 1051 | # internal use only - the default SamplerParams with this model 1052 | self._default_sampler_params = SamplerParams(self) 1053 | 1054 | self.pos = 0 1055 | """The current position of the model within the context window""" 1056 | 1057 | self.context_tokens = [] 1058 | """A list of all tokens currently in the context window""" 1059 | 1060 | self._lock = _InferenceLock() 1061 | 1062 | if warmup: 1063 | self.warmup() 1064 | 1065 | # End of Llama.__init__ 1066 | 1067 | def __repr__(self) -> str: 1068 | return ( 1069 | f"Llama(" 1070 | f"path_model={self._model.path_model!r}, " 1071 | f"n_gpu_layers={self._model.params.n_gpu_layers}, " 1072 | f"n_ctx={self._n_ctx}, " 1073 | f"type_k={self._ctx.params.type_k}, " 1074 | f"type_v={self._ctx.params.type_v}, " 1075 | f"offload_kqv={self._ctx.params.offload_kqv}, " 1076 | f"flash_attn={self._ctx.params.flash_attn}" 1077 | f")" 1078 | ) 1079 | 1080 | def free(self): 1081 | """Deallocate the context and model""" 1082 | self._ctx.free() 1083 | self._model.free() 1084 | 1085 | def _validate_model_state(self) -> None: 1086 | """Ensure `llama_model`, `llama_vocab` and `llama_context` are not NULL and validate 1087 | `Llama.pos`""" 1088 | null_ptr_check(self._model.model, 'self._model.model', '_validate_model_state') 1089 | null_ptr_check(self._vocab, 'self._vocab', '_validate_model_state') 1090 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', '_validate_model_state') 1091 | 1092 | _n_context_tokens = len(self.context_tokens) 1093 | _pos = self.pos 1094 | 1095 | if _pos < 0: 1096 | self.pos = 0 1097 | self.context_tokens = [] 1098 | log_verbose( 1099 | f'self.pos value was {self.pos} - clamping to 0. the KV cache has been reset.', 1100 | 2 1101 | ) 1102 | elif _pos != _n_context_tokens: 1103 | self.pos = 0 1104 | self.context_tokens = [] 1105 | log_verbose( 1106 | f'n_context_tokens {_n_context_tokens} did not match self.pos {_pos}. the KV ' 1107 | f'cache has been reset.', 2 1108 | ) 1109 | if not hasattr(self, '_default_sampler_params'): 1110 | self._default_sampler_params = SamplerParams(self) 1111 | log_verbose( 1112 | "Llama._default_sampler_params was destroyed but has been recreated", 2 1113 | ) 1114 | 1115 | def warmup(self) -> None: 1116 | """Warm-up the model. This also clears the KV cache.""" 1117 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.warmup") 1118 | 1119 | with suppress_output(disable=get_verbose()): 1120 | self.reset() 1121 | 1122 | lib.llama_set_warmup(self._ctx.ctx, True) 1123 | 1124 | log_verbose('warmup: single token decode ...') 1125 | with self._lock: 1126 | _internals.decode_tg(self._ctx.ctx, 0, 0) 1127 | 1128 | # This section decodes a full batch of tokens, but is probably unnecessary. 1129 | # 1130 | # with suppress_output(disable=get_verbose()): 1131 | # self.reset() 1132 | # 1133 | # log_verbose('warmup: full batch decode ...') 1134 | # with self._lock: 1135 | # _internals.decode_pp(self._ctx.ctx, 0, [0] * self._n_batch, self._n_batch) 1136 | 1137 | lib.llama_set_warmup(self._ctx.ctx, False) 1138 | 1139 | with suppress_output(disable=get_verbose()): 1140 | self.reset() 1141 | 1142 | self.pos = 0 1143 | log_verbose('warmup: done') 1144 | 1145 | def n_ctx(self) -> int: 1146 | """Get the current context length""" 1147 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_ctx") 1148 | return lib.llama_n_ctx(self._ctx.ctx) 1149 | 1150 | def n_batch(self) -> int: 1151 | """Get the current batch size""" 1152 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_batch") 1153 | return lib.llama_n_batch(self._ctx.ctx) 1154 | 1155 | def n_ubatch(self) -> int: 1156 | """Get the current micro-batch size""" 1157 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_batch") 1158 | return lib.llama_n_ubatch(self._ctx.ctx) 1159 | 1160 | def n_seq_max(self) -> int: 1161 | """Get the max number of sequences""" 1162 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_seq_max") 1163 | return lib.llama_n_seq_max(self._ctx.ctx) 1164 | 1165 | def n_vocab(self) -> int: 1166 | """Get the vocab size""" 1167 | null_ptr_check(self._vocab, "self._vocab", "Llama.n_vocab") 1168 | return lib.llama_vocab_n_tokens(self._vocab) 1169 | 1170 | def n_ctx_train(self) -> int: 1171 | """Get the trained context length""" 1172 | null_ptr_check(self._model.model, 'self._model.model', 'Llama.n_ctx_train') 1173 | return lib.llama_model_n_ctx_train(self._model.model) 1174 | 1175 | def n_embd(self) -> int: 1176 | """Get the embedding size""" 1177 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_embd") 1178 | return lib.llama_model_n_embd(self._model.model) 1179 | 1180 | def n_layer(self) -> int: 1181 | """Get the number of layers""" 1182 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_layer") 1183 | return lib.llama_model_n_layer(self._model.model) 1184 | 1185 | def n_head(self) -> int: 1186 | """Get the number of attention heads""" 1187 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_head") 1188 | return lib.llama_model_n_head(self._model.model) 1189 | 1190 | def n_head_kv(self) -> int: 1191 | """Get the number of KV heads""" 1192 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_head_kv") 1193 | return lib.llama_model_n_head_kv(self._model.model) 1194 | 1195 | def pooling_type(self) -> int: 1196 | """Get the pooling type used by the context""" 1197 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.pooling_type") 1198 | return lib.llama_pooling_type(self._ctx.ctx) 1199 | 1200 | def n_swa(self) -> int: 1201 | """Get the sliding window size, for models which use SWA.""" 1202 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_swa") 1203 | return lib.llama_model_n_swa(self._model.model) 1204 | 1205 | def vocab_type(self) -> int: 1206 | """Get the vocab type""" 1207 | null_ptr_check(self._vocab, "self._vocab", "Llama.vocab_type") 1208 | return lib.llama_vocab_type(self._vocab) 1209 | 1210 | def rope_type(self) -> int: 1211 | """Get the RoPE type""" 1212 | null_ptr_check(self._model.model, "self._model.model", "Llama.rope_type") 1213 | return lib.llama_model_rope_type(self._model.model) 1214 | 1215 | def rope_freq_scale_train(self) -> float: 1216 | """Get the trained RoPE frequency scale""" 1217 | null_ptr_check(self._model.model, "self._model.model", "Llama.rope_freq_scale_train") 1218 | return lib.llama_model_rope_freq_scale_train(self._model.model) 1219 | 1220 | def model_size_bytes(self) -> int: 1221 | """Get the total size of all tensors in the model, in bytes""" 1222 | null_ptr_check(self._model.model, "self._model.model", "Llama.model_size_bytes") 1223 | return lib.llama_model_size(self._model.model) 1224 | 1225 | def chat_template(self) -> Optional[str]: 1226 | """Get the model's built-in chat template string. Returns None if not available.""" 1227 | null_ptr_check(self._model.model, "self._model.model", "Llama.chat_template") 1228 | return lib.llama_model_chat_template(self._model.model, name=None) 1229 | 1230 | def n_params(self) -> int: 1231 | """Get the total number of parameters in the model""" 1232 | null_ptr_check(self._model.model, "self._model.model", "Llama.n_params") 1233 | return lib.llama_model_n_params(self._model.model) 1234 | 1235 | def has_encoder(self) -> bool: 1236 | """If the model has an encoder""" 1237 | null_ptr_check(self._model.model, "self._model.model", "Llama.has_encoder") 1238 | return lib.llama_model_has_encoder(self._model.model) 1239 | 1240 | def has_decoder(self) -> bool: 1241 | """If the model has a decoder""" 1242 | null_ptr_check(self._model.model, "self._model.model", "Llama.has_decoder") 1243 | return lib.llama_model_has_decoder(self._model.model) 1244 | 1245 | def is_recurrent(self) -> bool: 1246 | """If the model is recurrent""" 1247 | null_ptr_check(self._model.model, "self._model.model", "Llama.is_recurrent") 1248 | return lib.llama_model_is_recurrent(self._model.model) 1249 | 1250 | # 1251 | # KV cache management methods 1252 | # 1253 | 1254 | def kv_cache_clear(self) -> None: 1255 | """Clear the KV cache""" 1256 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_clear") 1257 | lib.llama_kv_self_clear(self._ctx.ctx) 1258 | 1259 | def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int) -> bool: 1260 | """Remove tokens from a sequence in the KV cache""" 1261 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_rm") 1262 | return lib.llama_kv_self_seq_rm(self._ctx.ctx, seq_id, p0, p1) 1263 | 1264 | def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int) -> None: 1265 | """Copy tokens between sequences in the KV cache""" 1266 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_cp") 1267 | lib.llama_kv_self_seq_cp(self._ctx.ctx, seq_id_src, seq_id_dst, p0, p1) 1268 | 1269 | def kv_cache_seq_keep(self, seq_id: int) -> None: 1270 | """Remove all tokens except for the ones in this sequence""" 1271 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_keep") 1272 | lib.llama_kv_self_seq_keep(self._ctx.ctx, seq_id) 1273 | 1274 | def kv_cache_seq_add(self, seq_id: int, p0: int, p1: int, delta: int) -> None: 1275 | """Add relative position "delta" to the tokens""" 1276 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_add") 1277 | lib.llama_kv_self_seq_add(self._ctx.ctx, seq_id, p0, p1, delta) 1278 | 1279 | def kv_cache_seq_div(self, seq_id: int, p0: int, p1: int, d: int) -> None: 1280 | """Integer division of the positions by factor of `d > 1`""" 1281 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_div") 1282 | lib.llama_kv_self_seq_div(self._ctx.ctx, seq_id, p0, p1, d) 1283 | 1284 | def kv_cache_seq_pos_min(self, seq_id: int) -> int: 1285 | """Returns the earliest valid position in the KV cache for the specified sequence 1286 | (relevant for models which use SWA)""" 1287 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_pos_max") 1288 | return lib.llama_kv_self_seq_pos_min(self._ctx.ctx, seq_id) 1289 | 1290 | def kv_cache_seq_pos_max(self, seq_id: int) -> int: 1291 | """Returns the largest position present in the KV cache for the specified sequence""" 1292 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_seq_pos_max") 1293 | return lib.llama_kv_self_seq_pos_max(self._ctx.ctx, seq_id) 1294 | 1295 | def kv_cache_can_shift(self) -> bool: 1296 | """Check if the context supports KV cache shifting""" 1297 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.kv_cache_can_shift") 1298 | return lib.llama_kv_self_can_shift(self._ctx.ctx) 1299 | 1300 | def n_threads(self) -> int: 1301 | """Get the number of threads used for batch size == 1""" 1302 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_threads") 1303 | return lib.llama_n_threads(self._ctx.ctx) 1304 | 1305 | def n_threads_batch(self) -> int: 1306 | """Get the number of threads used for batch sizes > 1""" 1307 | null_ptr_check(self._ctx.ctx, "self._ctx.ctx", "Llama.n_threads_batch") 1308 | return lib.llama_n_threads_batch(self._ctx.ctx) 1309 | 1310 | def token_get_score(self, token: int) -> float: 1311 | """Get the score of a token""" 1312 | null_ptr_check(self._vocab, "self._vocabl", "Llama.token_get_score") 1313 | return lib.llama_vocab_get_score(self._vocab, token) 1314 | 1315 | def token_is_eog(self, token: int) -> bool: 1316 | """If the token is marked as EOG (End-Of-Generation)""" 1317 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_is_eog") 1318 | return lib.llama_vocab_is_eog(self._vocab, token) 1319 | 1320 | def token_is_control(self, token: int) -> bool: 1321 | """If the token is marked as a control token""" 1322 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_is_control") 1323 | return lib.llama_vocab_is_control(self._vocab, token) 1324 | 1325 | def token_bos(self) -> Optional[int]: 1326 | """Get the BOS (Beginning-Of-Sequence) token. Return None if not available.""" 1327 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_bos") 1328 | id = lib.llama_vocab_bos(self._vocab) 1329 | return id if id != lib.LLAMA_TOKEN_NULL else None 1330 | 1331 | def token_eos(self) -> Optional[int]: 1332 | """Get the EOS (End-Of-Sequence) token. Return None if not available.""" 1333 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_eos") 1334 | id = lib.llama_vocab_eos(self._vocab) 1335 | return id if id != lib.LLAMA_TOKEN_NULL else None 1336 | 1337 | def token_eot(self) -> Optional[int]: 1338 | """Get the EOT (End-Of-Turn) token. Return None if not available.""" 1339 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_eot") 1340 | id = lib.llama_vocab_eot(self._vocab) 1341 | return id if id != lib.LLAMA_TOKEN_NULL else None 1342 | 1343 | def token_sep(self) -> Optional[int]: 1344 | """Get the SEP (Separator) token. Return None if not available.""" 1345 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_sep") 1346 | id = lib.llama_vocab_sep(self._vocab) 1347 | return id if id != lib.LLAMA_TOKEN_NULL else None 1348 | 1349 | def token_nl(self) -> Optional[int]: 1350 | """Get the NL (Newline) token. Return None if not available.""" 1351 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_nl") 1352 | id = lib.llama_vocab_nl(self._vocab) 1353 | return id if id != lib.LLAMA_TOKEN_NULL else None 1354 | 1355 | def token_pad(self) -> Optional[int]: 1356 | """Get the PAD (Padding) token. Return None if not available.""" 1357 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_pad") 1358 | id = lib.llama_vocab_pad(self._vocab) 1359 | return id if id != lib.LLAMA_TOKEN_NULL else None 1360 | 1361 | def add_bos_token(self) -> bool: 1362 | """If the model is configured to add a BOS token""" 1363 | null_ptr_check(self._vocab, "self._vocab", "Llama.add_bos_token") 1364 | return lib.llama_vocab_get_add_bos(self._vocab) 1365 | 1366 | def add_eos_token(self) -> bool: 1367 | """If the model is configured to add an EOS token""" 1368 | null_ptr_check(self._vocab, "self._vocab", "Llama.add_eos_token") 1369 | return lib.llama_vocab_get_add_eos(self._vocab) 1370 | 1371 | def token_fim_pre(self) -> Optional[int]: 1372 | """Get the FIM PRE (Fill-In-Middle Prefix) token. Return None if not available.""" 1373 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_pre") 1374 | id = lib.llama_vocab_fim_pre(self._vocab) 1375 | return id if id != lib.LLAMA_TOKEN_NULL else None 1376 | 1377 | def token_fim_suf(self) -> Optional[int]: 1378 | """Get the FIM SUF (Fill-In-Middle Suffix) token. Return None if not available.""" 1379 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_suf") 1380 | id = lib.llama_vocab_fim_suf(self._vocab) 1381 | return id if id != lib.LLAMA_TOKEN_NULL else None 1382 | 1383 | def token_fim_mid(self) -> Optional[int]: 1384 | """Get the FIM MID (Fill-In-Middle Middle) token. Return None if not available.""" 1385 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_mid") 1386 | id = lib.llama_vocab_fim_mid(self._vocab) 1387 | return id if id != lib.LLAMA_TOKEN_NULL else None 1388 | 1389 | def token_fim_pad(self) -> Optional[int]: 1390 | """Get the FIM PAD (Fill-In-Middle Padding) token. Return None if not available.""" 1391 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_pad") 1392 | id = lib.llama_vocab_fim_pad(self._vocab) 1393 | return id if id != lib.LLAMA_TOKEN_NULL else None 1394 | 1395 | def token_fim_rep(self) -> Optional[int]: 1396 | """Get the FIM REP (Fill-In-Middle Repository) token. Return None if not available.""" 1397 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_rep") 1398 | id = lib.llama_vocab_fim_rep(self._vocab) 1399 | return id if id != lib.LLAMA_TOKEN_NULL else None 1400 | 1401 | def token_fim_sep(self) -> Optional[int]: 1402 | """Get the FIM SEP (Fill-In-Middle Separator) token. Return None if not available.""" 1403 | null_ptr_check(self._vocab, "self._vocab", "Llama.token_fim_sep") 1404 | id = lib.llama_vocab_fim_sep(self._vocab) 1405 | return id if id != lib.LLAMA_TOKEN_NULL else None 1406 | 1407 | def tokenize( 1408 | self, 1409 | text_bytes: bytes, 1410 | add_special: bool, 1411 | parse_special: bool, 1412 | ) -> list[int]: 1413 | """Convert the provided UTF-8 encoded text into tokens 1414 | 1415 | - text_bytes: 1416 | The text to be tokenized 1417 | - add_special: 1418 | Allow to add BOS and EOS tokens if model is configured to do so. 1419 | - parse_special: 1420 | Allow tokenizing special and/or control tokens which otherwise are 1421 | not exposed and treated as plaintext. Does not insert a leading 1422 | space.""" 1423 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.tokenize') 1424 | n_tokens = _internals.get_length( 1425 | vocab=self._vocab, 1426 | text_bytes=text_bytes, 1427 | add_special=add_special, 1428 | parse_special=parse_special 1429 | ) 1430 | return _internals.tokenize( 1431 | vocab=self._vocab, 1432 | text_bytes=text_bytes, 1433 | n_tokens_max=n_tokens, 1434 | add_special=add_special, 1435 | parse_special=parse_special 1436 | ) 1437 | 1438 | def token_to_piece(self, token: int, special: bool) -> bytes: 1439 | """Convert a single token ID into UTF-8 bytes 1440 | 1441 | - special: 1442 | If True, special tokens are rendered in the output""" 1443 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.token_to_piece') 1444 | return _internals.token_to_piece( 1445 | vocab=self._vocab, 1446 | token=token, 1447 | special=special 1448 | ) 1449 | 1450 | def detokenize( 1451 | self, 1452 | tokens: list[int], 1453 | special: bool 1454 | ) -> str: 1455 | """Convert the provided tokens into UTF-8 encoded text 1456 | 1457 | - special: 1458 | If True, special tokens are rendered in the output""" 1459 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.detokenize') 1460 | return _internals.detokenize( 1461 | vocab=self._vocab, 1462 | tokens=tokens, 1463 | special=special 1464 | ) 1465 | 1466 | def get_length( 1467 | self, 1468 | text_bytes: bytes, 1469 | add_special: bool, 1470 | parse_special: bool, 1471 | ) -> int: 1472 | """Return the length of a given text as measured in tokens""" 1473 | null_ptr_check(self._vocab, 'self._vocab', 'Llama.get_length') 1474 | return _internals.get_length( 1475 | vocab=self._vocab, 1476 | text_bytes=text_bytes, 1477 | add_special=add_special, 1478 | parse_special=parse_special 1479 | ) 1480 | 1481 | # TODO - Chat templating functions here 1482 | 1483 | def _first_valid_pos(self, tokens: list[int]) -> int: 1484 | """Given a list of tokens, and using `Llama.context_tokens`, find the first 1485 | valid `Llama.pos` 1486 | 1487 | In other words, return length of the longest common prefix between the 1488 | two lists of tokens. 1489 | 1490 | Returns 0 if none of the tokens match, 1 if one token matches, etc.""" 1491 | i = 0 1492 | for c, t in zip(self.context_tokens, tokens): 1493 | if c == t: 1494 | i += 1 1495 | else: 1496 | break 1497 | return i 1498 | 1499 | def _set_cache_tokens(self, input_tokens: list[int]) -> list[int]: 1500 | n_input_tokens = len(input_tokens) 1501 | 1502 | if n_input_tokens == 0: 1503 | raise ValueError(f'Llama._set_cache_tokens: input_tokens cannot be empty') 1504 | 1505 | # find how many tokens in the input are already in the KV cache 1506 | self.pos = self._first_valid_pos(input_tokens) 1507 | 1508 | if self.pos > self._n_ctx: 1509 | raise ExceededContextLengthException( 1510 | f'Llama._set_cache_tokens: no valid position within context window ' 1511 | f'{self._n_ctx} (self.pos = {self.pos})' 1512 | ) 1513 | 1514 | # remove all tokens in the KV cache that are past that point 1515 | self.kv_cache_seq_rm(0, self.pos, -1) 1516 | 1517 | # tokens already in the KV cache 1518 | self.context_tokens = input_tokens[:self.pos] 1519 | 1520 | # tokens returned to caller for processing 1521 | actual_input_tokens = input_tokens[self.pos:] 1522 | 1523 | return actual_input_tokens 1524 | 1525 | def _process_batch( 1526 | self, 1527 | batch_tokens: list[int], 1528 | logits_all: bool = False 1529 | ) -> Optional[np.ndarray]: 1530 | """Process a batch of one or more tokens, up to `Llama.n_batch()`. If `logits_all` is 1531 | True, return the logits for all tokens in the batch. Otherwise, return None. 1532 | 1533 | This function is used by `Llama.eval`, `Llama.generate`, `Llama.stream`, etc.""" 1534 | 1535 | batch_logits = None 1536 | n_batch_tokens = len(batch_tokens) 1537 | log_debug(f'Llama._process_batch: processing {batch_tokens}') 1538 | 1539 | if n_batch_tokens > self._n_batch: 1540 | raise ValueError( 1541 | f'Llama._process_batch: n_batch_tokens cannot exceed n_batch ' 1542 | f'({n_batch_tokens} > {self._n_batch})' 1543 | ) 1544 | 1545 | if n_batch_tokens > 1: # prompt processing 1546 | if logits_all: 1547 | with self._lock: 1548 | self._stopwatch.start_pp() 1549 | batch_logits = _internals.decode_pp_with_logits( 1550 | self._ctx.ctx, self.pos, batch_tokens, n_batch_tokens, self._n_vocab 1551 | ) 1552 | self._stopwatch.stop_pp() 1553 | 1554 | else: 1555 | with self._lock: 1556 | self._stopwatch.start_pp() 1557 | _internals.decode_pp(self._ctx.ctx, self.pos, batch_tokens, n_batch_tokens) 1558 | self._stopwatch.stop_pp() 1559 | 1560 | self._stopwatch.increment_pp_tokens(n_batch_tokens) 1561 | 1562 | elif n_batch_tokens == 1: # text generation 1563 | with self._lock: 1564 | self._stopwatch.start_tg() 1565 | batch_logits = _internals.decode_tg_with_logits( 1566 | self._ctx.ctx, self.pos, batch_tokens[0], self._n_vocab 1567 | ) 1568 | self._stopwatch.stop_tg() 1569 | 1570 | self._stopwatch.increment_tg_tokens(1) 1571 | 1572 | else: 1573 | raise RuntimeError( 1574 | f'Llama._process_batch: unexpected n_batch_tokens value {n_batch_tokens}' 1575 | ) 1576 | 1577 | # update the Llama position and context 1578 | self.pos += n_batch_tokens 1579 | self.context_tokens.extend(batch_tokens) 1580 | 1581 | return batch_logits 1582 | 1583 | def eval( 1584 | self, 1585 | input_tokens: list[int], 1586 | logits_all: bool = False 1587 | ) -> np.ndarray: 1588 | """Evaluate the given tokens and update the model state. 1589 | 1590 | If `logits_all` is True, return the logits for all `input_tokens`. Otherwise, only 1591 | return the logits for last token (which are the predictions for the next token).""" 1592 | 1593 | self._stopwatch.reset() 1594 | self._stopwatch.start_wall_time() 1595 | 1596 | n_input_tokens = len(input_tokens) 1597 | 1598 | if logits_all: 1599 | if self._first_valid_pos(input_tokens) > 0: 1600 | log( 1601 | f'Llama.eval: the KV cache will be cleared in order to compute the logits ' 1602 | f'for all tokens in the input', 2 1603 | ) 1604 | 1605 | log_verbose(f'Llama.eval: {n_input_tokens} tokens to eval ...') 1606 | 1607 | self.reset() 1608 | actual_input_tokens = input_tokens 1609 | n_actual_input_tokens = len(input_tokens) 1610 | else: 1611 | actual_input_tokens = self._set_cache_tokens(input_tokens) 1612 | 1613 | n_actual_input_tokens = len(actual_input_tokens) 1614 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens 1615 | 1616 | log_verbose( 1617 | f'Llama.eval: {n_cache_hit_tokens} tokens in cache, ' 1618 | f'{n_actual_input_tokens} tokens to eval ...' 1619 | ) 1620 | 1621 | if n_actual_input_tokens == 0: 1622 | return self.get_logits() 1623 | 1624 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch) 1625 | 1626 | # process each batch one-by-one 1627 | if logits_all: 1628 | all_logits = [] 1629 | for batch in _batches_with_progress_bar(batches): 1630 | batch_logits = self._process_batch(batch, logits_all=True) 1631 | all_logits.append(batch_logits) 1632 | final_logits = np.concatenate(all_logits, axis=0) 1633 | else: 1634 | for batch in _batches_with_progress_bar(batches): 1635 | final_logits = self._process_batch(batch, logits_all=False) 1636 | 1637 | self._stopwatch.stop_wall_time() 1638 | 1639 | if get_verbose(): 1640 | self._stopwatch.print_stats() 1641 | 1642 | return final_logits 1643 | 1644 | def generate_single( 1645 | self, 1646 | input_tokens: list[int], 1647 | sampler_preset: Optional[SamplerPreset] = None, 1648 | return_logits: bool = False 1649 | ) -> Union[int, np.ndarray]: 1650 | """Generate a single token 1651 | 1652 | - input_tokens: 1653 | The tokens to evaluate 1654 | - sampler_preset: 1655 | The `SamplerPreset` object to use for sampling. If not specified, 1656 | use the model's default sampler parameters 1657 | - return_logits: 1658 | If True, return the logits for the generated token instead of the token ID.""" 1659 | 1660 | self._stopwatch.reset() 1661 | self._stopwatch.start_wall_time() 1662 | 1663 | n_input_tokens = len(input_tokens) 1664 | 1665 | actual_input_tokens = self._set_cache_tokens(input_tokens) 1666 | 1667 | n_actual_input_tokens = len(actual_input_tokens) 1668 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens 1669 | 1670 | if sampler_preset is None: 1671 | sampler_params = self._default_sampler_params 1672 | else: 1673 | sampler_params = self.sampler_params_from_preset(sampler_preset) 1674 | 1675 | if get_verbose(): 1676 | sampler_params.print_chain() 1677 | 1678 | log_verbose( 1679 | f'Llama.generate_single: {n_cache_hit_tokens} tokens in cache, ' 1680 | f'{n_actual_input_tokens} tokens to eval ...' 1681 | ) 1682 | 1683 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch) 1684 | 1685 | # process each batch one-by-one 1686 | for batch in _batches_with_progress_bar(batches): 1687 | self._process_batch(batch, logits_all=False) 1688 | 1689 | self._stopwatch.stop_wall_time() 1690 | 1691 | if get_verbose(): 1692 | self._stopwatch.print_stats() 1693 | 1694 | if return_logits: # TODO: this is inefficient, it decodes the last token again. replace. 1695 | return self.get_logits() 1696 | 1697 | return self.sample(sampler_params) 1698 | 1699 | def generate( 1700 | self, 1701 | input_tokens: list[int], 1702 | n_predict: int, 1703 | stop_tokens: Optional[list[int]] = None, 1704 | sampler_preset: Optional[SamplerPreset] = None, 1705 | return_logits: bool = False 1706 | ) -> Union[list[int], np.ndarray]: 1707 | """Generate new tokens and return them all at once 1708 | 1709 | - input_tokens: 1710 | The tokens to evaluate 1711 | - n_predict: 1712 | The number of tokens to predict. If `n_predict < 0`, then the number of tokens 1713 | predicted is only limited by the context length. If `n_predict == 0`, then no new 1714 | tokens will be predicted, but the input_tokens will still be processed. 1715 | - stop_tokens: 1716 | A list of token IDs that will end the generation early. Note that the stop token 1717 | will be included in the output. If this parameter is None, all built-in stop tokens 1718 | for the model will be used. Pass an empty list `[]` to ignore all stop tokens. 1719 | - sampler_preset: 1720 | The `SamplerPreset` object to use for sampling. If not specified, use the model's 1721 | default sampler parameters. 1722 | - return_logits: 1723 | If True, return the logits for the generated tokens instead of the token IDs. Note 1724 | that this incurs a slight performance penalty.""" 1725 | 1726 | self._stopwatch.reset() 1727 | self._stopwatch.start_wall_time() 1728 | 1729 | n_input_tokens = len(input_tokens) 1730 | 1731 | actual_input_tokens = self._set_cache_tokens(input_tokens) 1732 | 1733 | n_actual_input_tokens = len(actual_input_tokens) 1734 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens 1735 | 1736 | _stop_tokens = stop_tokens if stop_tokens is not None else self.eog_tokens 1737 | 1738 | if sampler_preset is None: 1739 | sampler_params = self._default_sampler_params 1740 | else: 1741 | sampler_params = self.sampler_params_from_preset(sampler_preset) 1742 | 1743 | if get_verbose(): 1744 | sampler_params.print_chain() 1745 | 1746 | _n_predict = n_predict if n_predict >= 0 else self._n_ctx - self.pos 1747 | 1748 | log_verbose( 1749 | f'Llama.generate: {n_cache_hit_tokens} tokens in cache, {n_actual_input_tokens} ' 1750 | f'tokens to eval ...' 1751 | ) 1752 | 1753 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch) 1754 | 1755 | log_debug('Llama.generate: start processing input batches') 1756 | 1757 | # process each input batch one-by-one 1758 | for batch in _batches_with_progress_bar(batches): 1759 | self._process_batch(batch, logits_all=False) 1760 | 1761 | log_debug('Llama.generate: done processing input batches') 1762 | 1763 | if _n_predict == 0: 1764 | self._stopwatch.stop_wall_time() 1765 | if get_verbose(): 1766 | self._stopwatch.print_stats() 1767 | return [] 1768 | 1769 | predicted_tokens = [] 1770 | if return_logits: 1771 | predicted_logits = [] 1772 | n_predicted = 0 1773 | 1774 | log_verbose(f'Llama.generate: predicting up to {_n_predict} new tokens ...') 1775 | log_debug(f'Llama.generate: enter while loop') 1776 | 1777 | while n_predicted < _n_predict: 1778 | # sample a token from the latest logits 1779 | sampled_token = self.sample(sampler_params) 1780 | 1781 | # save the sampled token as a prediction 1782 | predicted_tokens.append(sampled_token) 1783 | n_predicted += 1 1784 | 1785 | # if it's a stop token, stop generating 1786 | if sampled_token in _stop_tokens: 1787 | if get_verbose(): 1788 | tok_str = ez_decode(self.token_to_piece(sampled_token, True)) 1789 | print() 1790 | log(f'inferred stop token {sampled_token} ({tok_str!r})') 1791 | break 1792 | 1793 | # decode the sampled token to get the logits for the following token 1794 | if return_logits: 1795 | logits = self._process_batch([sampled_token], True) 1796 | predicted_logits.append(logits) 1797 | else: 1798 | self._process_batch([sampled_token]) 1799 | 1800 | # done generating, show stopwatch stats and return 1801 | self._stopwatch.stop_wall_time() 1802 | log_debug(f'Llama.generate: exited while loop') 1803 | if get_verbose(): 1804 | self._stopwatch.print_stats() 1805 | if return_logits: 1806 | return np.stack(predicted_logits, axis=0) 1807 | return predicted_tokens 1808 | 1809 | def stream( 1810 | self, 1811 | input_tokens: list[int], 1812 | n_predict: int, 1813 | stop_tokens: Optional[list[int]] = None, 1814 | sampler_preset: Optional[SamplerPreset] = None, 1815 | yield_logits: bool = False 1816 | ) -> Iterable[Union[int, np.ndarray]]: 1817 | """Return a Generator which yields tokens as they are generated 1818 | 1819 | - input_tokens: 1820 | The tokens to evaluate 1821 | - n_predict: 1822 | The number of tokens to predict. If `n_predict < 0`, then the number of tokens 1823 | predicted is only limited by the context length. If `n_predict == 0`, then no new 1824 | tokens will be predicted, but the input_tokens will still be processed. 1825 | - stop_tokens: 1826 | A list of token IDs that will end the generation early. Note that 1827 | the stop token will be included in the output. If this parameter is 1828 | None, all built-in stop tokens for the model will be used. Pass an 1829 | empty list `[]` to ignore all stop tokens. 1830 | - sampler_preset: 1831 | The `SamplerPreset` object to use for sampling. If not specified, 1832 | use the model's default sampler parameters 1833 | - yield_logits: 1834 | If True, yield the logits for the generated tokens instead of the token IDs""" 1835 | 1836 | self._stopwatch.reset() 1837 | self._stopwatch.start_wall_time() 1838 | 1839 | n_input_tokens = len(input_tokens) 1840 | 1841 | actual_input_tokens = self._set_cache_tokens(input_tokens) 1842 | 1843 | n_actual_input_tokens = len(actual_input_tokens) 1844 | n_cache_hit_tokens = n_input_tokens - n_actual_input_tokens 1845 | 1846 | _stop_tokens = stop_tokens if stop_tokens is not None else self.eog_tokens 1847 | 1848 | if sampler_preset is None: 1849 | sampler_params = self._default_sampler_params 1850 | else: 1851 | sampler_params = self.sampler_params_from_preset(sampler_preset) 1852 | 1853 | if get_verbose(): 1854 | sampler_params.print_chain() 1855 | 1856 | _n_predict = n_predict if n_predict >= 0 else self._n_ctx - self.pos 1857 | 1858 | log_verbose( 1859 | f'Llama.stream: {n_cache_hit_tokens} tokens in cache, {n_actual_input_tokens} ' 1860 | f'tokens to eval ...' 1861 | ) 1862 | 1863 | batches = split_tokens_into_batches(actual_input_tokens, self._n_batch) 1864 | 1865 | log_debug('Llama.stream: start processing input batches') 1866 | 1867 | # process each input batch one-by-one 1868 | for batch in _batches_with_progress_bar(batches): 1869 | self._process_batch(batch, logits_all=False) 1870 | 1871 | log_debug('Llama.stream: done processing input batches') 1872 | 1873 | if _n_predict == 0: 1874 | self._stopwatch.stop_wall_time() 1875 | if get_verbose(): 1876 | self._stopwatch.print_stats() 1877 | return 1878 | 1879 | n_predicted = 0 1880 | 1881 | log_verbose(f'Llama.stream: predicting up to {_n_predict} new tokens ...') 1882 | log_debug(f'Llama.stream: enter while loop') 1883 | 1884 | while n_predicted < _n_predict: 1885 | # sample a token from the latest logits 1886 | sampled_token = self.sample(sampler_params) 1887 | 1888 | is_stop_token = sampled_token in _stop_tokens 1889 | 1890 | if yield_logits: 1891 | if not is_stop_token: 1892 | # process the token, yield the logits for the next prediction 1893 | logits = self._process_batch([sampled_token], True) 1894 | yield logits 1895 | else: 1896 | yield sampled_token 1897 | 1898 | n_predicted += 1 1899 | 1900 | if is_stop_token: 1901 | if get_verbose(): 1902 | tok_str = ez_decode(self.token_to_piece(sampled_token, True)) 1903 | print() 1904 | log(f'inferred stop token {sampled_token} ({tok_str!r})') 1905 | break 1906 | 1907 | if not yield_logits: 1908 | self._process_batch([sampled_token]) 1909 | 1910 | # done generating, show stopwatch stats 1911 | self._stopwatch.stop_wall_time() 1912 | log_debug(f'Llama.stream: exited while loop') 1913 | if get_verbose(): 1914 | self._stopwatch.print_stats() 1915 | 1916 | def benchmark( 1917 | self, 1918 | n_tokens_pp: Optional[int] = None, 1919 | n_tokens_tg: Optional[int] = None, 1920 | n_runs: Optional[int] = None 1921 | ) -> list[dict]: 1922 | """Measure the prompt processing and text generation speed of this Llama.""" 1923 | 1924 | n_tokens_pp = n_tokens_pp if n_tokens_pp is not None else self.n_batch() 1925 | n_tokens_tg = n_tokens_tg if n_tokens_tg is not None else 10 1926 | n_runs = n_runs if n_runs is not None else 3 1927 | 1928 | results = [] 1929 | total_pp_time_ns = 0 1930 | total_tg_time_ns = 0 1931 | 1932 | for i in range(1, n_runs+1): 1933 | 1934 | log_verbose(f'benchmark: starting run {i}/{n_runs}:') 1935 | 1936 | log_verbose(f'benchmark: processing {n_tokens_pp} tokens ... please wait ...') 1937 | self.reset() 1938 | self.eval(input_tokens=[0] * n_tokens_pp) 1939 | pp_ns = self._stopwatch.get_elapsed_time_pp() 1940 | total_pp_time_ns += pp_ns 1941 | 1942 | log_verbose(f'benchmark: generating {n_tokens_tg} tokens ... please wait ...') 1943 | self.reset() 1944 | self.generate( 1945 | input_tokens=[0], 1946 | n_predict=n_tokens_tg, 1947 | stop_tokens=[], 1948 | sampler_preset=SamplerPreset(seed=42, top_k=1, temp=0.0) 1949 | ) 1950 | tg_ns = self._stopwatch.get_elapsed_time_tg() 1951 | total_tg_time_ns += tg_ns 1952 | 1953 | results.append({ 1954 | 'n_tokens_pp' : n_tokens_pp, 1955 | 'n_tokens_tg' : n_tokens_tg, 1956 | 'pp_time_ns' : pp_ns, 1957 | 'tg_time_ns' : tg_ns 1958 | }) 1959 | 1960 | avg_pp_time_ns = total_pp_time_ns / n_runs 1961 | avg_tg_time_ns = total_tg_time_ns / n_runs 1962 | 1963 | avg_pp_time_ms = avg_pp_time_ns / 1e6 1964 | avg_tg_time_ms = avg_tg_time_ns / 1e6 1965 | 1966 | avg_pp_tok_per_sec = n_tokens_pp / (avg_pp_time_ns / 1e9) 1967 | avg_tg_tok_per_sec = n_tokens_tg / (avg_tg_time_ns / 1e9) 1968 | 1969 | log_verbose( 1970 | f'average pp speed for {n_tokens_pp:>7} tokens over {n_runs} runs: ' 1971 | f'{avg_pp_time_ms:>13.3f}ms ({avg_pp_tok_per_sec:10.2f} tok/s)', 4 1972 | ) 1973 | log_verbose( 1974 | f'average tg speed for {n_tokens_tg:>7} tokens over {n_runs} runs: ' 1975 | f'{avg_tg_time_ms:>13.3f}ms ({avg_tg_tok_per_sec:10.2f} tok/s)', 4 1976 | ) 1977 | 1978 | return results 1979 | 1980 | def sample_greedy(self) -> int: 1981 | id = _internals.sample_greedy(self._ctx.ctx) 1982 | # llama_sampler_sample internally calls llama_sampler_accept. 1983 | # uncomment the next line if this changes 1984 | #lib.llama_sampler_accept(_internals.greedy_sampler, id) 1985 | return id 1986 | 1987 | def sampler_params_from_preset(self, sampler_preset: SamplerPreset) -> SamplerParams: 1988 | """Create and return a new `SamplerParams` object for this Llama using the provided 1989 | `SamplerPreset`. 1990 | 1991 | @param sampler_preset: The `sampling.SamplerPreset` object which defines the sampler 1992 | parameter values to use""" 1993 | 1994 | return SamplerParams( 1995 | llama = self, 1996 | seed = sampler_preset.seed, 1997 | top_k = sampler_preset.top_k, 1998 | top_p = sampler_preset.top_p, 1999 | min_p = sampler_preset.min_p, 2000 | xtc_probability = sampler_preset.xtc_probability, 2001 | xtc_threshold = sampler_preset.xtc_threshold, 2002 | typical_p = sampler_preset.typical_p, 2003 | temp = sampler_preset.temp, 2004 | dynatemp_delta = sampler_preset.dynatemp_delta, 2005 | dynatemp_exponent = sampler_preset.dynatemp_exponent, 2006 | penalty_last_n = sampler_preset.penalty_last_n, 2007 | penalty_repeat = sampler_preset.penalty_repeat, 2008 | penalty_freq = sampler_preset.penalty_freq, 2009 | penalty_present = sampler_preset.penalty_present, 2010 | dry_multiplier = sampler_preset.dry_multiplier, 2011 | dry_base = sampler_preset.dry_base, 2012 | dry_allowed_length = sampler_preset.dry_allowed_length, 2013 | dry_penalty_last_n = sampler_preset.dry_penalty_last_n, 2014 | mirostat = sampler_preset.mirostat, 2015 | top_n_sigma = sampler_preset.top_n_sigma, 2016 | mirostat_tau = sampler_preset.mirostat_tau, 2017 | mirostat_eta = sampler_preset.mirostat_eta, 2018 | dry_sequence_breakers = sampler_preset.dry_sequence_breakers, 2019 | logit_bias = sampler_preset.logit_bias 2020 | ) 2021 | 2022 | def sample(self, sampler_params: Optional[SamplerParams] = None) -> int: 2023 | """Sample a token using the current context 2024 | 2025 | - params: 2026 | The `sampling.SamplerParams` object which defines the sampling 2027 | parameters to use. If this parameter is None, the default sampler 2028 | paramater values will be used.""" 2029 | sampler_params = sampler_params if sampler_params is not None else ( 2030 | self._default_sampler_params 2031 | ) 2032 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.sample') 2033 | id = lib.llama_sampler_sample(sampler_params.smpl, self._ctx.ctx, -1) 2034 | # llama_sampler_sample internally calls llama_sampler_accept. 2035 | # uncomment the next line if this changes 2036 | #lib.llama_sampler_accept(sampler_params.smpl, id) 2037 | return id 2038 | 2039 | def get_logits(self) -> np.ndarray: 2040 | """Return the raw logits for the last token in the context, which are the predictions 2041 | for the next token. The returned array has shape `(n_vocab,)`.""" 2042 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.get_logits') 2043 | return _internals.get_logits(self._ctx.ctx, self._n_vocab) 2044 | 2045 | def get_scores(self, temp: Optional[float] = None) -> np.ndarray: 2046 | """Compute the logits for the last token in the context, normalized with softmax. 2047 | Optionally apply temperature `temp` if specified. 2048 | 2049 | Any floating-point value for temperature `temp` is valid, including 0.0 2050 | and negative numbers. 2051 | 2052 | The returned array has shape `(n_vocab,)`.""" 2053 | logits = self.get_logits() 2054 | return softmax(logits, T=temp if temp is not None else 1.0) 2055 | 2056 | def get_tokenization_mapping(self, tokens: list[int]) -> list[tuple[int, bytes]]: 2057 | """Given some tokens, return a list of tuples where the first item in the 2058 | tuple is the token ID and the second item is the corresponding UTF-8 2059 | text bytes.""" 2060 | return list(zip(tokens, [self.token_to_piece(id, special=True) for id in tokens])) 2061 | 2062 | def print_tokenization_mapping( 2063 | self, 2064 | tokens: list[int], 2065 | file: _SupportsWriteAndFlush = sys.stderr 2066 | ) -> None: 2067 | """Given some tokens, print a mapping of each token ID to the 2068 | corresponding UTF-8 text bytes 2069 | 2070 | This is meant to be roughly equivalent to `llama.cpp/llama-tokenize` 2071 | 2072 | - tokens: 2073 | The tokens to print a mapping for 2074 | - file: 2075 | The file or stream to which the mapping will be printed""" 2076 | token_mapping = self.get_tokenization_mapping(tokens) 2077 | for id, bytes in token_mapping: 2078 | print(f"{id:>7} -> {repr(bytes)} ({bytes.hex(':')})", file=file) 2079 | #print(f"{id:>7} -> {str(txt)}", file=file) 2080 | print(f"Total number of tokens: {len(token_mapping)}", file=file, flush=True) 2081 | 2082 | def name(self) -> str: 2083 | """Get the name of the model from the GGUF metadata""" 2084 | # '/path/to/my-model.gguf' --> 'my-model' 2085 | model_file_basename = os.path.basename(self._model.path_model).removesuffix('.gguf') 2086 | # 'my-model-00001-of-99999' --> 'my-model' 2087 | model_file_basename = re.sub(r'-\d{5}-of-\d{5}$', '', model_file_basename) 2088 | # TODO: get from metadata instead, fallback to using filename 2089 | return model_file_basename 2090 | 2091 | def bpw(self) -> float: 2092 | """Get the average bits per weight of the model""" 2093 | return (self._model_size_bytes * 8) / self._n_params 2094 | 2095 | def save_state(self, file_path: str) -> None: 2096 | """Save the current state of the context to a file""" 2097 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.save_state') 2098 | 2099 | state_size_bytes = lib.llama_state_get_size(self._ctx.ctx) 2100 | state_size_mib = int(state_size_bytes / (1024 * 1024)) # approximate 2101 | 2102 | log(f'Llama.save_state: state size: {state_size_mib} MiB ({state_size_bytes} bytes)') 2103 | log(f'Llama.save_state: saving to {file_path} ...') 2104 | 2105 | if os.path.exists(file_path): 2106 | log(f'Llama.save_state: file exists, will be overwritten', 2) 2107 | 2108 | # save the llama state 2109 | with suppress_output(disable=get_verbose()): 2110 | success = lib.llama_state_save_file(self._ctx.ctx, file_path, self.context_tokens) 2111 | 2112 | if success: 2113 | log(f'Llama.save_state: successfully saved state') 2114 | else: 2115 | raise RuntimeError(f'Llama.save_state: failed to save state') 2116 | 2117 | def load_state(self, file_path: str) -> None: 2118 | """Load a previously saved context state from a file""" 2119 | null_ptr_check(self._ctx.ctx, 'self._ctx.ctx', 'Llama.load_state') 2120 | 2121 | if not os.path.exists(file_path): 2122 | raise FileNotFoundError(f'Llama.load_state: file_path {file_path} does not exist') 2123 | 2124 | if os.path.isdir(file_path): 2125 | raise IsADirectoryError(f'Llama.load_state: file_path {file_path} is a directory') 2126 | 2127 | # reset the current context before loading the new one 2128 | self.reset() 2129 | 2130 | n_ctx = self.n_ctx() 2131 | loaded_tokens_buf = (lib.llama_token * n_ctx)() 2132 | n_loaded_tokens_p = ctypes.c_size_t(0) 2133 | 2134 | log(f'Llama.load_state: loading from {file_path} ...') 2135 | 2136 | # load the llama state 2137 | with suppress_output(disable=get_verbose()): 2138 | success = lib.llama_state_load_file( 2139 | ctx=self._ctx.ctx, 2140 | path_session=file_path, 2141 | tokens_out=loaded_tokens_buf, 2142 | n_token_capacity=n_ctx, 2143 | n_token_count_out=ctypes.byref(n_loaded_tokens_p) 2144 | ) 2145 | 2146 | if success: 2147 | n_loaded_tokens = n_loaded_tokens_p.value 2148 | 2149 | self.context_tokens = list(loaded_tokens_buf[:n_loaded_tokens]) 2150 | self.pos = n_loaded_tokens 2151 | 2152 | state_size_bytes = lib.llama_state_get_size(self._ctx.ctx) 2153 | state_size_mib = int(state_size_bytes / (1024 * 1024)) # approximate 2154 | 2155 | log(f'Llama.load_state: state size: {state_size_mib} MiB ({state_size_bytes} bytes)') 2156 | log(f'Llama.load_state: successfully loaded state ({n_loaded_tokens} tokens)') 2157 | else: 2158 | raise RuntimeError(f'Llama.load_state: failed to load state') 2159 | 2160 | def reset(self) -> None: 2161 | """Reset the position of the model and clear the KV cache""" 2162 | self.kv_cache_clear() 2163 | self.pos = 0 2164 | self.context_tokens = [] 2165 | -------------------------------------------------------------------------------- /easy_llama/sampling.py: -------------------------------------------------------------------------------- 1 | # sampling.py 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | """This file provides functionality for defining the sampler parameters used to control 6 | text generation.""" 7 | 8 | # TODO: implement grammar, implement custom sampling chains 9 | 10 | import os 11 | import sys 12 | import ctypes 13 | 14 | from .utils import null_ptr_check, log, ez_encode 15 | from .libllama import _internals 16 | from typing import Optional 17 | 18 | from . import libllama as lib 19 | 20 | HIGH_TEMP = 10_000.0 21 | 22 | class Llama: # can't import the real Llama - would be circular 23 | """Type hint denoting a `llama.Llama` instance""" 24 | 25 | def _get_random_seed() -> int: 26 | # unsigned 32-bit integer 27 | return int.from_bytes(bytes=os.urandom(4), byteorder=sys.byteorder, signed=False) 28 | 29 | class SamplerParams: 30 | """A SamplerParams object is used by a Llama model to define sampling behaviour. 31 | 32 | However, SamplerParams objects also require some information about the Llama model itself, 33 | such as n_ctx_train, n_vocab, etc. Therefore Llama models and SamplerParams are tightly 34 | coupled. 35 | 36 | A SamplerPreset (which is a separate class) can be used to define sampling parameters 37 | without having to specify a Llama object. In turn, the Llama class can use these presets to 38 | create the actual SamplerParams object it needs for sampling.""" 39 | 40 | # NOTE: as of 2025-04-04, the default sampler chain for llama-cli is: 41 | # 42 | # logits -> logit-bias -> penalties -> dry -> top-k -> typical -> top-p -> min-p -> 43 | # xtc -> temp-ext -> dist 44 | # 45 | # ----------------------------------------------------------------------------------- 46 | # 47 | # as of 2025-04-22, the sampler chain for easy-llama is constructed as follows: 48 | # 49 | # ----------------------------------------------------------------------------------- 50 | # 51 | # logits -> logit-bias -> top-k (k=128) -> penalties -> dry -> xtc ... 52 | # 53 | # -- IF TEMP <= 0.0: 54 | # 55 | # ... -> greedy 56 | # 57 | # -- IF MIROSTAT v1: 58 | # 59 | # ... -> temp(-ext) -> mirostat-v1 60 | # 61 | # -- IF MIROSTAT v2: 62 | # 63 | # ... -> temp(-ext) -> mirostat-v2 64 | # 65 | # -- ELSE: 66 | # 67 | # -- IF TOP-N-SIGMA >= 0: 68 | # 69 | # ... -> top-k -> temp -> top-n-sigma -> dist 70 | # 71 | # -- ELSE: 72 | # 73 | # ... -> top-k -> typical-p -> top-p -> min-p -> temp(-ext) -> dist 74 | # 75 | # ---------------------------------------------------------------------------------- 76 | # 77 | # - "temp(-ext)" denotes "temp" if dynatemp_range == 0.0, otherwise 78 | # "temp-ext" 79 | # 80 | # - "top-k (k=128)" is always performed before applying penalties 81 | # and DRY to improve performance 82 | # 83 | # - note that "logit-bias", "top-k (k=128)", "penalties", and "dry" 84 | # are always applied 85 | 86 | def __init__( 87 | # 88 | # ref: llama.cpp/common/common.h: struct common_params_sampling { ... } 89 | # 90 | self, 91 | llama: Llama, # some samplers require info about n_ctx_train, n_vocab, etc. 92 | 93 | seed: int = -1, # random seed: <= 0 94 | top_k: int = 40, # neutral: <= 0 95 | top_p: float = 0.95, # neutral: 1.0 96 | min_p: float = 0.05, # neutral: 0.0 97 | xtc_probability: float = 0.0, # neutral: 0.0 98 | xtc_threshold: float = 0.1, # disable: > 0.5 99 | typical_p: float = 1.0, # neutral: 1.0 100 | temp: float = 0.8, # neutral: 1.0, greedy: <= 0.0 101 | dynatemp_delta: float = 0.0, # neutral: <= 0.0 102 | dynatemp_exponent: float = 1.0, # controls how entropy maps to dynamic temperature 103 | penalty_last_n: int = 64, # disable: 0, n_ctx: -1, last n tokens to penalize 104 | penalty_repeat: float = 1.0, # neutral: 1.0, should be between 1.0 and ~1.1 105 | penalty_freq: float = 0.0, # neutral: 0.0 106 | penalty_present: float = 0.0, # neutral: 0.0 107 | dry_multiplier: float = 0.0, # disable: 0.0, DRY repetition penalty for tokens extending repetition: 108 | dry_base: float = 1.75, # disable: 0.0, multiplier * base ^ (length of sequence before token - allowed length) 109 | dry_allowed_length: int = 2, # tokens extending repetitions beyond this receive penalty 110 | dry_penalty_last_n: int = -1, # disable: 0, n_ctx: -1, how many tokens to scan for repetitions 111 | mirostat: int = 0, # disable: 0, use v1: 1, use v2: 2 112 | top_n_sigma: float = -1.0, # disable: -1.0 113 | mirostat_tau: float = 5.0, # target entropy for mirostat 114 | mirostat_eta: float = 0.1, # learning rate for mirostat 115 | 116 | dry_sequence_breakers: list[str] = ["\n", ":", "\"", "*"], # default sequence breakers for DRY 117 | 118 | # TODO: grammar goes here 119 | 120 | logit_bias: Optional[dict[int, float]] = None 121 | ): 122 | self.smpl = None 123 | 124 | # 125 | # ref: llama.cpp/common/common.h: common_sampler_init(...) { ... } 126 | # 127 | 128 | # 129 | # Store parameter values as attributes 130 | # 131 | 132 | # NOTE: Changing these attributes will not change the sampling. If you need to change 133 | # sampling, construct a new sampler. 134 | 135 | self.llama = llama 136 | self.seed = seed 137 | 138 | self.top_k = top_k 139 | self.top_p = top_p 140 | self.min_p = min_p 141 | self.xtc_probability = xtc_probability 142 | self.xtc_threshold = xtc_threshold 143 | self.typical_p = typical_p 144 | self.temp = temp 145 | self.dynatemp_delta = dynatemp_delta 146 | self.dynatemp_exponent = dynatemp_exponent 147 | self.penalty_last_n = penalty_last_n 148 | self.penalty_repeat = penalty_repeat 149 | self.penalty_freq = penalty_freq 150 | self.penalty_present = penalty_present 151 | self.dry_multiplier = dry_multiplier 152 | self.dry_base = dry_base 153 | self.dry_allowed_length = dry_allowed_length 154 | self.dry_penalty_last_n = dry_penalty_last_n 155 | self.mirostat = mirostat 156 | self.top_n_sigma = top_n_sigma 157 | self.mirostat_tau = mirostat_tau 158 | self.mirostat_eta = mirostat_eta 159 | 160 | # TODO 161 | # self._validate_params() # show warnings/errors for nonsensical sampler values 162 | 163 | self.dry_sequence_breakers = dry_sequence_breakers 164 | 165 | self.logit_bias = logit_bias 166 | 167 | self._chain_str = '' 168 | 169 | sparams = lib.llama_sampler_chain_default_params() 170 | null_ptr_check(sparams, 'sparams', 'SamplerParams.__init__') 171 | 172 | smpl = lib.llama_sampler_chain_init(sparams) 173 | null_ptr_check(smpl, 'smpl', 'SamplerParams.__init__') 174 | 175 | # Logit bias 176 | 177 | if logit_bias is not None: 178 | if len(logit_bias) == 1: 179 | self._chain_str += f'one logit bias -> ' 180 | else: 181 | self._chain_str += f'{len(logit_bias)} logit biases -> ' 182 | logit_bias_arr = _internals.get_logit_bias_array(logit_bias) 183 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_logit_bias( 184 | n_vocab=llama._n_vocab, 185 | n_logit_bias=len(logit_bias), 186 | logit_bias=logit_bias_arr 187 | )) 188 | 189 | penalty_last_n = penalty_last_n if penalty_last_n >= 0 else llama.n_ctx() 190 | _penalties_sampler_active = penalty_last_n != 0 and any( 191 | [penalty_repeat != 1.0, penalty_present != 0.0, penalty_freq != 0.0] 192 | ) 193 | _dry_sampler_active = dry_multiplier > 0.0 194 | _use_topk128_cutoff = _penalties_sampler_active or _dry_sampler_active 195 | 196 | if _use_topk128_cutoff: 197 | 198 | # Top-K (where k == 128) 199 | # 200 | # NOTE: This improves performance by greatly reducing the number of 201 | # tokens that the DRY and penalties samplers need to consider. It 202 | # should have absolutely no effect on the output except under the 203 | # strangest and most unlikely circumstances (like when temp > 10.0 204 | # and all other samplers are explicitly disabled). 205 | # 206 | # If you really need to bypass this, you can construct your own 207 | # llama_sampler_chain manually. But you probably don't need to. 208 | # 209 | self._chain_str += 'top-k 128 -> ' 210 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_k(k=128)) 211 | 212 | # Penalties 213 | 214 | if _penalties_sampler_active: 215 | self._chain_str += f'penalty last:{penalty_last_n}' 216 | self._chain_str += f' rept:{penalty_repeat:.3f}' if penalty_repeat != 1.0 else '' 217 | self._chain_str += f' pres:{penalty_present:.3f}' if penalty_present != 0.0 else '' 218 | self._chain_str += f' freq:{penalty_freq:.3f}' if penalty_freq != 0.0 else '' 219 | self._chain_str += f' -> ' 220 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_penalties( 221 | penalty_last_n=penalty_last_n, 222 | penalty_repeat=penalty_repeat, 223 | penalty_freq=penalty_freq, 224 | penalty_present=penalty_present 225 | )) 226 | 227 | # DRY 228 | 229 | if _dry_sampler_active: 230 | self._chain_str += f'DRY x{dry_multiplier:.2f} base:{dry_base:.2f} ' 231 | self._chain_str += f'len:{dry_allowed_length} -> ' 232 | # dry == D.R.Y. ("Don't Repeat Yourself") 233 | # ref: https://github.com/oobabooga/text-generation-webui/pull/5677 234 | null_ptr_check(llama._vocab, 'llama._vocab', 'SamplerParams.__init__') 235 | seq_breakers = dry_sequence_breakers 236 | seq_breakers_bytes = [ez_encode(s) for s in seq_breakers] 237 | arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes) 238 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_dry( 239 | vocab=llama._vocab, 240 | n_ctx_train=llama._n_ctx_train, 241 | dry_multiplier=dry_multiplier, 242 | dry_base=dry_base, 243 | dry_allowed_length=dry_allowed_length, 244 | dry_penalty_last_n=dry_penalty_last_n, 245 | seq_breakers=arr, 246 | num_breakers=len(seq_breakers) 247 | )) 248 | 249 | # XTC 250 | 251 | if xtc_probability > 0.0: 252 | self._chain_str += f'XTC p:{xtc_probability:.2f} t:{xtc_threshold:.2f} -> ' 253 | lib.llama_sampler_chain_add( 254 | smpl, lib.llama_sampler_init_xtc( 255 | p=xtc_probability, 256 | t=xtc_threshold, 257 | min_keep=1, 258 | seed=seed if seed > 0 else _get_random_seed() 259 | ) 260 | ) 261 | 262 | # IF TEMP <= 0.0: 263 | 264 | if temp <= 0.0: 265 | # ... -> greedy 266 | self._chain_str += 'greedy' 267 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_greedy()) 268 | 269 | # IF MIROSTAT v1: 270 | 271 | elif mirostat == 1: 272 | # ... -> temp(-ext) -> mirostat-v1 273 | if dynatemp_delta != 0.0: 274 | # dynamic temperature AKA entropy sampling 275 | self._chain_str += f'temp {temp:.2f} +/- {dynatemp_delta:.2f} -> ' 276 | lib.llama_sampler_chain_add( 277 | smpl, lib.llama_sampler_init_temp_ext( 278 | t=temp, 279 | delta=dynatemp_delta, 280 | exponent=dynatemp_exponent 281 | ) 282 | ) 283 | else: 284 | # standard temperature 285 | self._chain_str += f'temp {temp:.2f} -> ' 286 | lib.llama_sampler_chain_add( 287 | smpl, lib.llama_sampler_init_temp(t=temp) 288 | ) 289 | self._chain_str += f'mirostat v1 tau:{mirostat_tau:.2f} eta:{mirostat_eta:.2f}' 290 | lib.llama_sampler_chain_add( 291 | smpl, lib.llama_sampler_init_mirostat( 292 | seed=seed if seed > 0 else _get_random_seed(), 293 | tau=mirostat_tau, 294 | eta=mirostat_eta 295 | ) 296 | ) 297 | 298 | # IF MIROSTAT v2: 299 | 300 | elif mirostat == 2: 301 | # ... -> temp(-ext) -> mirostat-v2 302 | if dynatemp_delta != 0.0: 303 | # dynamic temperature AKA entropy sampling 304 | self._chain_str += f'temp-ext {temp:.2f} +/- {dynatemp_delta:.2f} -> ' 305 | lib.llama_sampler_chain_add( 306 | smpl, lib.llama_sampler_init_temp_ext( 307 | t=temp, 308 | delta=dynatemp_delta, 309 | exponent=dynatemp_exponent 310 | ) 311 | ) 312 | else: 313 | # standard temperature 314 | self._chain_str += f'temp {temp:.2f} -> ' 315 | lib.llama_sampler_chain_add( 316 | smpl, lib.llama_sampler_init_temp(t=temp) 317 | ) 318 | self._chain_str += f'mirostat v2 tau:{mirostat_tau:.2f} eta:{mirostat_eta:.2f}' 319 | lib.llama_sampler_chain_add( 320 | smpl, lib.llama_sampler_init_mirostat_v2( 321 | seed=seed if seed > 0 else _get_random_seed(), 322 | tau=mirostat_tau, 323 | eta=mirostat_eta 324 | ) 325 | ) 326 | 327 | # DEFAULT CASE 328 | 329 | elif mirostat == 0: 330 | if top_n_sigma >= 0.0: 331 | # ... -> top-k -> temp -> top-n-sigma -> ... 332 | self._chain_str += ( 333 | f'top-k {top_k} -> temp {temp:.2f} -> top-n-sigma {top_n_sigma:.2f} -> ' 334 | ) 335 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_k(k=top_k)) 336 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_temp(t=temp)) 337 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_n_sigma(n=top_n_sigma)) 338 | else: 339 | # ... -> top-k -> typical-p -> top-p -> min-p -> temp(-ext) -> ... 340 | if top_k > 0: 341 | self._chain_str += f'top-k {top_k} -> ' 342 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_top_k(k=top_k)) 343 | if typical_p != 1.0: 344 | self._chain_str += f'typical-p {typical_p:.2f} -> ' 345 | lib.llama_sampler_chain_add( 346 | smpl, lib.llama_sampler_init_typical(p=typical_p, min_keep=1) 347 | ) 348 | if top_p < 1.0: 349 | self._chain_str += f'top-p {top_p:.2f} -> ' 350 | lib.llama_sampler_chain_add( 351 | smpl, lib.llama_sampler_init_top_p(p=top_p, min_keep=1) 352 | ) 353 | if min_p > 0.0: 354 | self._chain_str += f'min-p {min_p:.3f} -> ' 355 | lib.llama_sampler_chain_add( 356 | smpl, lib.llama_sampler_init_min_p(p=min_p, min_keep=1) 357 | ) 358 | if dynatemp_delta != 0.0: 359 | # dynamic temperature AKA entropy sampling 360 | self._chain_str += f'temp {temp:.2f} +/- {dynatemp_delta:.2f} -> ' 361 | lib.llama_sampler_chain_add( 362 | smpl, lib.llama_sampler_init_temp_ext( 363 | t=temp, 364 | delta=dynatemp_delta, 365 | exponent=dynatemp_exponent 366 | ) 367 | ) 368 | else: 369 | # standard temperature 370 | self._chain_str += f'temp {temp:.2f} -> ' 371 | lib.llama_sampler_chain_add(smpl, lib.llama_sampler_init_temp(t=temp)) 372 | 373 | # ... -> dist 374 | self._chain_str += 'dist' 375 | lib.llama_sampler_chain_add( 376 | smpl, lib.llama_sampler_init_dist(seed=seed if seed > 0 else _get_random_seed()) 377 | ) 378 | 379 | else: 380 | raise ValueError( 381 | f'SamplerParams.__init__: unknown mirostat version {mirostat!r}' 382 | ) 383 | 384 | self.smpl = smpl 385 | 386 | def __del__(self): 387 | self.free() 388 | 389 | def __repr__(self) -> str: 390 | return ( 391 | f"SamplerParams(" 392 | f"llama=, " 393 | f"seed={self.seed}, " 394 | f"top_k={self.top_k}, " 395 | f"min_p={self.min_p}, " 396 | f"xtc_probability={self.xtc_probability}, " 397 | f"xtc_threshold={self.xtc_threshold}, " 398 | f"typical_p={self.typical_p}, " 399 | f"temp={self.temp}, " 400 | f"dynatemp_delta={self.dynatemp_delta}, " 401 | f"dynatemp_exponent={self.dynatemp_exponent}, " 402 | f"penalty_last_n={self.penalty_last_n}, " 403 | f"penalty_repeat={self.penalty_repeat}, " 404 | f"penalty_freq={self.penalty_freq}, " 405 | f"penalty_present={self.penalty_present}, " 406 | f"dry_multiplier={self.dry_multiplier}, " 407 | f"dry_base={self.dry_base}, " 408 | f"dry_allowed_length={self.dry_allowed_length}, " 409 | f"dry_penalty_last_n={self.dry_penalty_last_n}, " 410 | f"top_n_sigma={self.top_n_sigma}, " 411 | f"mirostat={self.mirostat}, " 412 | f"mirostat_tau={self.mirostat_tau}, " 413 | f"mirostat_eta={self.mirostat_eta}, " 414 | f"dry_sequence_breakers={self.dry_sequence_breakers!r}, " 415 | f"logit_bias={self.logit_bias!r}" 416 | f")" 417 | ) 418 | 419 | def print_chain(self) -> None: 420 | """Print the active sampler chain.""" 421 | log(f'sampler chain: {self._chain_str}') 422 | 423 | def free(self) -> None: 424 | if self.smpl is not None: 425 | lib.llama_sampler_free(self.smpl) 426 | self.smpl = None 427 | 428 | def reset(self) -> None: 429 | null_ptr_check(self.smpl, 'self.smpl', 'SamplerParams.reset') 430 | lib.llama_sampler_reset(self.smpl) 431 | 432 | def to_dict(self) -> dict: 433 | """Return the sampler parameters as a dictionary.""" 434 | return { 435 | # not including "llama" 436 | "seed" : self.seed, 437 | "top_k" : self.top_k, 438 | "top_p" : self.top_p, 439 | "min_p" : self.min_p, 440 | "xtc_probability" : self.xtc_probability, 441 | "xtc_threshold" : self.xtc_threshold, 442 | "typical_p" : self.typical_p, 443 | "temp" : self.temp, 444 | "dynatemp_delta" : self.dynatemp_delta, 445 | "dynatemp_exponent" : self.dynatemp_exponent, 446 | "penalty_last_n" : self.penalty_last_n, 447 | "penalty_repeat" : self.penalty_repeat, 448 | "penalty_freq" : self.penalty_freq, 449 | "penalty_present" : self.penalty_present, 450 | "dry_multiplier" : self.dry_multiplier, 451 | "dry_base" : self.dry_base, 452 | "dry_allowed_length" : self.dry_allowed_length, 453 | "dry_penalty_last_n" : self.dry_penalty_last_n, 454 | "mirostat" : self.mirostat, 455 | "top_n_sigma" : self.top_n_sigma, 456 | "mirostat_tau" : self.mirostat_tau, 457 | "mirostat_eta" : self.mirostat_eta, 458 | "dry_sequence_breakers" : self.dry_sequence_breakers, 459 | "logit_bias" : self.logit_bias 460 | } 461 | 462 | @classmethod 463 | def from_dict(cls, llama: Llama, params_dict: dict): 464 | """Creates a SamplerParams instance from a dictionary. 465 | 466 | Args: 467 | llama: The Llama object associated with these parameters. 468 | params_dict: A dictionary containing the sampler parameters. 469 | 470 | Returns: 471 | A new SamplerParams instance. 472 | """ 473 | # Create a copy to avoid modifying the original dictionary 474 | # and remove keys that are not constructor arguments (like 'llama' if present) 475 | filtered_params = {k: v for k, v in params_dict.items() if k in SamplerPreset.__init__.__code__.co_varnames} 476 | return cls(llama=llama, **filtered_params) 477 | 478 | 479 | class SamplerPreset: 480 | """A SamplerPreset object contains all the values necessary to construct a SamplerParams 481 | object using a Llama model. 482 | 483 | Llama objects use SamplerParam objects to define the sampling parameters, but these 484 | SamplerParam objects also require some information about the Llama model itself, such as 485 | n_ctx_train, n_vocab, etc. Therefore Llama models and SamplerParams are tightly coupled. 486 | 487 | A SamplerPreset (this class) can be used to define sampling parameters without having to 488 | specify a Llama object. In turn, the Llama class can use these presets to create the actual 489 | SamplerParams object it needs for sampling.""" 490 | 491 | def __init__( 492 | self, 493 | 494 | seed: int = -1, # random seed: <= 0 495 | top_k: int = 40, # neutral: <= 0 496 | top_p: float = 0.95, # neutral: 1.0 497 | min_p: float = 0.05, # neutral: 0.0 498 | xtc_probability: float = 0.0, # neutral: 0.0 499 | xtc_threshold: float = 0.1, # disable: > 0.5 500 | typical_p: float = 1.0, # neutral: 1.0 501 | temp: float = 0.8, # neutral: 1.0, greedy: <= 0.0 502 | dynatemp_delta: float = 0.0, # neutral: <= 0.0 503 | dynatemp_exponent: float = 1.0, # controls how entropy maps to dynamic temperature 504 | penalty_last_n: int = 64, # disable: 0, n_ctx: -1, last n tokens to penalize 505 | penalty_repeat: float = 1.0, # neutral: 1.0, should be between 1.0 and ~1.1 506 | penalty_freq: float = 0.0, # neutral: 0.0 507 | penalty_present: float = 0.0, # neutral: 0.0 508 | dry_multiplier: float = 0.0, # disable: 0.0, DRY repetition penalty for tokens extending repetition: 509 | dry_base: float = 1.75, # disable: 0.0, multiplier * base ^ (length of sequence before token - allowed length) 510 | dry_allowed_length: int = 2, # tokens extending repetitions beyond this receive penalty 511 | dry_penalty_last_n: int = -1, # disable: 0, n_ctx: -1, how many tokens to scan for repetitions 512 | mirostat: int = 0, # disable: 0, use v1: 1, use v2: 2 513 | top_n_sigma: float = -1.0, # disable: -1.0 514 | mirostat_tau: float = 5.0, # target entropy for mirostat 515 | mirostat_eta: float = 0.1, # learning rate for mirostat 516 | 517 | dry_sequence_breakers: list[str] = ["\n", ":", "\"", "*"], # default sequence breakers for DRY 518 | 519 | # TODO: grammar goes here 520 | 521 | logit_bias: Optional[dict[int, float]] = None 522 | ): 523 | self.seed = seed 524 | 525 | self.top_k = top_k 526 | self.top_p = top_p 527 | self.min_p = min_p 528 | self.xtc_probability = xtc_probability 529 | self.xtc_threshold = xtc_threshold 530 | self.typical_p = typical_p 531 | self.temp = temp 532 | self.dynatemp_delta = dynatemp_delta 533 | self.dynatemp_exponent = dynatemp_exponent 534 | self.penalty_last_n = penalty_last_n 535 | self.penalty_repeat = penalty_repeat 536 | self.penalty_freq = penalty_freq 537 | self.penalty_present = penalty_present 538 | self.dry_multiplier = dry_multiplier 539 | self.dry_base = dry_base 540 | self.dry_allowed_length = dry_allowed_length 541 | self.dry_penalty_last_n = dry_penalty_last_n 542 | self.mirostat = mirostat 543 | self.top_n_sigma = top_n_sigma 544 | self.mirostat_tau = mirostat_tau 545 | self.mirostat_eta = mirostat_eta 546 | 547 | self.dry_sequence_breakers = dry_sequence_breakers 548 | 549 | self.logit_bias = logit_bias 550 | 551 | def __repr__(self) -> str: 552 | return ( 553 | f"SamplerPreset(" 554 | f"seed={self.seed}, " 555 | f"top_k={self.top_k}, " 556 | f"min_p={self.min_p}, " 557 | f"xtc_probability={self.xtc_probability}, " 558 | f"xtc_threshold={self.xtc_threshold}, " 559 | f"typical_p={self.typical_p}, " 560 | f"temp={self.temp}, " 561 | f"dynatemp_delta={self.dynatemp_delta}, " 562 | f"dynatemp_exponent={self.dynatemp_exponent}, " 563 | f"penalty_last_n={self.penalty_last_n}, " 564 | f"penalty_repeat={self.penalty_repeat}, " 565 | f"penalty_freq={self.penalty_freq}, " 566 | f"penalty_present={self.penalty_present}, " 567 | f"dry_multiplier={self.dry_multiplier}, " 568 | f"dry_base={self.dry_base}, " 569 | f"dry_allowed_length={self.dry_allowed_length}, " 570 | f"dry_penalty_last_n={self.dry_penalty_last_n}, " 571 | f"mirostat={self.mirostat}, " 572 | f"top_n_sigma={self.top_n_sigma}, " 573 | f"mirostat_tau={self.mirostat_tau}, " 574 | f"mirostat_eta={self.mirostat_eta}, " 575 | f"dry_sequence_breakers={self.dry_sequence_breakers!r}, " 576 | f"logit_bias={self.logit_bias!r}" 577 | f")" 578 | ) 579 | 580 | def as_dict(self) -> dict: 581 | """Returns the sampler parameters as a dictionary.""" 582 | return { 583 | "seed" : self.seed, 584 | "top_k" : self.top_k, 585 | "top_p" : self.top_p, 586 | "min_p" : self.min_p, 587 | "xtc_probability" : self.xtc_probability, 588 | "xtc_threshold" : self.xtc_threshold, 589 | "typical_p" : self.typical_p, 590 | "temp" : self.temp, 591 | "dynatemp_delta" : self.dynatemp_delta, 592 | "dynatemp_exponent" : self.dynatemp_exponent, 593 | "penalty_last_n" : self.penalty_last_n, 594 | "penalty_repeat" : self.penalty_repeat, 595 | "penalty_freq" : self.penalty_freq, 596 | "penalty_present" : self.penalty_present, 597 | "dry_multiplier" : self.dry_multiplier, 598 | "dry_base" : self.dry_base, 599 | "dry_allowed_length" : self.dry_allowed_length, 600 | "dry_penalty_last_n" : self.dry_penalty_last_n, 601 | "mirostat" : self.mirostat, 602 | "top_n_sigma" : self.top_n_sigma, 603 | "mirostat_tau" : self.mirostat_tau, 604 | "mirostat_eta" : self.mirostat_eta, 605 | "dry_sequence_breakers" : self.dry_sequence_breakers, 606 | "logit_bias" : self.logit_bias 607 | } 608 | 609 | @classmethod 610 | def from_dict(cls, params_dict: dict): 611 | """Creates a SamplerPreset instance from a dictionary. 612 | 613 | Args: 614 | params_dict: A dictionary containing the sampler parameters. 615 | 616 | Returns: 617 | A new SamplerPreset instance. 618 | """ 619 | # Create a copy to avoid modifying the original dictionary 620 | # and remove keys that are not constructor arguments 621 | filtered_params = {k: v for k, v in params_dict.items() if k in cls.__init__.__code__.co_varnames} 622 | return cls(**filtered_params) 623 | 624 | 625 | class SamplerPresets: 626 | """This class contains several ready-made `SamplerPreset` objects that can be used to 627 | control text generation.""" 628 | 629 | Greedy = SamplerPreset( 630 | seed = 1, 631 | top_k = 1, 632 | temp = 0.0 633 | ) 634 | """The most likely token is always chosen""" 635 | 636 | Default = SamplerPreset() 637 | """The default easy-llama sampler preset""" 638 | 639 | LlamaCPP = SamplerPreset( 640 | top_k = 40, 641 | top_p = 0.95, 642 | min_p = 0.05, 643 | temp = 0.8 644 | ) 645 | """The default llama.cpp sampler preset""" 646 | 647 | Cool = SamplerPreset( 648 | top_k = -1, 649 | top_p = 1.0, 650 | min_p = 0.5, 651 | temp = 1.0 652 | ) 653 | """The recommended easy-llama sampler preset for predictable output""" 654 | 655 | Warm = SamplerPreset( 656 | top_k = -1, 657 | top_p = 1.0, 658 | min_p = 0.1, 659 | temp = 1.5 660 | ) 661 | """The recommended easy-llama sampler preset for creative yet coherent output""" 662 | 663 | Neutral = SamplerPreset( 664 | top_k = -1, 665 | top_p = 1.0, 666 | min_p = 0.0, 667 | temp = 1.0 668 | ) 669 | """All samplers neutralized""" 670 | 671 | ContrastiveSearchCool = SamplerPreset( 672 | top_k = -1, 673 | top_p = 1.0, 674 | min_p = 0.0, 675 | temp = 0.4, 676 | penalty_present = 0.6, 677 | ) 678 | """Constrastive Search as described in https://arxiv.org/abs/2210.14140 (less random)""" 679 | 680 | ContrastiveSearchWarm = SamplerPreset( 681 | top_k = -1, 682 | top_p = 1.0, 683 | min_p = 0.0, 684 | temp = 0.8, 685 | penalty_present = 0.6 686 | ) 687 | """Constrastive Search as described in https://arxiv.org/abs/2210.14140 (more random)""" 688 | 689 | NucleusSamplingCool = SamplerPreset( 690 | top_k = -1, 691 | top_p = 0.25, 692 | min_p = 0.0, 693 | temp = 1.0 694 | ) 695 | """Nucleus sampling as described in https://arxiv.org/abs/1904.09751 (less random)""" 696 | 697 | NucleusSamplingWarm = SamplerPreset( 698 | top_k = -1, 699 | top_p = 0.9, 700 | min_p = 0.0, 701 | temp = 1.0 702 | ) 703 | """Nucleus sampling as described in https://arxiv.org/abs/1904.09751 (more random)""" 704 | 705 | TopNSigma = SamplerPreset( 706 | top_k = -1, 707 | top_p = 1.0, 708 | min_p = 0.0, 709 | temp = 1.0, 710 | top_n_sigma = 1.0 711 | ) 712 | """Top-nσ as described on [arXiv](https://arxiv.org/pdf/2411.07641) and 713 | in [llama.cpp#11223](https://github.com/ggml-org/llama.cpp/pull/11223)""" 714 | 715 | TopNSigmaRandom = SamplerPreset( 716 | top_k = -1, 717 | top_p = 1.0, 718 | min_p = 0.0, 719 | temp = 9999.9, 720 | top_n_sigma = 1.0 721 | ) 722 | """Top-nσ as described on [arXiv](https://arxiv.org/pdf/2411.07641) and 723 | in [llama.cpp#11223](https://github.com/ggml-org/llama.cpp/pull/11223), except that 724 | `temp = 9999.9` to randomly select any token that is determined to be valid 725 | by Top-nσ.""" 726 | 727 | DRY = SamplerPreset(dry_multiplier=0.8, dry_base=1.75, dry_allowed_length=2) 728 | """https://github.com/oobabooga/text-generation-webui/pull/5677""" 729 | 730 | XTC = SamplerPreset( 731 | top_k=-1, 732 | top_p=1.0, 733 | min_p=0.02, 734 | xtc_probability=0.5, 735 | xtc_threshold=0.1 736 | ) 737 | """https://github.com/oobabooga/text-generation-webui/pull/6335""" 738 | 739 | # 740 | # Samplers below this line are for specific models / model families 741 | # 742 | 743 | Llama3 = SamplerPreset( 744 | top_k = -1, 745 | top_p = 0.9, 746 | min_p = 0.0, 747 | temp = 0.6 748 | ) 749 | """[meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/)""" 750 | 751 | Llama3Classic = SamplerPreset( 752 | top_k = 40, 753 | top_p = 0.95, 754 | min_p = 0.05, 755 | temp = 0.65 756 | ) 757 | """Unofficial preset based on the developer's personal preference""" 758 | 759 | Llama3Cool = SamplerPreset( 760 | top_k = -1, 761 | top_p = 0.9, 762 | min_p = 0.0, 763 | temp = 0.45 764 | ) 765 | """Llama3 preset with reduced temperature (less random)""" 766 | 767 | Llama3Warm = SamplerPreset( 768 | top_k = -1, 769 | top_p = 0.9, 770 | min_p = 0.0, 771 | temp = 1.2 772 | ) 773 | """Llama3 preset with increased temperature (more random)""" 774 | 775 | Mistral = SamplerPreset( 776 | temp = 0.3 777 | ) 778 | """Mistral models tend to require a lower temperature""" 779 | 780 | Qwen2_5Official = SamplerPreset( 781 | top_k = 20, 782 | top_p = 0.8, 783 | min_p = 0.0, 784 | temp = 0.7, 785 | penalty_repeat = 1.05 786 | ) 787 | """[Qwen/Qwen2.5-14B-Instruct/](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct/) 788 | (official, but not recommended)""" 789 | 790 | Qwen2_5Recommended = SamplerPreset( 791 | top_k = -1, 792 | top_p = 0.9, 793 | min_p = 0.1, 794 | temp = 1.1 795 | ) 796 | """[Qwen/Qwen2.5-14B-Instruct/](https://huggingface.co/Qwen/Qwen2.5-14B-Instruct/) 797 | (unofficial, but recommended)""" 798 | 799 | Qwen3Thinking = SamplerPreset( 800 | top_k = 20, 801 | top_p = 0.95, 802 | min_p = 0.0, 803 | temp = 0.6 804 | ) 805 | """[Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/README.md)""" 806 | 807 | Qwen3NoThinking = SamplerPreset( 808 | top_k = 20, 809 | top_p = 0.8, 810 | min_p = 0.0, 811 | temp = 0.7 812 | ) 813 | """[Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/README.md)""" 814 | -------------------------------------------------------------------------------- /easy_llama/server.py: -------------------------------------------------------------------------------- 1 | # server.py 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | """The easy-llama FastAPI server, including an API endpoint and a simple WebUI""" 6 | 7 | # NOTE: This module is WIP. 8 | 9 | import os 10 | import uvicorn 11 | 12 | import easy_llama as ez 13 | 14 | from pydantic import BaseModel 15 | from fastapi.staticfiles import StaticFiles 16 | from fastapi.responses import FileResponse 17 | from fastapi.middleware.cors import CORSMiddleware 18 | from easy_llama.utils import assert_type, log, ANSI 19 | from typing import Optional, Union, Literal 20 | from fastapi import FastAPI, APIRouter, Body, HTTPException 21 | 22 | 23 | WEBUI_DIRECTORY = os.path.join(os.path.dirname(__file__), 'webui') 24 | 25 | STATUS_RESPONSE_SUCCESS = {'success': True} 26 | STATUS_RESPONSE_FAILURE = {'success': False} 27 | 28 | Y = ANSI.FG_BRIGHT_YELLOW 29 | R = ANSI.MODE_RESET_ALL 30 | 31 | # 32 | # Pydantic models for FastAPI 33 | # 34 | 35 | class StatusResponseModel(BaseModel): 36 | success: bool 37 | 38 | class MessageModel(BaseModel): 39 | role: str 40 | content: str 41 | 42 | class SendMessageModel(BaseModel): 43 | role: str 44 | content: str 45 | n_predict: Optional[int] = -1 46 | 47 | class MessageResponseModel(BaseModel): 48 | role: str 49 | content: str 50 | n_input_tokens: int 51 | n_output_tokens: int 52 | 53 | class SetSystemPromptRequestModel(BaseModel): 54 | content: str 55 | 56 | class SummaryResponseModel(BaseModel): 57 | summary: str 58 | 59 | class InfoResponseModel(BaseModel): 60 | model_name: str 61 | model_n_params: int 62 | model_bpw: float 63 | n_tokens: int 64 | n_ctx: int 65 | n_ctx_train: int 66 | n_thread_tokens: int 67 | n_thread_messages: int 68 | 69 | class SamplerSettingsModel(BaseModel): 70 | seed: Optional[int] 71 | top_k: Optional[int] 72 | top_p: Optional[float] 73 | min_p: Optional[float] 74 | xtc_probability: Optional[float] 75 | xtc_threshold: Optional[float] 76 | typical_p: Optional[float] 77 | temp: Optional[float] 78 | dynatemp_delta: Optional[float] 79 | dynatemp_exponent: Optional[float] 80 | penalty_last_n: Optional[int] 81 | penalty_repeat: Optional[float] 82 | penalty_freq: Optional[float] 83 | penalty_present: Optional[float] 84 | dry_multiplier: Optional[float] 85 | dry_base: Optional[float] 86 | dry_allowed_length: Optional[int] 87 | dry_penalty_last_n: Optional[int] 88 | mirostat: Optional[int] 89 | top_n_sigma: Optional[float] 90 | mirostat_tau: Optional[float] 91 | mirostat_eta: Optional[float] 92 | dry_sequence_breakers: Optional[list[str]] 93 | logit_bias: Optional[dict[int, float]] 94 | 95 | 96 | class Server: 97 | """The easy-llama FastAPI server, providing a WebUI and an API router""" 98 | 99 | def log(self, text: str, level: Literal[1,2,3,4] = 1) -> None: 100 | log(f'ez.Server @ {Y}{self._host}{R}:{Y}{self._port}{R} {text}', level=level) 101 | 102 | def __init__( 103 | self, 104 | thread: 'ez.Thread', 105 | host: str = "127.0.0.1", 106 | port: int = 8080 107 | ): 108 | assert_type(thread, getattr(ez, 'Thread'), 'thread', 'Server.__init__') 109 | self._thread = thread 110 | self._host = host 111 | self._port = port 112 | self._app = FastAPI(title=f"ez.Server @ {host}:{port}") 113 | self._router = APIRouter(prefix="/api") 114 | 115 | self._setup_api_routes() 116 | 117 | self._app.add_middleware( 118 | CORSMiddleware, 119 | allow_origins=["*"], 120 | allow_credentials=True, 121 | allow_methods=["*"], 122 | allow_headers=["*"], 123 | ) 124 | 125 | self._app.include_router(self._router) 126 | 127 | if os.path.isdir(WEBUI_DIRECTORY): 128 | self._app.mount("/static", StaticFiles(directory=WEBUI_DIRECTORY), name="static") 129 | 130 | @self._app.get("/", include_in_schema=False) 131 | async def read_index(): 132 | index_path = os.path.join(WEBUI_DIRECTORY, 'index.html') 133 | if os.path.exists(index_path): 134 | return FileResponse(index_path) 135 | else: 136 | # fallback if index.html is missing but directory exists 137 | self.log('index.html not found', 3) 138 | raise HTTPException(status_code=404, detail="index.html not found") 139 | 140 | else: 141 | self.log(f"WebUI directory not found at {WEBUI_DIRECTORY}", level=2) 142 | @self._app.get("/", include_in_schema=False) 143 | async def root_fallback(): 144 | # provide a message if WebUI isn't found 145 | return {"message": "easy-llama server is running. WebUI not found."} 146 | 147 | 148 | def _setup_api_routes(self): 149 | """Define all API endpoints.""" 150 | router = self._router 151 | 152 | @router.get("/ping", response_model=StatusResponseModel) 153 | async def ping() -> dict: 154 | """Check if the server is running.""" 155 | self.log('pong') 156 | return STATUS_RESPONSE_SUCCESS 157 | 158 | @router.post("/send", response_model=MessageResponseModel) 159 | async def send(message: SendMessageModel = Body(...)) -> dict[str, Union[str, int]]: 160 | """Send a user message and get the bot's response. 161 | Adds both messages to the thread history.""" 162 | 163 | try: 164 | self._thread.add_message(message.role, message.content) 165 | 166 | except ValueError as exc: 167 | self.log(f"/send: error adding user message: {exc}", 3) 168 | raise HTTPException(status_code=400, detail=str(exc)) 169 | 170 | input_toks = self._thread.get_input_ids(role='bot') 171 | 172 | try: 173 | response_toks = self._thread.llama.generate( 174 | input_tokens=input_toks, 175 | n_predict=message.n_predict if message.n_predict is not None else -1, 176 | stop_tokens=self._thread._stop_tokens, 177 | sampler_preset=self._thread.sampler_preset 178 | ) 179 | response_txt = self._thread.llama.detokenize(response_toks, special=False) 180 | 181 | except ez.llama.ExceededContextLengthException as exc: 182 | self.log(f"ExceededContextLengthException: {exc}", 3) 183 | raise HTTPException(status_code=413, detail=str(exc)) 184 | 185 | except Exception as exc: 186 | self.log(f"Error during generation: {type(exc).__name__}: {exc}", level=3) 187 | raise HTTPException(status_code=500, detail=str(exc)) 188 | 189 | # Add bot response to history 190 | self._thread.add_message('bot', response_txt) 191 | 192 | return { 193 | 'role': 'bot', 194 | 'content': response_txt, 195 | 'n_input_tokens': len(input_toks), # Tokens processed for this turn 196 | 'n_output_tokens': len(response_toks) # Tokens generated this turn 197 | } 198 | 199 | @router.post("/add_message", response_model=StatusResponseModel, tags=["Chat"]) 200 | async def add_message(message: MessageModel = Body(...)) -> dict: 201 | """ 202 | Add a message to the thread history without triggering a bot response. 203 | Useful for manually setting up conversation state. 204 | """ 205 | self.log(f"/add_message request: role='{message.role}', content='{message.content[:50]}...'") 206 | try: 207 | self._thread.add_message(message.role, message.content) 208 | return STATUS_RESPONSE_SUCCESS 209 | except ValueError as exc: 210 | self.log(f'Failed to add message: {exc}', 3) 211 | raise HTTPException(status_code=400, detail=str(exc)) 212 | 213 | @router.post("/set_system_prompt", response_model=StatusResponseModel, tags=["Chat"]) 214 | async def set_system_prompt(request: SetSystemPromptRequestModel = Body(...)) -> dict: 215 | """ 216 | Set or update the system prompt (the first message if it has a system role). 217 | If no system prompt exists, it will be prepended. 218 | """ 219 | content = request.content 220 | new_sys_msg = {'role': 'system', 'content': content} 221 | messages = self._thread.messages 222 | 223 | if messages and messages[0]['role'].lower() in self._thread.valid_system_roles: 224 | self.log("Updating existing system prompt.") 225 | messages[0]['content'] = content 226 | else: 227 | self.log("Prepending new system prompt.") 228 | messages.insert(0, new_sys_msg) 229 | 230 | orig_messages = self._thread._orig_messages 231 | if orig_messages and orig_messages[0]['role'].lower() in self._thread.valid_system_roles: 232 | orig_messages[0]['content'] = content 233 | else: 234 | orig_messages.insert(0, new_sys_msg) 235 | 236 | return STATUS_RESPONSE_SUCCESS 237 | 238 | @router.post("/trigger", response_model=MessageResponseModel, tags=["Chat"]) 239 | async def trigger() -> dict[str, Union[str, int]]: 240 | """Trigger the bot to generate a response based on the current history, 241 | without adding a user message first. Appends the bot's response.""" 242 | self.log("/trigger request received") 243 | if not self._thread.messages: 244 | # TODO: this is incorrect 245 | self.log("Cannot trigger response: No messages in history.", level=2) 246 | raise HTTPException(status_code=400, detail="Cannot trigger response with empty history") 247 | 248 | last_message_role = self._thread.messages[-1]['role'].lower() 249 | if last_message_role in self._thread.valid_bot_roles: 250 | self.log("Last message was from bot, triggering another bot response.", level=1) 251 | # Allow triggering even if last was bot, might be desired sometimes 252 | elif last_message_role in self._thread.valid_system_roles: 253 | self.log("Last message was system, triggering bot response.", level=1) 254 | elif last_message_role in self._thread.valid_user_roles: 255 | self.log("Last message was user, triggering bot response.", level=1) 256 | 257 | # prepare for generation 258 | 259 | try: 260 | input_toks = self._thread.get_input_ids(role='bot') 261 | except ez.llama.ExceededContextLengthException as exc: 262 | self.log(f"Context length exceeded before generation: {exc}", level=3) 263 | raise HTTPException(status_code=413, detail=f"Input context too long: {exc}") 264 | except ValueError as exc: # Handle potential errors in get_input_ids 265 | self.log(f"Error getting input IDs: {exc}", level=3) 266 | raise HTTPException(status_code=500, detail=f"Internal error preparing generation: {exc}") 267 | 268 | # generate response 269 | 270 | try: 271 | response_toks = self._thread.llama.generate( 272 | input_tokens=input_toks, 273 | n_predict=-1, 274 | stop_tokens=self._thread._stop_tokens, 275 | sampler_preset=self._thread.sampler_preset 276 | ) 277 | response_txt = self._thread.llama.detokenize(response_toks, special=False).strip() 278 | self.log(f"Generated {len(response_toks)} tokens: '{response_txt[:50]}...'") 279 | except ez.llama.ExceededContextLengthException as exc: 280 | self.log(f"Context length exceeded during generation: {exc}", level=3) 281 | raise HTTPException(status_code=413, detail=f"Context length exceeded during generation: {exc}") 282 | except Exception as exc: 283 | self.log(f"Error during generation: {type(exc).__name__}: {exc}", level=3) 284 | raise HTTPException(status_code=500, detail=f"Generation failed: {exc}") 285 | 286 | # Add bot response to history 287 | self._thread.add_message('bot', response_txt) 288 | 289 | return { 290 | 'role': 'bot', 291 | 'content': response_txt, 292 | 'n_input_tokens': len(input_toks), 293 | 'n_output_tokens': len(response_toks) 294 | } 295 | 296 | 297 | @router.get("/messages", response_model=list[MessageModel], tags=["Chat"]) 298 | async def messages() -> list[dict]: 299 | """Get the current list of messages in the thread history.""" 300 | self.log("/messages request received") 301 | return self._thread.messages 302 | 303 | 304 | @router.get("/summarize", response_model=SummaryResponseModel, tags=["Chat"]) 305 | async def summarize() -> dict[str, str]: 306 | """ 307 | Generate a summary of the current chat thread. 308 | Note: This performs inference and modifies the Llama context temporarily. 309 | """ 310 | self.log("/summarize request received") 311 | try: 312 | summary_text = self._thread.summarize() 313 | self.log(f"Generated summary: '{summary_text[:50]}...'") 314 | return {"summary": summary_text} 315 | except Exception as exc: 316 | self.log(f"Error during summarization: {type(exc).__name__}: {exc}", level=3) 317 | raise HTTPException(status_code=500, detail=f"Summarization failed: {exc}") 318 | 319 | 320 | @router.post("/cancel", response_model=StatusResponseModel, tags=["Control"]) 321 | async def cancel() -> dict: 322 | """(Not Implemented) Attempt to cancel ongoing generation.""" 323 | # NOTE: llama.cpp doesn't have a built-in robust way to interrupt 324 | # llama_decode mid-computation from Python easily without potential 325 | # instability. This requires more complex handling (e.g., abort callbacks). 326 | self.log('Endpoint `/api/cancel` is not implemented yet!', 2) 327 | # return STATUS_RESPONSE_FAILURE 328 | raise HTTPException(status_code=501, detail="Cancel operation not implemented") 329 | 330 | 331 | @router.post("/reset", response_model=StatusResponseModel, tags=["Control"]) 332 | async def reset() -> dict: 333 | """Reset the chat thread to its initial state (usually just the system prompt).""" 334 | self.log("/reset request received") 335 | try: 336 | self._thread.reset() 337 | self.log("Thread and Llama context reset successfully.") 338 | return STATUS_RESPONSE_SUCCESS 339 | except Exception as exc: 340 | self.log(f"Error during reset: {type(exc).__name__}: {exc}", level=3) 341 | raise HTTPException(status_code=500, detail=f"Reset failed: {exc}") 342 | 343 | 344 | @router.get("/info", response_model=InfoResponseModel, tags=["Status"]) 345 | async def get_info() -> dict[str, Union[str, int, float]]: 346 | """Get information about the loaded model and current thread state.""" 347 | self.log("/info request received") 348 | # Use suppress_output if get_input_ids might log verbosely 349 | # with ez.utils.suppress_output(disable=ez.get_verbose()): 350 | input_ids = self._thread.get_input_ids(role=None) # Get tokens just for counting 351 | 352 | return { 353 | 'model_name': self._thread.llama.name(), 354 | 'model_n_params': self._thread.llama.n_params(), 355 | 'model_bpw': self._thread.llama.bpw(), 356 | 'n_tokens': self._thread.llama.pos, 357 | 'n_ctx': self._thread.llama.n_ctx(), 358 | 'n_ctx_train': self._thread.llama.n_ctx_train(), 359 | 'n_thread_tokens': len(input_ids), 360 | 'n_thread_messages': len(self._thread.messages) 361 | } 362 | 363 | @router.post("/sampler", response_model=SamplerSettingsModel, tags=["Sampling"]) 364 | async def set_sampler(settings: SamplerSettingsModel = Body(...)) -> dict: 365 | """ 366 | Update the sampler settings for subsequent generations. 367 | Returns the complete current settings after applying the update. 368 | """ 369 | self.log(f"/sampler POST request received with updates: {settings.model_dump(exclude_unset=True)}") 370 | current = self._thread.sampler_preset 371 | update_data = settings.model_dump(exclude_unset=True) # Get only provided fields 372 | 373 | # Create a dictionary from the current preset 374 | current_data = current.as_dict() 375 | 376 | # Update the current data with the new values 377 | current_data.update(update_data) 378 | 379 | # Create a new preset from the merged data 380 | try: 381 | new_preset = ez.SamplerPreset(**current_data) 382 | # Update the thread's active sampler preset 383 | self._thread.sampler_preset = new_preset 384 | self.log(f"Sampler settings updated. Current: {new_preset}") 385 | return new_preset.as_dict() # Return the full new settings 386 | except Exception as exc: # Catch potential validation errors in SamplerPreset 387 | self.log(f"Error updating sampler settings: {exc}", level=3) 388 | raise HTTPException(status_code=400, detail=f"Invalid sampler settings: {exc}") 389 | 390 | 391 | @router.get("/sampler", response_model=SamplerSettingsModel, tags=["Sampling"]) 392 | async def get_sampler() -> dict: 393 | """Get the current sampler settings.""" 394 | self.log("/sampler GET request received") 395 | return self._thread.sampler_preset.as_dict() 396 | 397 | 398 | # --- Server Start Method --- 399 | def start(self): 400 | """Start the Uvicorn server.""" 401 | try: 402 | uvicorn.run( 403 | app=self._app, 404 | host=self._host, 405 | port=self._port 406 | ) 407 | except KeyboardInterrupt: 408 | self.log('KeyboardInterrupt') 409 | except Exception as exc: 410 | self.log(f'Server crashed: {type(exc).__name__}: {exc}', 3) 411 | # Potentially re-raise or handle specific exceptions (e.g., port in use) 412 | raise exc 413 | finally: 414 | self.log("Server shutdown complete.") 415 | -------------------------------------------------------------------------------- /easy_llama/thread.py: -------------------------------------------------------------------------------- 1 | # thread.py 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | """This file provides functionality for multi-turn conversations with Llama models.""" 6 | 7 | # TODO: jinja2 8 | 9 | import sys 10 | import jinja2 11 | 12 | from typing import Optional 13 | from .formats import PromptFormat 14 | from .sampling import SamplerPreset 15 | from .utils import ( 16 | _SupportsWriteAndFlush, ANSI, log, assert_type, ez_encode, ez_decode, suppress_output, 17 | KeyboardInterruptHandler 18 | ) 19 | 20 | from . import llama as _llama # avoid confusion with Thread.llama attribute 21 | 22 | 23 | class Thread: 24 | 25 | valid_system_roles = ['system', 'developer' ] 26 | valid_user_roles = ['user', 'human' ] 27 | valid_bot_roles = ['bot', 'assistant', 'model', 'gpt', 'llama'] 28 | 29 | all_valid_roles = valid_system_roles + valid_user_roles + valid_bot_roles 30 | 31 | def __init__( 32 | self, 33 | llama: _llama.Llama, 34 | prompt_format: PromptFormat, 35 | sampler_preset: Optional[SamplerPreset] = None 36 | ) -> None: 37 | 38 | assert_type(llama, _llama.Llama, 'llama', 'Thread.__init__') 39 | assert_type(prompt_format, PromptFormat, 'prompt_format', 'Thread.__init__') 40 | 41 | llama._validate_model_state() 42 | 43 | self.llama = llama 44 | self.prompt_format = prompt_format 45 | self.sampler_preset = sampler_preset if sampler_preset is not None else SamplerPreset() 46 | 47 | self.messages: list[dict[str, str]] = [] 48 | 49 | system_prompt = prompt_format.system_prompt() 50 | 51 | if system_prompt != '': 52 | self.messages.append({ 53 | 'role': 'system', 54 | 'content': system_prompt 55 | }) 56 | 57 | # stop tokens 58 | format_stops = self.prompt_format.stop_tokens() 59 | self._stop_tokens = format_stops if format_stops is not None else self.llama.eog_tokens 60 | 61 | # save the original messages for self.reset() 62 | self._orig_messages = self.messages.copy() 63 | 64 | # save the sampler_preset param for repr 65 | self._sampler_preset = self.sampler_preset 66 | 67 | def __repr__(self) -> str: 68 | return ( 69 | f"Thread(" 70 | f"llama={self.llama!r}, " 71 | f"prompt_format={self.prompt_format!r}, " 72 | f"sampler_preset={self._sampler_preset!r}" 73 | f")" 74 | ) 75 | 76 | def _messages_in_jinja_format(self) -> list[dict]: 77 | jinja_messages = [] 78 | for message in self.messages: 79 | try: 80 | role = message['role'] 81 | except KeyError: 82 | log(f'_messages_in_jinja_format: skipping message with no role!', 2) 83 | continue 84 | try: 85 | content = message['content'] 86 | except KeyError: 87 | log(f'_messages_in_jinja_format: skipping message with no content!', 2) 88 | continue 89 | if role in Thread.valid_system_roles: 90 | jinja_messages.append({'role': 'system', 'content': content}) 91 | elif role in Thread.valid_user_roles: 92 | jinja_messages.append({'role': 'user', 'content': content}) 93 | elif role in Thread.valid_bot_roles: 94 | jinja_messages.append({'role': 'assistant', 'content': content}) 95 | else: 96 | log( 97 | f'_messages_in_jinja_format: skipping message with invalid role {role!r}!', 98 | 2 99 | ) 100 | return jinja_messages 101 | 102 | def _render_messages( 103 | self, 104 | add_generation_prompt: bool = True, 105 | **kwargs 106 | ) -> str: 107 | """Render the Jinja template with current messages""" 108 | try: 109 | template = jinja2.Template( 110 | source=self.llama._chat_template, 111 | undefined=jinja2.StrictUndefined 112 | ) 113 | context = { 114 | 'messages': self._messages_in_jinja_format(), 115 | 'add_generation_prompt': add_generation_prompt, 116 | **kwargs 117 | } 118 | return template.render(context) 119 | except jinja2.exceptions.TemplateSyntaxError as e: 120 | log(f"_render_messages: invalid chat template syntax: {e}", 3) 121 | raise ValueError(f"_render_messages: invalid chat template syntax: {e}") from e 122 | except jinja2.exceptions.UndefinedError as e: 123 | log(f"_render_messages: missing template variable: {e}", 3) 124 | raise ValueError(f"_render_messages: missing template variable: {e}") from e 125 | except Exception as e: 126 | log(f"_render_messages: template rendering error: {e}", 3) 127 | raise ValueError(f"_render_messages: template rendering error: {e}") from e 128 | 129 | def get_input_ids(self, role: Optional[str] = 'bot') -> list[int]: 130 | """Get a list of token IDs in this thread, to be used for inference 131 | 132 | - role: 133 | The role for which inference will be performed (usually 'bot'). Can be 'system', 134 | 'user', 'bot', or None. If None, no role prefix will be appended (this is useful 135 | when you just want to get all the tokens in this Thread but are not going to do 136 | inference).""" 137 | 138 | if role is None and len(self.messages) == 0: 139 | if self.llama.add_bos_token(): 140 | return [self.llama.token_bos()] 141 | else: 142 | return [] 143 | 144 | input_ids = [] 145 | if len(self.messages) > 0: 146 | # the prefix of the first message requires `add_special=True` in order to set 147 | # the BOS token correctly 148 | first_msg = self.messages[0] 149 | if first_msg['role'].lower() in Thread.valid_system_roles: 150 | input_ids.extend(self.llama.tokenize( 151 | text_bytes=ez_encode(self.prompt_format.system_prefix()), 152 | add_special=True, 153 | parse_special=True 154 | )) 155 | input_ids.extend(self.llama.tokenize( 156 | text_bytes=ez_encode(first_msg['content']), 157 | add_special=False, 158 | parse_special=False 159 | )) 160 | input_ids.extend(self.llama.tokenize( 161 | text_bytes=ez_encode(self.prompt_format.system_suffix()), 162 | add_special=False, 163 | parse_special=True 164 | )) 165 | 166 | elif first_msg['role'].lower() in Thread.valid_user_roles: 167 | input_ids.extend(self.llama.tokenize( 168 | text_bytes=ez_encode(self.prompt_format.user_prefix()), 169 | add_special=True, 170 | parse_special=True 171 | )) 172 | input_ids.extend(self.llama.tokenize( 173 | text_bytes=ez_encode(first_msg['content']), 174 | add_special=False, 175 | parse_special=False 176 | )) 177 | input_ids.extend(self.llama.tokenize( 178 | text_bytes=ez_encode(self.prompt_format.user_suffix()), 179 | add_special=False, 180 | parse_special=True 181 | )) 182 | 183 | elif first_msg['role'].lower() in Thread.valid_bot_roles: 184 | input_ids.extend(self.llama.tokenize( 185 | text_bytes=ez_encode(self.prompt_format.bot_prefix()), 186 | add_special=True, 187 | parse_special=True 188 | )) 189 | input_ids.extend(self.llama.tokenize( 190 | text_bytes=ez_encode(first_msg['content']), 191 | add_special=False, 192 | parse_special=False 193 | )) 194 | input_ids.extend(self.llama.tokenize( 195 | text_bytes=ez_encode(self.prompt_format.bot_suffix()), 196 | add_special=False, 197 | parse_special=True 198 | )) 199 | 200 | else: 201 | raise ValueError( 202 | f'Thread.get_input_ids: first message has invalid role {role!r}' 203 | ) 204 | 205 | # all the other messages are treated the same 206 | i = 0 207 | for msg in self.messages[1:]: 208 | i += 1 209 | if msg['role'].lower() in Thread.valid_system_roles: 210 | raise ValueError( 211 | f'Thread.get_input_ids: multiple system messages are not supported' 212 | ) 213 | elif msg['role'].lower() in Thread.valid_user_roles: 214 | input_ids.extend(self.llama.tokenize( 215 | text_bytes=ez_encode(self.prompt_format.user_prefix()), 216 | add_special=False, 217 | parse_special=True 218 | )) 219 | input_ids.extend(self.llama.tokenize( 220 | text_bytes=ez_encode(msg['content']), 221 | add_special=False, 222 | parse_special=False 223 | )) 224 | input_ids.extend(self.llama.tokenize( 225 | text_bytes=ez_encode(self.prompt_format.user_suffix()), 226 | add_special=False, 227 | parse_special=True 228 | )) 229 | elif msg['role'].lower() in Thread.valid_bot_roles: 230 | input_ids.extend(self.llama.tokenize( 231 | text_bytes=ez_encode(self.prompt_format.bot_prefix()), 232 | add_special=False, 233 | parse_special=True 234 | )) 235 | input_ids.extend(self.llama.tokenize( 236 | text_bytes=ez_encode(msg['content']), 237 | add_special=False, 238 | parse_special=False 239 | )) 240 | input_ids.extend(self.llama.tokenize( 241 | text_bytes=ez_encode(self.prompt_format.bot_suffix()), 242 | add_special=False, 243 | parse_special=True 244 | )) 245 | else: 246 | raise ValueError( 247 | f'Thread.get_input_ids: message {i} has invalid role {role!r}' 248 | ) 249 | 250 | if role is not None: 251 | # append the role prefix tokens to the end 252 | # (if role is None, no prefix is appended) 253 | if role.lower() in Thread.valid_system_roles: 254 | raise ValueError( 255 | f'Thread.get_input_ids: multiple system messages are not supported' 256 | ) 257 | elif role.lower() in Thread.valid_user_roles: 258 | input_ids.extend(self.llama.tokenize( 259 | text_bytes=ez_encode(self.prompt_format.user_prefix()), 260 | add_special=False, 261 | parse_special=True 262 | )) 263 | elif role.lower() in Thread.valid_bot_roles: 264 | input_ids.extend(self.llama.tokenize( 265 | text_bytes=ez_encode(self.prompt_format.bot_prefix()), 266 | add_special=False, 267 | parse_special=True 268 | )) 269 | else: 270 | raise ValueError(f'Thread.get_input_ids: invalid role {role!r}') 271 | 272 | # input_ids is now fully constructed 273 | n_input_ids = len(input_ids) 274 | 275 | _llama.log_verbose( 276 | f'Thread.get_input_ids: converted {len(self.messages)} messages to ' 277 | f'{n_input_ids} tokens' 278 | ) 279 | 280 | if n_input_ids >= self.llama._n_ctx: 281 | log( 282 | f'Thread.get_input_ids: length of input_ids {n_input_ids} ' 283 | f'equals or exceeds the current context length ' 284 | f'{self.llama._n_ctx}', 2 285 | ) 286 | 287 | return input_ids 288 | 289 | def send(self, content: str, n_predict: Optional[int] = None) -> str: 290 | """Send a message in this thread and return the generated response. This adds your 291 | message and the bot's message to the thread.""" 292 | self.messages.append({ 293 | 'role': 'user', 294 | 'content': content 295 | }) 296 | response_toks = self.llama.generate( 297 | input_tokens=self.get_input_ids(role='bot'), 298 | n_predict=n_predict if n_predict is not None else -1, 299 | stop_tokens=self._stop_tokens, 300 | sampler_preset=self.sampler_preset 301 | ) 302 | response_txt = self.llama.detokenize(response_toks, special=False) 303 | self.messages.append({ 304 | 'role': 'bot', 305 | 'content': response_txt 306 | }) 307 | return response_txt 308 | 309 | def as_string(self) -> str: 310 | """Return this thread's message history as a string""" 311 | result_str = '' 312 | for msg in self.messages: 313 | if msg['role'].lower() in Thread.valid_system_roles: 314 | result_str += ''.join([ 315 | self.prompt_format.system_prefix(), 316 | msg['content'], 317 | self.prompt_format.system_suffix() 318 | ]) 319 | elif msg['role'].lower() in Thread.valid_user_roles: 320 | result_str += ''.join([ 321 | self.prompt_format.user_prefix(), 322 | msg['content'], 323 | self.prompt_format.user_suffix() 324 | ]) 325 | elif msg['role'].lower() in Thread.valid_bot_roles: 326 | result_str += ''.join([ 327 | self.prompt_format.bot_prefix(), 328 | msg['content'], 329 | self.prompt_format.bot_suffix() 330 | ]) 331 | else: 332 | raise ValueError(f"Thread.as_string: invalid message role {msg['role']!r}") 333 | return result_str 334 | 335 | def add_message(self, role: str, content: str) -> None: 336 | """Append a message to `Thread.messages` with the specified role and content 337 | 338 | - role: 339 | The role of the message, for example 'system', 'user', or 'bot'. 340 | - content: 341 | The text content of the message.""" 342 | if role.lower() in Thread.valid_system_roles: 343 | self.messages.append({'role': 'system', 'content': content}) 344 | elif role.lower() in Thread.valid_user_roles: 345 | self.messages.append({'role': 'user', 'content': content}) 346 | elif role.lower() in Thread.valid_bot_roles: 347 | self.messages.append({'role': 'bot', 'content': content}) 348 | else: 349 | raise ValueError(f'Thread.add_message: invalid role {role!r}') 350 | 351 | def warmup(self) -> None: 352 | input_ids = self.get_input_ids() 353 | if self.llama._first_valid_pos(input_ids) < len(input_ids): 354 | _llama.log_verbose('Thread.warmup: processing thread content with model ...') 355 | with suppress_output(disable=_llama.get_verbose()): 356 | self.llama.generate(input_tokens=input_ids, n_predict=0) 357 | 358 | # if the above condition is not True, the thread is already in the cache, so 359 | # nothing needs to be done 360 | _llama.log_verbose('Thread.warmup: done') 361 | 362 | def interact(self, stream: bool = True) -> None: 363 | """Start an interactive terminal-based chat using this thread""" 364 | R = ANSI.MODE_RESET_ALL 365 | B = ANSI.FG_BRIGHT_CYAN 366 | G = ANSI.FG_BRIGHT_GREEN 367 | with KeyboardInterruptHandler(): 368 | print() 369 | while True: 370 | user_input = input(f'{R} > {G}') 371 | print(R, end='\n', flush=True) 372 | 373 | if stream: 374 | self.messages.append({'role': 'user', 'content': user_input}) 375 | input_ids = self.get_input_ids() 376 | 377 | tok_gen = self.llama.stream( 378 | input_tokens=input_ids, 379 | n_predict=-1, 380 | stop_tokens=self._stop_tokens, 381 | sampler_preset=self.sampler_preset 382 | ) 383 | 384 | response_toks = [] 385 | detok_bytes_buffer = b'' 386 | 387 | for tok in tok_gen: 388 | response_toks.append(tok) 389 | # 390 | # detok_bytes_buffer holds any incomplete UTF-8 characters until they 391 | # are completed by future tokens 392 | # 393 | # for example, emojis are often split between two tokens, with one or 394 | # both of those tokens not being valid UTF-8 on its own 395 | # 396 | detok_bytes_buffer += self.llama.token_to_piece(tok, special=False) 397 | try: 398 | detok_txt = detok_bytes_buffer.decode('utf-8', errors='strict') 399 | except UnicodeDecodeError: 400 | pass # try again on next token 401 | else: 402 | detok_bytes_buffer = b'' 403 | print(f'{B}{detok_txt}{R}', end='', flush=True) 404 | 405 | # print any leftover bytes (though ideally there should be none) 406 | if detok_bytes_buffer != b'': 407 | leftover_txt = ez_decode(detok_bytes_buffer) 408 | print(f'{B}{leftover_txt}{R}', end='', flush=True) 409 | 410 | self.messages.append({ 411 | 'role': 'bot', 412 | 'content': self.llama.detokenize(response_toks, special=False) 413 | }) 414 | 415 | print() 416 | if not _llama.get_verbose(): 417 | print() 418 | 419 | else: 420 | response = self.send(user_input) 421 | print(f'\n{B}{response}{R}\n') 422 | 423 | def give_input_output_examples(self, examples: dict[str, str]) -> None: 424 | """Provide examples for few-shot prompting""" # TODO: this should be renamed or removed 425 | for input_msg_content, output_msg_content in examples.items(): 426 | self.add_message('user', input_msg_content) 427 | self.add_message('bot', output_msg_content) 428 | 429 | def summarize(self) -> str: 430 | """Generate a summary of this thread""" 431 | thread_as_string = self.as_string() 432 | orig_thread_messages = self.messages.copy() 433 | self.messages = [ 434 | { 435 | 'role': 'system', 436 | 'content': 'Follow the given instructions exactly. Do not add any unnecessary ' 437 | 'information.' 438 | }, 439 | { 440 | 'role': 'user', 441 | 'content': 'Take a moment to read through the following conversation ' 442 | 'carefully. When you\'re done, write a single paragraph that ' 443 | 'explains all of the most relevant details.' 444 | f'\n\n```\n{thread_as_string}\n```\n\n' 445 | 'Now that you\'ve read the above conversation, provide a summary ' 446 | 'in the form of a single paragraph.' 447 | } 448 | ] 449 | input_ids = self.get_input_ids() # uses the above messages 450 | output_ids = self.llama.generate(input_tokens=input_ids, n_predict=300) 451 | summary = self.llama.detokenize(output_ids, special=False) 452 | self.messages = orig_thread_messages.copy() 453 | return summary 454 | 455 | def print_stats(self, file: _SupportsWriteAndFlush = sys.stderr) -> None: 456 | """Print stats about the context usage in this thread""" 457 | with suppress_output(): 458 | input_ids = self.get_input_ids(role=None) 459 | n_thread_tokens = len(input_ids) 460 | n_msgs = len(self.messages) 461 | n_ctx = self.llama._n_ctx 462 | c = (n_thread_tokens/n_ctx) * 100 463 | ctx_used_pct = int(c) + (c > int(c)) # round up to next integer 464 | print(f"{n_thread_tokens} / {n_ctx} tokens", file=file) 465 | print(f"{ctx_used_pct}% of context used", file=file) 466 | print(f"{n_msgs} messages", file=file) 467 | 468 | def reset(self) -> None: 469 | self.messages = self._orig_messages.copy() 470 | -------------------------------------------------------------------------------- /easy_llama/utils.py: -------------------------------------------------------------------------------- 1 | # utils.py 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | """Submodule containing convenience functions used throughout easy_llama""" 6 | 7 | from . import __version__ 8 | 9 | import os 10 | import sys 11 | import datetime 12 | import contextlib 13 | 14 | import numpy as np 15 | 16 | from typing import Iterable, TextIO, Optional, TypeVar, Generic, NoReturn, Literal, Union 17 | 18 | class ANSI: 19 | """ANSI codes for terminal emulators""" 20 | 21 | BELL = '\a' 22 | 23 | CLEAR = '\x1bc\x1b[3J' # technically this is two ANSI codes in one 24 | 25 | # Standard colors 26 | FG_BLACK = '\x1b[30m' 27 | BG_BLACK = '\x1b[40m' 28 | FG_RED = '\x1b[31m' 29 | BG_RED = '\x1b[41m' 30 | FG_GREEN = '\x1b[32m' 31 | BG_GREEN = '\x1b[42m' 32 | FG_YELLOW = '\x1b[33m' 33 | BG_YELLOW = '\x1b[43m' 34 | FG_BLUE = '\x1b[34m' 35 | BG_BLUE = '\x1b[44m' 36 | FG_MAGENTA = '\x1b[35m' 37 | BG_MAGENTA = '\x1b[45m' 38 | FG_CYAN = '\x1b[36m' 39 | BG_CYAN = '\x1b[46m' 40 | FG_WHITE = '\x1b[37m' 41 | BG_WHITE = '\x1b[47m' 42 | 43 | # Bright colors 44 | FG_BRIGHT_BLACK = '\x1b[90m' 45 | BG_BRIGHT_BLACK = '\x1b[100m' 46 | FG_BRIGHT_RED = '\x1b[91m' 47 | BG_BRIGHT_RED = '\x1b[101m' 48 | FG_BRIGHT_GREEN = '\x1b[92m' 49 | BG_BRIGHT_GREEN = '\x1b[102m' 50 | FG_BRIGHT_YELLOW = '\x1b[93m' 51 | BG_BRIGHT_YELLOW = '\x1b[103m' 52 | FG_BRIGHT_BLUE = '\x1b[94m' 53 | BG_BRIGHT_BLUE = '\x1b[104m' 54 | FG_BRIGHT_MAGENTA = '\x1b[95m' 55 | BG_BRIGHT_MAGENTA = '\x1b[105m' 56 | FG_BRIGHT_CYAN = '\x1b[96m' 57 | BG_BRIGHT_CYAN = '\x1b[106m' 58 | FG_BRIGHT_WHITE = '\x1b[97m' 59 | BG_BRIGHT_WHITE = '\x1b[107m' 60 | 61 | # Modes 62 | MODE_RESET_ALL = '\x1b[0m' 63 | MODE_BOLD_SET = '\x1b[1m' 64 | MODE_BOLD_RESET = '\x1b[22m' 65 | MODE_DIM_SET = '\x1b[2m' 66 | MODE_DIM_RESET = '\x1b[22m' 67 | MODE_ITALIC_SET = '\x1b[3m' 68 | MODE_ITALIC_RESET = '\x1b[23m' 69 | MODE_UNDERLINE_SET = '\x1b[4m' 70 | MODE_UNDERLINE_RESET = '\x1b[24m' 71 | MODE_BLINKING_SET = '\x1b[5m' 72 | MODE_BLINKING_RESET = '\x1b[25m' 73 | MODE_REVERSE_SET = '\x1b[7m' 74 | MODE_REVERSE_RESET = '\x1b[27m' 75 | MODE_HIDDEN_SET = '\x1b[8m' 76 | MODE_HIDDEN_RESET = '\x1b[28m' 77 | MODE_STRIKETHROUGH_SET = '\x1b[9m' 78 | MODE_STRIKETHROUGH_RESET = '\x1b[29m' 79 | 80 | NoneType: type = type(None) 81 | 82 | _VERBOSE = True 83 | 84 | _DEBUG = False 85 | """Package-wide debug flag""" 86 | 87 | class _ArrayLike(Iterable): 88 | """Anything that can be interpreted as a numpy array""" 89 | 90 | class _SupportsWriteAndFlush(TextIO): 91 | """A file, stream, or buffer that supports writing and flushing""" 92 | 93 | class UnreachableException(Exception): 94 | """The code has reached an unreachable state""" 95 | def __init__(self): 96 | super().__init__( 97 | "the code has reached a location that was thought to be " 98 | "unreachable. please report this issue to the developer at this " 99 | "link: https://github.com/ddh0/easy-llama/issues/new/choose" 100 | ) 101 | 102 | class LlamaNullException(Exception): 103 | """Raised when a libllama function returns NULL or NULLPTR""" 104 | 105 | T = TypeVar('T') 106 | 107 | class ptr(Generic[T]): 108 | """Generic type hint representing any ctypes pointer. Optionally subscriptable with any 109 | type.""" 110 | 111 | def set_verbose(state: bool) -> None: 112 | """Enable or disable verbose terminal output from easy-llama""" 113 | global _VERBOSE 114 | _VERBOSE = state 115 | 116 | def get_verbose() -> bool: 117 | """Return `True` if verbose terminal output is enabled in easy-llama, `False` otherwise""" 118 | global _VERBOSE 119 | return _VERBOSE 120 | 121 | @contextlib.contextmanager 122 | def KeyboardInterruptHandler(): 123 | log_verbose('Press CTRL+C to exit') 124 | try: 125 | yield 126 | except KeyboardInterrupt: 127 | print(ANSI.MODE_RESET_ALL, end='\n', flush=True) 128 | 129 | def log( 130 | text: str, 131 | level: Literal[1, 2, 3, 4] = 1, 132 | disable: bool = False 133 | ) -> None: 134 | """Print the given text, prefixed with a timestamp""" 135 | if disable: 136 | return 137 | timestamp = datetime.datetime.now().strftime("%Y-%m-%d %a %k:%M:%S.%f")[:-3] 138 | if level == 1: 139 | lvltxt = f"{ANSI.FG_BRIGHT_GREEN}INFO" 140 | elif level == 2: 141 | lvltxt = f"{ANSI.FG_BRIGHT_YELLOW}WARNING" 142 | elif level == 3: 143 | lvltxt = f"{ANSI.FG_BRIGHT_RED}ERROR" 144 | elif level == 4: 145 | lvltxt = f"{ANSI.FG_BRIGHT_MAGENTA}STOPWATCH" 146 | else: 147 | raise ValueError(f'parameter `level` must be one of `[1, 2, 3, 4]`, not {level}') 148 | print( 149 | f"{ANSI.MODE_RESET_ALL}{ANSI.MODE_BOLD_SET}{ANSI.FG_BRIGHT_BLACK}[{timestamp}]" 150 | f"{ANSI.MODE_RESET_ALL}{ANSI.MODE_BOLD_SET} {lvltxt}{ANSI.MODE_RESET_ALL}" 151 | f"{ANSI.MODE_BOLD_SET}{ANSI.FG_BRIGHT_BLACK}:{ANSI.MODE_RESET_ALL} {text}", 152 | end='\n', 153 | file=sys.stdout if level in [1,4] else sys.stderr, 154 | flush=True 155 | ) 156 | 157 | def log_verbose(text: str, level: Literal[1,2,3,4] = 1) -> None: 158 | if get_verbose(): 159 | log(text, level) 160 | 161 | def log_debug(text: str, level: Literal[1,2,3,4] = 1) -> None: 162 | if _DEBUG: 163 | log('[DEBUG] ' + text, level) 164 | 165 | def softmax(z, T: float = 1.0): 166 | """Numerically stable softmax over all dimensions of an arbitrarily shaped array in 167 | float32 precision.""" 168 | z_arr = np.array(z, dtype=np.float32) 169 | if z_arr.size == 0: 170 | return z_arr 171 | if T == 0.0: 172 | result = np.zeros_like(z_arr) 173 | flat = z_arr.ravel() 174 | result.flat[np.argmax(flat)] = 1.0 175 | return result 176 | if T < 0: 177 | z_arr = -z_arr 178 | T = -T 179 | max_val = np.max(z_arr) 180 | scaled = (z_arr - max_val) / T 181 | exp_vals = np.exp(scaled) 182 | return exp_vals / np.sum(exp_vals) 183 | 184 | def cls() -> None: 185 | """Clear the terminal""" 186 | if os.name == 'nt': 187 | os.system('cls') 188 | else: 189 | #os.system('clear') 190 | print(ANSI.CLEAR, end='', flush=True) 191 | 192 | def truncate(text: str) -> str: 193 | return text if len(text) < 72 else f"{text[:69]}..." 194 | 195 | def ez_encode(txt: str) -> bytes: 196 | """Encode the given text `txt` from string to UTF-8. If strict encoding fails, an error 197 | will be shown and the offending character(s) will be replaced with `?`.""" 198 | try: 199 | return txt.encode('utf-8', errors='strict') 200 | except UnicodeEncodeError: 201 | log(f'error encoding string to UTF-8. using ? replacement character.', level=3) 202 | return txt.encode('utf-8', errors='replace') 203 | 204 | def ez_decode(txt: bytes) -> str: 205 | """Decode the given text `txt` from UTF-8 to string. If strict decoding fails, an error 206 | will be shown and the offending character(s) will be replaced with `�` (U+FFFD).""" 207 | try: 208 | return txt.decode('utf-8', errors='strict') 209 | except UnicodeDecodeError: 210 | log(f'error decoding string from UTF-8. using � replacement character.', level=3) 211 | return txt.decode('utf-8', errors='replace') 212 | 213 | _open = open 214 | _sys = sys 215 | _os = os 216 | 217 | @contextlib.contextmanager 218 | def suppress_output(disable: bool = False): 219 | """Suppress stdout and stderr.""" 220 | 221 | # NOTE: simply changing sys.stdout and sys.stderr does not affect output from llama.cpp. 222 | # this method (or similar) is required to suppress all undesired output, for example 223 | # when `verbose=False`. 224 | 225 | if disable: 226 | yield 227 | else: 228 | # save the original file descriptors 229 | original_stdout_fd = _sys.stdout.fileno() 230 | original_stderr_fd = _sys.stderr.fileno() 231 | 232 | saved_stdout_fd = _os.dup(original_stdout_fd) 233 | saved_stderr_fd = _os.dup(original_stderr_fd) 234 | 235 | with _open(_os.devnull, 'wb') as devnull: 236 | devnull_fd = devnull.fileno() 237 | 238 | _os.dup2(devnull_fd, original_stdout_fd) 239 | _os.dup2(devnull_fd, original_stderr_fd) 240 | 241 | try: 242 | yield 243 | finally: 244 | # restore the original file descriptors 245 | _os.dup2(saved_stdout_fd, original_stdout_fd) 246 | _os.dup2(saved_stderr_fd, original_stderr_fd) 247 | 248 | _os.close(saved_stdout_fd) 249 | _os.close(saved_stderr_fd) 250 | 251 | def assert_type( 252 | obj: object, 253 | expected_type: Union[type, tuple[type]], 254 | obj_name: str, 255 | code_location: str, 256 | hint: Optional[str] = None 257 | ): 258 | """Ensure that `obj` is an instance of `expected_type`. 259 | 260 | If `expected_type` is a tuple, ensure that `obj` is an instance of some type in the tuple. 261 | 262 | Raise `TypeError` otherwise, using `obj_name` and `code_location` to make an informative 263 | exception message. 264 | 265 | If specified, `hint` is added as a note to the exception.""" 266 | 267 | if isinstance(obj, expected_type): 268 | return 269 | 270 | obj_type_repr = repr(type(obj).__name__) 271 | 272 | if not isinstance(expected_type, tuple): 273 | # represent `int` as 'int' instead of "" 274 | expected_type_repr = repr(expected_type.__name__) 275 | exc = TypeError( 276 | f"{code_location}: {obj_name} should be an instance of {expected_type_repr}, " 277 | f"not {obj_type_repr}" 278 | ) 279 | else: 280 | # represent `(int, list)` as "('int', 'list')" instead of 281 | # "(, )" 282 | expected_type_repr = repr(tuple(t.__name__ for t in expected_type)) 283 | exc = TypeError( 284 | f"{code_location}: {obj_name} should be one of {expected_type_repr}, " 285 | f"not {obj_type_repr}" 286 | ) 287 | if isinstance(hint, str): 288 | exc.add_note(hint) 289 | raise exc 290 | 291 | def null_ptr_check(ptr: ptr, ptr_name: str, loc_hint: str) -> None | NoReturn: 292 | """Ensure that the given object `ptr` is not NULL / NULLPTR 293 | 294 | Raise LlamaNullException on failure 295 | 296 | - ptr: 297 | The object to check 298 | - ptr_name: 299 | The name of the object (for error messages) 300 | - loc_hint: 301 | Code location hint used in easy-llama""" 302 | if not bool(ptr): 303 | raise LlamaNullException(f"{loc_hint}: pointer `{ptr_name}` is null") 304 | -------------------------------------------------------------------------------- /easy_llama/webui/index.html: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddh0/easy-llama/04b75b10d8854025eb4855e616a682962e3eef82/easy_llama/webui/index.html -------------------------------------------------------------------------------- /easy_llama/webui/script.js: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ddh0/easy-llama/04b75b10d8854025eb4855e616a682962e3eef82/easy_llama/webui/script.js -------------------------------------------------------------------------------- /easy_llama/webui/style.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --color-black: #141414; 3 | --color-dark-grey: #292929; 4 | --color-neutral-grey: #808080; 5 | --color-light-grey: #e1e1e1; 6 | --color-white: #f5f5f5; 7 | --color-green: #22c55e; 8 | --color-green-dim: #1cab51; 9 | --color-blue: #0ea5e9; 10 | --color-blue-dim: #0b8fcb; 11 | --color-red: #ef4444; 12 | --color-red-dim: #d03a3a; 13 | --color-purple: #a855f7; 14 | --color-purple-dim: #9249d7; 15 | --color-yellow: #eab308; 16 | --color-yellow-dim: #cb9b06; 17 | } 18 | 19 | body { 20 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; 21 | background-color: var(--color-black); 22 | color: var(--color-white); 23 | line-height: 1.0; 24 | display: flex; 25 | justify-content: center; 26 | align-items: center; 27 | min-height: 100svh; 28 | padding: 10px; 29 | } 30 | -------------------------------------------------------------------------------- /examples/chat_demo.py: -------------------------------------------------------------------------------- 1 | # chat_demo.py 2 | # Python 3.12.9 3 | 4 | import easy_llama as ez 5 | 6 | Llama = ez.Llama( 7 | path_model='/path/to/your-llama3-model.gguf', 8 | n_gpu_layers=-1, 9 | n_ctx=1024, 10 | offload_kqv=True, 11 | warmup=False, 12 | verbose=True 13 | ) 14 | 15 | Thread = ez.Thread( 16 | llama=Llama, 17 | prompt_format=ez.PromptFormats.Llama3("You are a helpful assistant.") 18 | ) 19 | 20 | ez.set_verbose(False) 21 | ez.utils.cls() 22 | Thread.interact() 23 | print('-' * 80) 24 | Thread.print_stats() 25 | -------------------------------------------------------------------------------- /examples/pretrained_demo.py: -------------------------------------------------------------------------------- 1 | # pretrained_demo.py 2 | # Python 3.12.9 3 | 4 | import easy_llama as ez 5 | 6 | Llama = ez.Llama( 7 | path_model='/path/to/your-pretrained-model.gguf', 8 | n_gpu_layers=-1, 9 | n_ctx=1024, 10 | offload_kqv=True, 11 | warmup=False, 12 | verbose=True 13 | ) 14 | 15 | ez.set_verbose(False) 16 | ez.utils.cls() 17 | 18 | prompt = input(f'Enter your prompt:\n > ') 19 | 20 | print( 21 | f'\n\n{ez.utils.ANSI.FG_BRIGHT_GREEN}{prompt}{ez.utils.ANSI.MODE_RESET_ALL}', 22 | end='', 23 | flush=True 24 | ) 25 | 26 | prompt_bytes = prompt.encode('utf-8', errors='replace') 27 | 28 | prompt_tokens = Llama.tokenize(prompt_bytes, add_special=True, parse_special=False) 29 | 30 | token_generator = Llama.stream( 31 | input_tokens=prompt_tokens, 32 | n_predict=-1, 33 | stop_tokens=[] 34 | ) 35 | 36 | for token in token_generator: 37 | tok_str = Llama.token_to_piece(token, special=True).decode('utf-8', errors='replace') 38 | print( 39 | f"{ez.utils.ANSI.FG_BRIGHT_CYAN}{tok_str}{ez.utils.ANSI.MODE_RESET_ALL}", 40 | sep='', 41 | end='', 42 | flush=True 43 | ) 44 | 45 | print(f'\n{"-" * 80}') 46 | -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | # import the package 2 | import easy_llama as ez 3 | 4 | # load a model from a GGUF file (if $LIBLLAMA is not set, this will fail) 5 | MyLlama = ez.Llama('Qwen3-4B-Q8_0.gguf') 6 | 7 | # evaluate a single token and print the raw logits for inferred the next token 8 | logits = MyLlama.eval([0]) 9 | print(logits) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # pyproject.toml 2 | # https://github.com/ddh0/easy-llama/ 3 | # MIT License -- Copyright (c) 2024 Dylan Halladay 4 | 5 | [build-system] 6 | requires = ["setuptools>=61.0"] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [project] 10 | name = "easy_llama" 11 | dynamic = ["version"] 12 | description = "This is easy-llama, a Python package that wraps the C/C++ API provided by llama.cpp. It's intended for developers and machine learning hobbyists who want to integrate on-device language models (LLMs) into their applications." 13 | readme = "README.md" 14 | authors = [{ name = "Dylan Halladay", email = "chemist-mulches-39@icloud.com" }] 15 | license = { text = "MIT" } 16 | requires-python = ">=3.9" 17 | classifiers = [ 18 | "Development Status :: 4 - Beta", 19 | "Intended Audience :: Developers", 20 | "Intended Audience :: Science/Research", 21 | "License :: OSI Approved :: MIT License", 22 | "Natural Language :: English", 23 | "Programming Language :: Python :: 3 :: Only", 24 | "Programming Language :: Python :: 3.9", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12" 28 | ] 29 | dependencies = [ 30 | "numpy", 31 | "fastapi", 32 | "uvicorn", 33 | "jinja2", 34 | "tqdm" 35 | ] 36 | 37 | [project.urls] 38 | Homepage = "https://github.com/ddh0/easy-llama" 39 | 40 | [tool.setuptools] 41 | packages = ["easy_llama", "easy_llama.webui"] 42 | include-package-data = true 43 | 44 | [tool.setuptools.dynamic] 45 | version = {attr = "easy_llama.__version__"} 46 | 47 | [tool.setuptools.package-data] 48 | "easy_llama.webui" = [ 49 | "*.ico", 50 | "*.png", 51 | "*.html", 52 | "*.css", 53 | "*.js", 54 | "*.webmanifest", 55 | ] 56 | --------------------------------------------------------------------------------