├── .gitignore ├── LICENSE ├── README.md ├── pyproject.toml └── src ├── examples ├── attention.py ├── json_schema_cli.py ├── llm_schema.py ├── reluctance.py ├── requirements.txt ├── reusable_kv_cache.py ├── server.py └── static │ ├── attention.html │ └── ui.html ├── llm_structured_output ├── __init__.py ├── acceptor.py ├── json_acceptor.py ├── json_schema_acceptor.py └── util │ ├── __init__.py │ ├── bitmap.py │ ├── output.py │ ├── tokenization.py │ └── tokentrie.py └── tests ├── __init__.py ├── data └── fireworks-ai_function-calling-eval-dataset-v0 │ ├── completions-multi_turn-Meta-Llama-3-8B-Instruct-4bit.jsonl │ ├── completions-multi_turn-OpenAI-gpt-4o-2024-05-13.jsonl │ ├── completions-single_turn-Meta-Llama-3-8B-Instruct-4bit.jsonl │ ├── completions-single_turn-OpenAI-gpt-4o-2024-05-13.jsonl │ ├── multi_turn-00000-of-00001.jsonl │ ├── multi_turn-00000-of-00001.parquet │ ├── parquet_to_jsonl.py │ ├── report-multi_turn.md │ ├── report-single_turn.md │ ├── requirements.txt │ ├── single_turn-00000-of-00001.jsonl │ └── single_turn-00000-of-00001.parquet ├── eval_api.py ├── eval_local.py ├── eval_report.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .cache/ 2 | .venv/ 3 | **/Library/ 4 | **/__pycache__/ 5 | dist -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Oscar D.P. Triscon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM Structured Output: JSON Schema, Function Calling, Tools 2 | 3 | This repository contains a library to constrain LLM generation to structured 4 | output, such as function calling a.k.a. tool use. 5 | 6 | We include examples of application implementations using the MLX library. 7 | 8 | Differences with other approaches: 9 | 10 | - "JSON mode": this library constrains output to be valid JSON, but goes 11 | beyond JSON mode in also enforcing a JSON schema. This enables much tighter 12 | steeing: specifying data types, property names, etc. 13 | 14 | - GBNF translation: rather than converting the JSON schema to a formal grammar, 15 | we steer the output directly using the schema, which enables more flexible 16 | and deeper control with lower overhead. For example, expressing minimum and 17 | maximum array or string lengths in GBNF can lead to very large set of 18 | production rules, and certain JSON schema features are simply not possible. 19 | 20 | - Fine-tuning: our approach is complementary to fine-tuning an LLM to produce 21 | structured output. While fine-tuning currently can enhance but not guarantee 22 | adherence to a schema, our system introduces strong guarantees on the output. 23 | 24 | ## Demo 25 | 26 | https://github.com/otriscon/llm-structured-output/assets/165947759/f38704da-34b0-4601-be8b-48b92199445d 27 | 28 | Without a schema, Mistral 7B Instruct 0.2 solves the data extraction task but, 29 | despite our instructions to the contrary, it adds a lot of additional output that's 30 | not necessary, is hard to parse, and wastes time. 31 | 32 | https://github.com/otriscon/llm-structured-output/assets/165947759/f79a78ca-8244-4ec6-9e90-b6cdedfbb8b0 33 | 34 | With the schema, the generation is precisely the output we require. 35 | 36 | ## What's in the box 37 | 38 | You'll find: 39 | 40 | - A framefork and set of acceptors for constraining LLM output, which are 41 | application-independent. 42 | 43 | - Reference implementations and examples using Apple's MLX library. 44 | 45 | ### Framework and JSON acceptors 46 | 47 | - An acceptor/state machine framework which progresses all valid states of a 48 | given graph simultaneously. This minimizes the need for backtracking, which 49 | is expensive for LLMs as it would require re-computing past tokens. In this 50 | sense, the concept is similar to a chart parser or Earley-style recognizer 51 | and shares a similar motivation. In practice, it's quite different because 52 | we're dealing with token-level input. We implemented several optimizations 53 | to minimize combinatorial explosion: we use a trie to traverse the token 54 | vocabulary in logarithmic time, and collapse the trie branches when multiple 55 | options are equivalent. We also prune the chart by removing equivalent 56 | states arrived at by different paths. See [acceptor.py](src/llm_structured_output/acceptor.py). 57 | 58 | - A JSON acceptor based on the framework above that accepts valid JSON. See 59 | [json_acceptor.py](src/llm_structured_output/json_acceptor.py). 60 | 61 | - A JSON schema acceptor based on both items above that accepts valid JSON that 62 | conforms to a JSON schema. See [json_schema_acceptor.py](src/llm_structured_output/json_schema_acceptor.py). 63 | Please note that most but not all JSON schema directives are implemented. 64 | Please open an issue if one that you need is not. 65 | 66 | ### Reference implementation / examples 67 | 68 | - An example of using the acceptors above to guide decoding in an LLM using 69 | Apple's MLX framework. See [llm_schema.py](src/examples/llm_schema.py). 70 | This example includes several decoding techniques, including pre-emptive evaluation, 71 | which is a way to use the acceptor to anticipate the tokens that can be generated 72 | according to the schema, and use that to evaluate two tokens at a time instead of 73 | one, sometimes leading to noticeable performance improvements. 74 | 75 | - A server example that implements an OpenAI-compatible API including tools / function 76 | calling. Unlike [OpenAI's](https://platform.openai.com/docs/api-reference/chat/object), 77 | this implementation always generates valid JSON, and does not return hallucinated 78 | parameters not defined in your function schema (but it may still hallucinate their 79 | values). See [server.py](src/examples/server.py). 80 | 81 | ## Usage 82 | 83 | ### Run the examples on Apple hardware with MLX 84 | 85 | Clone this repo: 86 | 87 | ```sh 88 | git clone https://github.com/otriscon/llm-structured-output.git 89 | cd llm-structured-output 90 | ``` 91 | 92 | Optional, but recommended: create and activate a virtual environment with your favorite tool of choice, e.g. 93 | 94 | ```sh 95 | python -m venv .venv 96 | source .venv/bin/activate 97 | ``` 98 | 99 | Move into the examples folder and install the requirements, then move back: 100 | 101 | ```sh 102 | cd src/examples 103 | pip install -r requirements.txt 104 | cd .. 105 | ``` 106 | 107 | Choose a model from the [HuggingFace MLX community](https://huggingface.co/mlx-community), e.g. `mlx-community/Meta-Llama-3.1-8B-Instruct-4bit`. Models are downloaded automatically on first use and cached locally. 108 | 109 | Run the llm_schema example: 110 | 111 | ```sh 112 | MODEL=mlx-community/Meta-Llama-3.1-8B-Instruct-4bit 113 | 114 | LLM_PROMPT='[INST] Parse the following address into a JSON object: "27 Barrow St, New York, NY 10014". Your answer should be only a JSON object according to this schema: {"type": "object", "properties": {"streetNumber": {"type": "number"}, "streetName": {"type": "string"}, "city": {"type": {"string"}}, "state": {"type": "string"}, "zipCode": {"type": "number"}}}. Do not explain the result, just output it. Do not add any additional information. [/INST]' 115 | 116 | LLM_SCHEMA='{"type": "object", "properties": {"streetNumber": {"type": "number"}, "streetName": {"type": "string"}, "city": {"type": "string"}, "state": {"type": "string"}, "zipCode": {"type": "number"}}}' 117 | 118 | python3 -m examples.llm_schema --model-path $MODEL --prompt "$LLM_PROMPT" --schema "$LLM_SCHEMA" --max-tokens 1000 --repeat-prompt 119 | ``` 120 | 121 | Run the server example: 122 | 123 | ```sh 124 | MODEL_PATH=mlx-community/Meta-Llama-3.1-8B-Instruct-4bit uvicorn examples.server:app --port 8080 --reload 125 | ``` 126 | 127 | Try calling the server with this example adapted from [the OpenAI documentation (click on the example request titled _Functions_)](https://platform.openai.com/docs/api-reference/chat/create): 128 | ```sh 129 | curl http://localhost:8080/v1/chat/completions \ 130 | -H "Content-Type: application/json" \ 131 | -d '{ 132 | "model": "ignored", 133 | "messages": [ 134 | { 135 | "role": "user", 136 | "content": "What'\''s the weather like in Boston today?" 137 | } 138 | ], 139 | "tools": [ 140 | { 141 | "type": "function", 142 | "function": { 143 | "name": "get_current_weather", 144 | "description": "Get the current weather in a given location", 145 | "parameters": { 146 | "type": "object", 147 | "properties": { 148 | "location": { 149 | "type": "string", 150 | "description": "The city and state, e.g. San Francisco, CA" 151 | }, 152 | "unit": { 153 | "type": "string", 154 | "enum": ["celsius", "fahrenheit"] 155 | } 156 | }, 157 | "required": ["location"] 158 | } 159 | } 160 | } 161 | ], 162 | "tool_choice": "auto" 163 | }' 164 | ``` 165 | 166 | ### Using the JSON schema acceptor in your project 167 | 168 | Install in your project with `pip install llm-structured-output` and 169 | use a `JsonSchemaAcceptorDriver` within your normal generation loop: 170 | 171 | ```python 172 | import json 173 | import mlx.core as mx 174 | from mlx_lm.utils import load # Needs pip import mlx_lm 175 | from llm_structured_output import JsonSchemaAcceptorDriver, HuggingfaceTokenizerHelper, bias_logits 176 | 177 | 178 | MODEL_PATH = "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit" 179 | SCHEMA = { 180 | "type": "object", 181 | "properties": { 182 | "streetNumber": {"type": "number"}, 183 | "streetName": {"type": "string"}, 184 | "city": {"type": "string"}, 185 | "state": {"type": "string"}, 186 | "zipCode": {"type": "number"}, 187 | }, 188 | } 189 | PROMPT = f''' 190 | [INST] Parse the following address into a JSON object: "27 Barrow St, New York, NY 10014". 191 | Your answer should be only a JSON object according to this schema: {json.dumps(SCHEMA)} 192 | Do not explain the result, just output it. Do not add any additional information. [/INST] 193 | ''' 194 | 195 | 196 | # Load the model as usual. 197 | model, tokenizer = load(MODEL_PATH) 198 | 199 | # Instantiate a token acceptor 200 | tokenizer_helper = HuggingfaceTokenizerHelper(tokenizer) 201 | vocabulary, eos_id = tokenizer_helper.extract_vocabulary() 202 | token_acceptor_factory = JsonSchemaAcceptorDriver.driver_factory_for_model(vocabulary, eos_id) 203 | token_acceptor = token_acceptor_factory(SCHEMA) 204 | 205 | cache = None 206 | tokens = tokenizer_helper.encode_prompt(PROMPT) 207 | 208 | while tokens[-1] != eos_id: 209 | # Evaluate the model as usual. 210 | logits, cache = model(mx.array(tokens)[None], cache) 211 | 212 | # Set probability to -inf for invalid tokens. 213 | accepted_token_bitmap = token_acceptor.select_valid_tokens() 214 | logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap) 215 | 216 | # Sample as usual, e.g.: 217 | token = mx.argmax(logits, axis=-1).item() 218 | 219 | if token == eos_id: 220 | break 221 | 222 | # Store or use the generated token. 223 | tokens = [token] 224 | text = tokenizer_helper.no_strip_decode(tokens) 225 | print(text, end="") 226 | 227 | # Advance the acceptor to the next state. 228 | token_acceptor.advance_token(token) 229 | ``` 230 | 231 | ## A note about guarantees on the output 232 | 233 | Constraining the output of an LLM to follow a schema doesn't magically make the 234 | LLM great at producing output that solves a particular task. 235 | 236 | If an LLM that is not prompted or fine-tuned correctly to solve the task, it 237 | will produce syntactically valid output but the values inside won't necessarily 238 | constitute a good solution. As with any other technique, proper LLM prompting 239 | and/or n-shot examples are crucial to avoid getting nice-looking, 240 | well-formatted, schema-compliant nonsense. 241 | 242 | In particular, it's crucial to instruct the LLM regarding the desired output 243 | format, including making the desired schema part of the prompt. Here's an 244 | example of a prompt that includes the schema: 245 | 246 | ``` 247 | Parse the following address into a JSON object: "27 Barrow St, New York, NY 10014". 248 | Your answer should be only a JSON object according to this schema: {"type": "object", "properties": {"streetNumber": {"type": "number"}, "streetName": {"type": "string"}, "city": {"type": {"string"}}, "state": {"type": "string"}, "zipCode": {"type": "number"}}}. 249 | Do not explain the result, just output it. Do not add any additional information. 250 | ``` 251 | 252 | In order to give the LLM a scratch-pad prior to JSON generation for e.g. 253 | chain-of-thought reasoning, we have included an option for the acceptor to kick in 254 | only on output within a section delimited by the lines `` ```json `` and `` ``` ``, 255 | with the prior output treated as free text. This is enabled with the `is_encapsulated_json` 256 | option of the `JsonSchemaAcceptorDriver` constructor. Here's an example of a 257 | prompt that produces encapsulated JSON: 258 | ``` 259 | Your mission is to parse the following address into a JSON object: "27 Barrow St, New York, NY 10014". 260 | Your answer should be a JSON object according to this schema: {"type": "object", "properties": {"streetNumber": {"type": "number"}, "streetName": {"type": "string"}, "city": {"type": {"string"}}, "state": {"type": "string"}, "zipCode": {"type": "number"}}}. 261 | First, think through the task step by step, and then output a JSON object wrapped between the lines ```json and ```. 262 | ``` 263 | 264 | In our OpenAI-compatible server example, when the request specifies `tool_calls` or a 265 | legacy `function_call`, we automatically prepend a system message to the prompt with 266 | the schema and instructions for the LLM to use the tools provided. If your prompt already 267 | includes these instructions (because e.g. you want to customize them), this can be disabled 268 | with a non-standard option in the request payload: `"tool_options": { "no_prompt_steering": true }` 269 | 270 | 271 | ## Testing 272 | 273 | The library has been tested with the following datasets: 274 | 275 | - [Fireworks.ai's function calling eval dataset](https://huggingface.co/datasets/fireworks-ai/function-calling-eval-dataset-v0/) 276 | 277 | - [ALU.AI's table extraction](https://blog.alu.ai/tables-and-structured-data/) evaluation dataset (not yet open-source) 278 | 279 | ## Evaluations 280 | 281 | We're starting to perform evaluations to understand how well different LLMs perform 282 | in function calling tasks. The tools and data can be found in the [src/tests](src/tests/) folder. 283 | 284 | ### Fireworks.ai function calling eval dataset 285 | 286 | Environment: 287 | 288 | - llm_structured_output v0.0.15 289 | - mlx 0.14.1 290 | - 2023 Mac Studio M2 Ultra 24 cores (16 performance and 8 efficiency) 192 GB RAM running macOS Sonoma 14.5 291 | - LLM: mlx-community/Meta-Llama-3-8B-Instruct-4bit 292 | - Benchmarking LLM: gpt-4o-2024-05-13 293 | 294 | Results: 295 | 296 | - [multi-turn dataset report](src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/report-multi_turn.md) 297 | 298 | - [single-turn dataset report](src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/report-single_turn.md) 299 | 300 | 301 | ## Performance 302 | 303 | Since we need to select the acceptable tokens prior to sampling, constraining 304 | the output according to a schema introduces a delay for every token, which 305 | depends on the complexity of the schema. On the other hand, since the output is 306 | guaranteed to be valid JSON and to conform to the schema, it can reduce the 307 | number of tokens generated and reduce or eliminate the number of retries 308 | required to solve the task. 309 | 310 | ### Pre-emptive decoding experiment 311 | As an experiment to improve performance, we implement the option to use 312 | pre-emptive decoding: when the range of tokens that can be accepted after the 313 | current one is small, as often happens with structured output, we submit to the 314 | LLM a batch of two-token continuations where the first token is the one that 315 | was to be evaluated anyway, and the second token in each item in the batch is 316 | one of the possible continuations predicted according to the schema. We can 317 | then sample two tokens instead of one. We find that this approach can 318 | occasionally produce considerable increases in token generation speed, but in 319 | general it can also considerably slow it down, depending on model and 320 | quantization. We found that it works better with no fp16 models (no quantization), 321 | but batching performance degrades vastly in quantized models making pre-emptive 322 | decoding not worth it for those models. 323 | 324 | ### Benchmarks 325 | 326 | - The following tests were perfomed on an Apple Studio with an M2 Ultra (24 core) 327 | with 192GB of RAM using MLX version 0.9.0, with models converted to MLX format. 328 | 329 | - The results are the average of 5 runs on a simple data extraction task with a 330 | 127-token prompt. 331 | 332 | - Pre-emptive decoding was tested in two different forms: with a constant batch 333 | size, where we always sent the same size matrices for evaluation, and variable- 334 | size batching, where we made the batch large or shorter depending on the numer 335 | of possible follow-up tokens. 336 | 337 |
338 | 339 | | Mistral-7B-v0.2-Instruct (fp16) | Prompt tps | Generation tps | Generation tokens | 340 | | --- | :-: | :-: | :-: | 341 | | No schema | 305.82 | 34.76 | 321 | 342 | | Schema | 307.00 | 31.70 | 42 | 343 | | Pre-emptive constant batch =5 | 211.72 | 33.16 | 42 | 344 | | Pre-emptive variable batch <=5 | 321.85 | 36.53 | 42 | 345 | 346 | 347 | **Notes:** 348 | 349 | - Pre-emptive decoding accelerates generation even over schemaless generation. 350 | 351 |
352 |
353 | 354 | | Mistral-7B-v0.2-Instruct (q4) | Prompt tps | Generation tps | Generation tokens | 355 | | --- | :-: | :-: | :-: | 356 | | No schema | 487.19 | 86.36 | 137 | 357 | | Schema | 487.83 | 67.60 | 42 | 358 | | Pre-emptive constant batch =5 | 139.61 | 27.16 | 42 | 359 | | Pre-emptive variable batch <=5 | 488.88 | 36.25 | 42 | 360 | 361 | **Notes:** 362 | 363 | - Pre-emptive decoding is vastly slower, with the only change being quantization. 364 | 365 |
366 |
367 | 368 | | Mixtral-8x7B-Instruct-v0.1 (fp16) | Prompt tps | Generation tps | Generation tokens | 369 | | --- | :-: | :-: | :-: | 370 | | No schema | 3.48 | 2.23 | 50 | 371 | | Schema | 3.49 | 2.21 | 50 | 372 | | Pre-emptive constant batch =5 |2.36 | 1.16 | 50 | 373 | | Pre-emptive variable batch <=5 | 3.18 | 1.68 | 50 | 374 | 375 | **Notes:** 376 | 377 | - This is the only tested model that outputs schema-conforming output without a schema. 378 | 379 | - Pre-emptive decoding is a lot slower again. 380 | 381 |
382 |
383 | 384 | | Mixtral-8x7B-Instruct-v0.1 (q4) | Prompt tps | Generation tps | Generation tokens | 385 | | --- | :-: | :-: | :-: | 386 | | No schema | 15.02 | 32.21 | 165 | 387 | | Schema | 14.94 | 23.75 | 50 | 388 | | Pre-emptive constant batch =5 | 9.29 | 11.28 | 50 | 389 | | Pre-emptive variable batch <=5 | 15.02 | 17.94 | 50 | 390 | 391 | ## Roadmap 392 | 393 | - Extend JSON schema support as needed (see TODOs in code). Please, feel free to 394 | open an issue if you need a feature that not supported at the moment. Also open to 395 | implement additional schemas such as YAML and reference implementations for other LLMs. 396 | 397 | - Add formal test cases. 398 | 399 | - Reference implementation for the Transformers library. 400 | 401 | - Port to C++ and reference implementation for llama.cpp 402 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "llm_structured_output" 7 | version = "0.0.20" 8 | authors = [ 9 | { name="Oscar D.P. Triscon", email="github@triscon.com" }, 10 | ] 11 | description = "Constrain LLM generation to structured output, such as function calling and tool use" 12 | readme = "README.md" 13 | requires-python = ">=3.8" 14 | classifiers = [ 15 | "Programming Language :: Python :: 3", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | ] 19 | 20 | [project.urls] 21 | Homepage = "https://github.com/otriscon/llm-structured-output" 22 | Issues = "https://github.com/otriscon/llm-structured-output/issues" 23 | 24 | [tool.hatch.build.targets.sdist] 25 | include = [ 26 | "LICENSE", 27 | "README.md", 28 | "pyproject.toml", 29 | "llm_structured_output", 30 | ] 31 | -------------------------------------------------------------------------------- /src/examples/attention.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring,missing-class-docstring 2 | """ 3 | Example server to visualize the generation mechanism. 4 | """ 5 | import json 6 | from operator import itemgetter 7 | import os 8 | from typing import Optional 9 | 10 | from fastapi import FastAPI, Request, status 11 | from fastapi.responses import FileResponse, JSONResponse 12 | from fastapi.exceptions import RequestValidationError 13 | from pydantic import BaseModel 14 | 15 | import mlx.core as mx 16 | import mlx.nn as nn 17 | from mlx_lm.utils import load 18 | from llm_structured_output import ( 19 | JsonSchemaAcceptorDriver, 20 | HuggingfaceTokenizerHelper, 21 | ) 22 | from llm_structured_output.util.output import info, warning 23 | 24 | from .reusable_kv_cache import ReusableKVCache 25 | 26 | 27 | def calc_prompt_perplexity(logits, prompt: list[int]): 28 | """ 29 | Try to get a measure for how much a prompt is baked into the training of the LLM. 30 | When evaluating a prompt, we pass several (perhaps many) input tokens. The LLM 31 | returns a matrix with the logits for each next token in the sequence, applying a 32 | mask to avoid "seeing" tokens that appear after the current one. We can compare 33 | the probability distrubtion formed by these logits with the actual token that 34 | follows, giving us an idea of how "surprised" the model is by the next token in 35 | the prompt. Note that this is related but not quite the same as the perplexity 36 | metric used to measure the training quality of a language model. 37 | The output is a list with a value for each token in the prompt, with the value 38 | being zero for no suprise (the model assigns probability 1 to that token appearing 39 | at that position given the prior tokens), and tending to infinity for the model 40 | assigning zero probability for that token in that position. By convention, we 41 | assign a value of zero to the first token in the prompt. 42 | """ 43 | # Input: 44 | # batch_size, ntokens, voc_size = logits.shape 45 | # len(prompt) == ntokens 46 | # Note that row i of the output logits vector corresponds to the evaluation after 47 | # input token i, i.e. it's the probability distribution for token i+1. The last 48 | # row of logits corresponds to the first token after the prompt. 49 | target = mx.array([prompt[1:]]) 50 | loss = nn.losses.cross_entropy(logits[:, :-1, :], target)[0] 51 | # Add a zero for the first token in the prompt. 52 | return [0] + loss.tolist() 53 | 54 | 55 | class ObservedLLM: 56 | def __init__(self, model_path: str): 57 | self.model, self.tokenizer = load(model_path) 58 | self.tokenizer_helper = HuggingfaceTokenizerHelper(self.tokenizer) 59 | self.vocabulary, self.eos_id = self.tokenizer_helper.extract_vocabulary() 60 | self.token_acceptor_factory = JsonSchemaAcceptorDriver.driver_factory_for_model( 61 | self.vocabulary, self.eos_id 62 | ) 63 | self.cache = ReusableKVCache.for_model(self.model) 64 | self.tokens = [] 65 | self.fragments = [] 66 | self.layer_attention_scores = [] 67 | self.token_acceptor = None 68 | self.layer_attention_scores = [] 69 | 70 | # Replace the attention dot product function to be able to look into it. 71 | def mock_fast_scaled_dot_product_attention( 72 | queries, keys, values, *, scale, mask=None, stream=None 73 | ): 74 | """ 75 | O = softmax(Q @ K.T, dim=-1) @ V 76 | """ 77 | B, n_kv_heads, L, _ = keys.shape 78 | _, n_heads, _, _ = queries.shape 79 | repeats = n_heads // n_kv_heads 80 | 81 | def repeat(a): 82 | a = mx.concatenate([mx.expand_dims(a, 2)] * repeats, axis=2) 83 | return a.reshape([B, n_heads, L, -1]) 84 | 85 | keys, values = map(repeat, (keys, values)) 86 | 87 | scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) 88 | if mask is not None: 89 | scores += mask 90 | scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype) 91 | result = scores @ values 92 | self.layer_attention_scores.append(scores[0, :, -1, :].tolist()) 93 | return result 94 | 95 | mx.fast.scaled_dot_product_attention = mock_fast_scaled_dot_product_attention 96 | 97 | def start(self, prompt: str, schema: dict): 98 | if schema is None: 99 | self.token_acceptor = None 100 | else: 101 | self.token_acceptor = self.token_acceptor_factory(schema) 102 | 103 | prior_tokens = self.tokens 104 | self.tokens = self.tokenizer_helper.encode_prompt(prompt) 105 | self.fragments = [ 106 | self.tokenizer_helper.no_strip_decode([token]) for token in self.tokens 107 | ] 108 | 109 | # If we had started a generation before, try to reuse as much of the cache as possible. 110 | i = 0 111 | for i, t in enumerate(prior_tokens): 112 | if i >= len(self.tokens) - 1 or self.tokens[i] != t: 113 | break 114 | for layer_cache in self.cache: 115 | layer_cache.reuse(len(self.tokens), i) 116 | new_tokens = self.tokens[i:] 117 | 118 | print(f"{new_tokens}") 119 | return self._generate(new_tokens) 120 | 121 | def add_token(self, token): 122 | self.tokens.append(token) 123 | self.fragments.append(self.tokenizer_helper.no_strip_decode([token])) 124 | if self.token_acceptor: 125 | self.token_acceptor.advance_token(token) 126 | return self._generate([token]) 127 | 128 | def _generate(self, new_input_tokens: list[int]): 129 | self.layer_attention_scores = [] 130 | logits = self.model(mx.array(new_input_tokens)[None], self.cache) 131 | 132 | TOP_TOKEN_COUNT = 1000 133 | probs = mx.softmax(logits[0, -1, :]) 134 | top_token_partition = mx.argpartition(probs, -TOP_TOKEN_COUNT)[ 135 | -TOP_TOKEN_COUNT: 136 | ] 137 | top_token_probs = sorted( 138 | [*zip(top_token_partition.tolist(), probs[top_token_partition].tolist())], 139 | key=itemgetter(1), 140 | reverse=True, 141 | ) 142 | 143 | if len(new_input_tokens) > 1: 144 | prompt_perplexity = calc_prompt_perplexity(logits, new_input_tokens) 145 | else: 146 | prompt_perplexity = None 147 | 148 | if self.token_acceptor: 149 | accepted_token_bitmap = self.token_acceptor.select_valid_tokens() 150 | rejected_top_tokens = set( 151 | token 152 | for token, _ in top_token_probs 153 | if not (accepted_token_bitmap & (1 << token)) 154 | ) 155 | else: 156 | rejected_top_tokens = set() 157 | 158 | top_tokens = [ 159 | ( 160 | token, 161 | self.tokenizer_helper.no_strip_decode([token]), 162 | p, 163 | token in rejected_top_tokens, 164 | ) 165 | for token, p in top_token_probs 166 | ] 167 | 168 | return { 169 | "attention_scores": self.layer_attention_scores, 170 | "fragments": self.fragments, 171 | "top_tokens": top_tokens, 172 | "prompt_perplexity": prompt_perplexity, 173 | } 174 | 175 | 176 | try: 177 | MODEL_PATH = os.environ["MODEL_PATH"] 178 | except KeyError: 179 | MODEL_PATH = "mlx-community/Meta-Llama-3-8B-Instruct-4bit" 180 | info("No MODEL_PATH environment variable, using default model.") 181 | 182 | info(f"Loading model {MODEL_PATH}...") 183 | llm = ObservedLLM(MODEL_PATH) 184 | 185 | 186 | app = FastAPI() 187 | 188 | 189 | @app.exception_handler(RequestValidationError) 190 | async def validation_exception_handler(_request: Request, exc: RequestValidationError): 191 | exc_str = f"{exc}" 192 | warning(f"RequestValidationError: {exc_str}") 193 | content = {"status_code": 10422, "message": exc_str, "data": None} 194 | return JSONResponse( 195 | content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY 196 | ) 197 | 198 | 199 | @app.get("/") 200 | def get_root(): 201 | return FileResponse( 202 | f"{os.path.dirname(os.path.realpath(__file__))}/static/attention.html" 203 | ) 204 | 205 | 206 | @app.get("/status") 207 | def get_status(): 208 | return {"status": "OK"} 209 | 210 | 211 | class GenerationStartRequest(BaseModel): 212 | prompt: str 213 | schema: Optional[str] = None 214 | 215 | 216 | @app.post("/generation/start") 217 | async def post_generation_start(request: GenerationStartRequest): 218 | if request.schema: 219 | schema = json.loads(request.schema) 220 | else: 221 | schema = None 222 | response = llm.start(request.prompt, schema) 223 | return response 224 | 225 | 226 | class GenerationTokenRequest(BaseModel): 227 | token: int 228 | 229 | 230 | @app.post("/generation/token") 231 | async def post_generation_token(request: GenerationTokenRequest): 232 | response = llm.add_token(request.token) 233 | return response 234 | -------------------------------------------------------------------------------- /src/examples/json_schema_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Command-line tool to validate a JSON input against a JSON schema. 4 | """ 5 | 6 | import argparse 7 | import json 8 | import sys 9 | 10 | from llm_structured_output.json_schema_acceptor import JsonSchemaAcceptorDriver 11 | from llm_structured_output.util.output import debug 12 | 13 | 14 | def main(): # pylint: disable=missing-function-docstring 15 | arg_parser = argparse.ArgumentParser( 16 | description=""" 17 | Incrementally validate a JSON input against a JSON schema. 18 | """, 19 | ) 20 | arg_parser.add_argument( 21 | "schema", 22 | help='JSON schema string, or "@" file containing JSON schema, or "-" for stdin)', 23 | ) 24 | arg_parser.add_argument( 25 | "json", 26 | help='JSON input string, or "@" file containing JSON input, or "-" for stdin)', 27 | ) 28 | arg_parser.add_argument( 29 | "--debug", 30 | action="store_true", 31 | help="Output more debug information", 32 | ) 33 | arg_parser.add_argument( 34 | "--paths", 35 | action="store_true", 36 | help="Extract value paths", 37 | ) 38 | args = arg_parser.parse_args() 39 | 40 | if args.schema == "-": 41 | schema = sys.stdin 42 | elif args.schema[0] == "@": 43 | with open(args.schema[1:], encoding="utf-8") as f: 44 | schema = f.read() 45 | else: 46 | schema = args.schema 47 | schema = json.loads(schema) 48 | 49 | if args.json == "-": 50 | input_json = sys.stdin 51 | elif args.json[0] == "@": 52 | with open(args.json[1:], encoding="utf-8") as f: 53 | input_json = f.read() 54 | else: 55 | input_json = args.json 56 | 57 | if args.paths: 58 | token_len = 1 59 | else: 60 | # For test purposes, just split the input into groups of 3 letters. 61 | token_len = 3 62 | fragments = [ 63 | input_json[i : i + token_len] for i in range(0, len(input_json), token_len) 64 | ] 65 | eos_fragment = chr(3) 66 | eos_token = 0 67 | vocabulary = list(enumerate([eos_fragment] + [*set(fragments)])) 68 | reverse_vocabulary = dict((f, i) for i, f in vocabulary) 69 | tokens = [reverse_vocabulary[f] for f in fragments] 70 | 71 | acceptor_factory = JsonSchemaAcceptorDriver.driver_factory_for_model( 72 | vocabulary, eos_id=eos_token 73 | ) 74 | acceptor = acceptor_factory(schema) 75 | fail = False 76 | values_by_path = {} 77 | for fragment, token in zip(fragments, tokens): 78 | if args.debug: 79 | debug(f"FRAGMENT={repr(fragment)} TOKEN={token}") 80 | try: 81 | if args.debug: 82 | acceptor.debug_advance_token(token) 83 | else: 84 | acceptor.advance_token(token) 85 | except acceptor.TokenRejected: 86 | fail = True 87 | break 88 | if args.paths: 89 | for path in acceptor.get_current_value_paths(): 90 | values_by_path[path] = values_by_path.get(path, "") + fragment 91 | print("\n".join(repr(c) for c in acceptor.cursors)) 92 | if not fail: 93 | if args.debug: 94 | debug(f"FRAGMENT= TOKEN={eos_token}") 95 | try: 96 | fail = acceptor.advance_token(eos_token) 97 | except acceptor.TokenRejected: 98 | fail = True 99 | if fail: 100 | print("[FAIL]") 101 | result = 1 102 | else: 103 | print("[SUCCESS]") 104 | result = 0 105 | if debug: 106 | debug("\n".join(repr(c) for c in acceptor.cursors)) 107 | if args.paths: 108 | print(json.dumps(values_by_path, indent=2)) 109 | return result 110 | 111 | 112 | if __name__ == "__main__": 113 | sys.exit(main()) 114 | -------------------------------------------------------------------------------- /src/examples/llm_schema.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-class-docstring,missing-function-docstring 2 | """ 3 | Example of JSON schema decoding with MLX. 4 | """ 5 | import argparse 6 | import json 7 | import time 8 | from math import inf 9 | from operator import itemgetter 10 | from typing import Iterable, Optional, Union 11 | 12 | import mlx.core as mx 13 | import mlx.nn as nn 14 | 15 | from mlx_lm.utils import load 16 | 17 | from llm_structured_output import JsonSchemaAcceptorDriver 18 | from llm_structured_output.util.bitmap import ( 19 | bias_logits, 20 | count_set_bits, 21 | enumerate_set_bits, 22 | ) 23 | from llm_structured_output.util.output import info, bold, bolddim, debug 24 | from llm_structured_output.util.tokenization import HuggingfaceTokenizerHelper 25 | 26 | from .reusable_kv_cache import ReusableKVCache 27 | 28 | 29 | class RejectedCompletion(Exception): 30 | """ 31 | It's rare, but sometimes we reach a state where it's not possible to 32 | advance the acceptor. For example, when closing a JSON string we get 33 | a higher probability for slanted quotes than straight ones and select 34 | the wrong token. At that point, the LLM will continue generating with 35 | the prior that the string is closed, but our acceptor will remain in 36 | the string-accepting state. This can indicate an issue with the 37 | tokenizer vocabulary passed to the acceptor, or a bug in the code 38 | used to decode tokens from the LLM. If none of these apply, check that 39 | the LLM is actually able to generate JSON, although most are. 40 | """ 41 | 42 | 43 | class Model: 44 | def __init__(self): 45 | mx.random.seed(0) 46 | self.model = None 47 | self.tokenizer = None 48 | self.vocabulary = None 49 | self.eos_id = None 50 | self.json_schema_acceptor_driver_factory = None 51 | self._cached_prompt = None 52 | self._cached_cache = None 53 | 54 | def load(self, model_path: str): 55 | """ 56 | Load locally or download from Huggingface hub. 57 | """ 58 | self.model, tokenizer = load(model_path) 59 | self.tokenizer = HuggingfaceTokenizerHelper(tokenizer) 60 | self.vocabulary, self.eos_id = self.tokenizer.extract_vocabulary() 61 | self.json_schema_acceptor_driver_factory = ( 62 | JsonSchemaAcceptorDriver.driver_factory_for_model( 63 | self.vocabulary, self.eos_id 64 | ) 65 | ) 66 | 67 | def get_driver_for_json_schema(self, schema, encapsulated: bool = False): 68 | return self.json_schema_acceptor_driver_factory( 69 | schema, is_encapsulated_json=encapsulated 70 | ) 71 | 72 | def _evaluate_prompt( 73 | self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None 74 | ): 75 | if prior_prompt: 76 | i = 0 77 | for i, t in enumerate(prior_prompt): 78 | # We need to leave at least one token to evaluate because we don't 79 | # save the past logits. 80 | if i >= len(prompt) - 1 or prompt[i] != t: 81 | break 82 | cache = prior_cache 83 | for layer_cache in cache: 84 | layer_cache.reuse(len(prompt), i) 85 | tokens = prompt[i:] 86 | else: 87 | cache = ReusableKVCache.for_model(self.model) 88 | tokens = prompt 89 | 90 | logits = self.model(mx.array(tokens)[None], cache=cache) 91 | return logits, cache 92 | 93 | def _decode(self, tokens): 94 | return self.tokenizer.no_strip_decode(tokens) 95 | 96 | def _debug_top_tokens(self, logits, count=10): 97 | token_logits = sorted( 98 | enumerate(logits.tolist()), key=itemgetter(1), reverse=True 99 | ) 100 | top_tokens = [ 101 | (self._decode([t]), p) for t, p in token_logits[:count] if p != -inf 102 | ] 103 | debug("TOP TOKENS:", top_tokens) 104 | 105 | def _sample(self, logits, temp: float = 0): 106 | if temp == 0: 107 | result = mx.argmax(logits, axis=-1) 108 | else: 109 | result = mx.random.categorical(logits * (1 / temp)) 110 | return result.item() 111 | 112 | def _sample_with_bias( 113 | self, logits, temp: float = 0, token_acceptor=None, lazy_bias: bool = True 114 | ): 115 | if token_acceptor is None: 116 | return self._sample(logits, temp) 117 | 118 | if lazy_bias: 119 | token = self._sample(logits, temp) 120 | try: 121 | token_acceptor.advance_token(token) 122 | return token 123 | except JsonSchemaAcceptorDriver.TokenRejected: 124 | pass 125 | 126 | accepted_token_bitmap = token_acceptor.select_valid_tokens() 127 | if not accepted_token_bitmap: 128 | debug(token_acceptor.cursors) 129 | self._debug_top_tokens(logits) 130 | raise RejectedCompletion() 131 | token = self._sample(bias_logits(mx, logits, accepted_token_bitmap), temp) 132 | token_acceptor.advance_token(token) 133 | return token 134 | 135 | def generate_without_schema(self, logits, cache, temp: Optional[float] = 0.0): 136 | """ 137 | For testing / comparison purposes. 138 | """ 139 | while True: 140 | tokens = [self._sample(logits[0, -1, :], temp)] 141 | yield tokens 142 | if tokens[-1] == self.eos_id: 143 | break 144 | logits = self.model(mx.array(tokens)[None], cache=cache) 145 | 146 | def generate_with_schema( 147 | self, logits, cache, token_acceptor, temp: Optional[float] = 0.0 148 | ): 149 | while True: 150 | tokens = [self._sample_with_bias(logits[0, -1, :], temp, token_acceptor)] 151 | yield tokens 152 | if tokens[-1] == self.eos_id: 153 | break 154 | logits = self.model(mx.array(tokens)[None], cache=cache) 155 | 156 | def generate_with_preemptive_decoding( 157 | self, 158 | logits, 159 | cache, 160 | token_acceptor, 161 | temp: Optional[float] = 0.0, 162 | max_batch_size=5, 163 | ): 164 | """ 165 | Try to generate faster by precomputing two tokens at a time when possible. 166 | If we know that the acceptor will only accept a small set of tokens after 167 | the current one, we can evaluate a batch with one entry per possible 168 | future token. Each entry in the batch contains the current token sampled, 169 | which we have to evaluate anyway, and a second token corresponding to one 170 | of the possible tokens that could be sampled from the output to the first 171 | token. We get back logits for both tokens for each item in the batch: the 172 | logits for the first token will be the same (as long as the model applies 173 | a causal mask), and we can sample those logits to select from which of the 174 | items in the batch we can select the second token. 175 | In practice, this only seems to accelerate things for unquantized models. 176 | """ 177 | # Sample token from prompt evaluation 178 | accepted_token_bitmap = token_acceptor.select_valid_tokens() 179 | first_token_logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap) 180 | first_token = self._sample(first_token_logits, temp) 181 | tokens = [first_token] 182 | yield tokens 183 | token_acceptor.advance_token(first_token) 184 | accepted_token_bitmap = token_acceptor.select_valid_tokens() 185 | 186 | while True: 187 | last_token = tokens[-1] 188 | if count_set_bits(accepted_token_bitmap) in range(1, max_batch_size + 1): 189 | # If the number of possible follow-up tokens is small, submit for 190 | # evaluation a batch of 2-token continuations. 191 | batch = [] 192 | for followup_token in enumerate_set_bits(accepted_token_bitmap): 193 | batch.append([last_token, followup_token]) 194 | # Re-shape the cache to match the input. 195 | for layer_cache in cache: 196 | layer_cache.keys = mx.concatenate([layer_cache.keys] * len(batch)) 197 | layer_cache.values = mx.concatenate( 198 | [layer_cache.values] * len(batch) 199 | ) 200 | else: # Otherwise, submit the normal one-token continuation. 201 | batch = [[last_token]] 202 | 203 | logits = self.model(mx.array(batch), cache=cache) 204 | mx.eval(logits) 205 | 206 | first_token_logits = bias_logits(mx, logits[0, 0, :], accepted_token_bitmap) 207 | first_token = self._sample(first_token_logits, temp) 208 | tokens = [first_token] 209 | 210 | if first_token == self.eos_id: 211 | yield tokens 212 | break 213 | 214 | token_acceptor.advance_token(first_token) 215 | accepted_token_bitmap = token_acceptor.select_valid_tokens() 216 | if not accepted_token_bitmap: 217 | raise RejectedCompletion() 218 | 219 | # If we had submitted 2-token continuations, we can decode a second token 220 | if len(batch[0]) > 1: 221 | index = next( # Find which of the second tokens was selected 222 | i 223 | for i, batch_item in enumerate(batch) 224 | if batch_item[1] == first_token 225 | ) 226 | second_token_logits = bias_logits( 227 | mx, logits[index, 1, :], accepted_token_bitmap 228 | ) 229 | second_token = self._sample(second_token_logits, temp) 230 | tokens.append(second_token) 231 | 232 | token_acceptor.advance_token(second_token) 233 | accepted_token_bitmap = token_acceptor.select_valid_tokens() 234 | 235 | # Select the accepted generation in the cache, restoring it to batch dimension 1. 236 | for layer_cache in cache: 237 | layer_cache.keys = layer_cache.keys.split([index, index + 1])[1] 238 | layer_cache.values = layer_cache.values.split([index, index + 1])[1] 239 | 240 | yield tokens 241 | 242 | def _generate_tokens( 243 | self, 244 | generator: Iterable, 245 | max_tokens: int = 1000, 246 | ) -> Iterable: 247 | start_time = time.time_ns() 248 | token_count = 0 249 | 250 | for tokens in generator: 251 | token_count += len(tokens) 252 | 253 | try: 254 | eos_index = tokens.index(self.eos_id) 255 | tokens = tokens[0:eos_index] 256 | except ValueError: 257 | eos_index = -1 258 | 259 | if tokens: 260 | text = self._decode(tokens) 261 | yield { 262 | "op": "generatedTokens", 263 | "text": text, 264 | "token_count": len(tokens), 265 | "time_ms": (time.time_ns() - start_time) / 1e6, 266 | } 267 | 268 | if eos_index >= 0: 269 | yield {"op": "stop", "reason": "end"} 270 | return 271 | 272 | if token_count >= max_tokens: 273 | yield {"op": "stop", "reason": "max_tokens"} 274 | return 275 | 276 | start_time = time.time_ns() 277 | 278 | assert False 279 | 280 | def completion( 281 | self, 282 | prompt: Union[str, Iterable[dict[str, str]]], 283 | schema: dict, 284 | encapsulated: bool = False, 285 | max_tokens: int = 1000, 286 | temp: float = 0.0, 287 | seed: int = None, 288 | preemptive_batch_size: int = 0, 289 | cache_prompt: bool = False, 290 | ): 291 | if seed is not None: 292 | mx.random.seed(seed) 293 | 294 | start_time = time.time_ns() 295 | prompt_tokens = self.tokenizer.encode_prompt(prompt) 296 | logits, cache = self._evaluate_prompt( 297 | prompt_tokens, self._cached_prompt, self._cached_cache 298 | ) 299 | if cache_prompt: 300 | self._cached_prompt = prompt_tokens 301 | self._cached_cache = cache 302 | # Eager eval to more accurately reflect the prompt evaluation time. 303 | mx.eval(logits) 304 | prompt_time = time.time_ns() - start_time 305 | yield { 306 | "op": "evaluatedPrompt", 307 | "prompt": prompt, 308 | "token_count": len(prompt_tokens), 309 | "time_ms": prompt_time / 1e6, 310 | "prompt_tps": len(prompt_tokens) / (prompt_time / 1e9), 311 | } 312 | 313 | if schema: 314 | token_acceptor = self.get_driver_for_json_schema(schema, encapsulated) 315 | if preemptive_batch_size > 0: 316 | generator = self.generate_with_preemptive_decoding( 317 | logits, 318 | cache, 319 | token_acceptor, 320 | temp, 321 | max_batch_size=preemptive_batch_size, 322 | ) 323 | else: 324 | generator = self.generate_with_schema( 325 | logits, cache, token_acceptor, temp 326 | ) 327 | else: 328 | generator = self.generate_without_schema(logits, cache, temp) 329 | 330 | token_count = 0 331 | generation_time = 0 332 | for generation_result in self._generate_tokens(generator, max_tokens): 333 | if generation_result["op"] == "generatedTokens": 334 | token_count += generation_result["token_count"] 335 | generation_time += generation_result["time_ms"] 336 | elif generation_result["op"] == "stop": 337 | generation_result["token_count"] = token_count 338 | generation_result["time_ms"] = generation_time 339 | # This is slightly incorrect, because the first token is generated 340 | # from the prompt evaluation. 341 | generation_result["generation_tps"] = token_count / ( 342 | generation_time / 1e3 343 | ) 344 | yield generation_result 345 | 346 | 347 | def main(): 348 | parser = argparse.ArgumentParser( 349 | description="LLM inference script with schema-constrained sampling" 350 | ) 351 | parser.add_argument( 352 | "--model-path", 353 | type=str, 354 | default="mlx_model", 355 | help="The path to the model weights and tokenizer", 356 | ) 357 | parser.add_argument( 358 | "--prompt", 359 | default="Once upon a midnight dreary", 360 | help="The message to be processed by the model", 361 | ) 362 | parser.add_argument( 363 | "--max-tokens", 364 | "-m", 365 | type=int, 366 | default=100, 367 | help="Maximum number of tokens to generate", 368 | ) 369 | parser.add_argument( 370 | "--temp", 371 | help="The sampling temperature.", 372 | type=float, 373 | default=0.0, 374 | ) 375 | parser.add_argument("--seed", type=int, default=None, help="The PRNG seed") 376 | parser.add_argument( 377 | "--repeat-prompt", 378 | action=argparse.BooleanOptionalAction, 379 | help="Print prompt before start of generation", 380 | ) 381 | parser.add_argument( 382 | "--schema", 383 | help="A JSON schema to constrain the output.", 384 | type=str, 385 | default=None, 386 | ) 387 | parser.add_argument( 388 | "--encapsulated", 389 | action=argparse.BooleanOptionalAction, 390 | help="Whether the LLM is expected to encapsulate the JSON within ```json and ```.", 391 | ) 392 | parser.add_argument( 393 | "--preemptive", 394 | type=int, 395 | default=0, 396 | help="If greater than zero, the maximum size of the batch for pre-emptive decoding", 397 | ) 398 | 399 | args = parser.parse_args() 400 | 401 | info("Loading model from disk.") 402 | model = Model() 403 | model.load(args.model_path) 404 | 405 | if args.schema is not None: 406 | schema = json.loads(args.schema) 407 | info("Using schema") 408 | else: 409 | schema = None 410 | info("Starting generation...") 411 | 412 | for result in model.completion( 413 | prompt=args.prompt, 414 | schema=schema, 415 | encapsulated=args.encapsulated, 416 | max_tokens=args.max_tokens, 417 | temp=args.temp, 418 | seed=args.seed, 419 | preemptive_batch_size=args.preemptive, 420 | ): 421 | if result["op"] == "evaluatedPrompt": 422 | prompt_token_count = result["token_count"] 423 | prompt_time = result["time_ms"] 424 | prompt_tps = result["prompt_tps"] 425 | if args.repeat_prompt: 426 | bolddim(result["prompt"], flush=True) 427 | elif result["op"] == "generatedTokens": 428 | bold(result["text"], end="", flush=True) 429 | elif result["op"] == "stop": 430 | end_reason = result["reason"] 431 | generated_token_count = result["token_count"] 432 | generation_time = result["time_ms"] 433 | generation_tps = result["generation_tps"] 434 | else: 435 | assert False 436 | 437 | print() 438 | info(f"End reason: {end_reason}") 439 | info(f"Tokens: prompt {prompt_token_count}, generation {generated_token_count}") 440 | info(f"Tokens per second: prompt {prompt_tps:.2f}, generation {generation_tps:.2f}") 441 | info(f"Total time: prompt {prompt_time:.2f}ms, generation {generation_time:.2f}ms") 442 | 443 | 444 | if __name__ == "__main__": 445 | main() 446 | -------------------------------------------------------------------------------- /src/examples/reluctance.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-class-docstring,missing-function-docstring 2 | """ 3 | Example of JSON schema decoding with MLX. 4 | """ 5 | import argparse 6 | import json 7 | 8 | import mlx.core as mx 9 | import mlx.nn as nn 10 | from mlx_lm.utils import load 11 | 12 | from llm_structured_output import ( 13 | JsonSchemaAcceptorDriver, 14 | HuggingfaceTokenizerHelper, 15 | bias_logits, 16 | ) 17 | from llm_structured_output.util.output import info, setbg, setfg, clear 18 | 19 | from .reusable_kv_cache import ReusableKVCache 20 | 21 | 22 | def compute_reluctance(logits, accepted_token_bitmap) -> float: 23 | """ 24 | Sum the probabilities of each token that has higher probability than 25 | the highest-probability token selected by the schema. This gives an 26 | idea of the model's preference for tokens that don't follow the schema. 27 | """ 28 | p = nn.softmax(logits) 29 | indices = mx.argsort(p)[::-1] 30 | r = 0 31 | for i in indices.tolist(): 32 | if (1 << i) & accepted_token_bitmap: 33 | break 34 | r += p[i].item() 35 | return r 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser( 40 | description="Visualize LLM reluctance to generate according to the schema." 41 | ) 42 | parser.add_argument( 43 | "--model-path", 44 | type=str, 45 | default="mlx_model", 46 | help="The path to the model weights and tokenizer", 47 | ) 48 | parser.add_argument( 49 | "--schema", 50 | help="A JSON schema to constrain the output.", 51 | type=str, 52 | ) 53 | parser.add_argument( 54 | "--prompt", 55 | help="The message to be processed by the model", 56 | ) 57 | parser.add_argument( 58 | "--max-tokens", 59 | "-m", 60 | type=int, 61 | default=1000, 62 | help="Maximum number of tokens to generate", 63 | ) 64 | 65 | args = parser.parse_args() 66 | 67 | info("Loading model from disk...") 68 | model, tokenizer = load(args.model_path) 69 | schema = json.loads(args.schema) 70 | 71 | tokenizer_helper = HuggingfaceTokenizerHelper(tokenizer) 72 | vocabulary, eos_id = tokenizer_helper.extract_vocabulary() 73 | token_acceptor_factory = JsonSchemaAcceptorDriver.driver_factory_for_model(vocabulary, eos_id) 74 | token_acceptor = token_acceptor_factory(schema) 75 | 76 | 77 | info("Starting generation...") 78 | tokens = tokenizer_helper.encode_prompt(args.prompt) 79 | cache = ReusableKVCache.for_model(model) 80 | while tokens[-1] != eos_id: 81 | logits = model(mx.array(tokens)[None], cache) 82 | accepted_token_bitmap = token_acceptor.select_valid_tokens() 83 | reluctance = compute_reluctance(logits[0, -1, :], accepted_token_bitmap) 84 | biased_logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap) 85 | token = mx.argmax(biased_logits, axis=-1).item() 86 | if token == eos_id: 87 | break 88 | tokens = [token] 89 | text = tokenizer_helper.no_strip_decode(tokens) 90 | setbg(reluctance, 0.8 * (1 - reluctance), 0) 91 | setfg(1, 1, 1) 92 | print(text, end="") 93 | token_acceptor.advance_token(token) 94 | clear() 95 | print() 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /src/examples/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx >= 0.19.1 2 | mlx-lm >= 0.19.2 3 | tokenizers >= 0.20.1 4 | sentencepiece 5 | fastapi 6 | pydantic 7 | uvicorn 8 | -------------------------------------------------------------------------------- /src/examples/reusable_kv_cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper with improvements over mlx-lm's KVCache. 3 | """ 4 | 5 | import mlx.core as mx 6 | from mlx_lm.models.cache import KVCache 7 | 8 | 9 | class ReusableKVCache(KVCache): 10 | """ 11 | Usability improvements over KVCache. 12 | """ 13 | 14 | @classmethod 15 | def for_model(cls, model): 16 | return [cls() for _ in model.layers] 17 | 18 | def reuse(self, new_prompt_length, common_prefix_length): 19 | """ 20 | Reuse (part of) this cache for a new prompt that shares a prefix with it. 21 | """ 22 | if self.keys is None: 23 | return 24 | # Clip the cache to the common length. 25 | self.offset = common_prefix_length 26 | # Make sure the cache can fit the whole prompt. Because the offset is 27 | # (very likely) not a multiple of the step size, update_and_fetch() 28 | # won't resize the cache when evaluating the rest of the prompt as it 29 | # would if it were an empty cache. 30 | current_size = self.keys.shape[2] 31 | if current_size < new_prompt_length: 32 | _, n_kv_heads, _, k_head_dim = self.keys.shape 33 | v_head_dim = self.values.shape[3] 34 | n_steps = (self.step + new_prompt_length - 1) // self.step 35 | k_add_shape = (1, n_kv_heads, n_steps * self.step - current_size, k_head_dim) 36 | v_add_shape = (1, n_kv_heads, n_steps * self.step - current_size, v_head_dim) 37 | k_zeros = mx.zeros(k_add_shape, self.keys.dtype) 38 | v_zeros = mx.zeros(v_add_shape, self.values.dtype) 39 | self.keys = mx.concatenate([self.keys, k_zeros], axis=2) 40 | self.values = mx.concatenate([self.values, v_zeros], axis=2) 41 | 42 | def update_and_fetch(self, keys, values): 43 | """ 44 | Override the base class method to allow the cache to be used with batches of 45 | size greater than 1. 46 | This is just a tiny change in the line that determines the shape. 47 | """ 48 | prev = self.offset 49 | if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: 50 | B, n_kv_heads, _, k_head_dim = keys.shape 51 | v_head_dim = values.shape[3] 52 | n_steps = (self.step + keys.shape[2] - 1) // self.step 53 | k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) 54 | v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) 55 | new_k = mx.zeros(k_shape, keys.dtype) 56 | new_v = mx.zeros(v_shape, values.dtype) 57 | if self.keys is not None: 58 | if prev % self.step != 0: 59 | self.keys = self.keys[..., :prev, :] 60 | self.values = self.values[..., :prev, :] 61 | self.keys = mx.concatenate([self.keys, new_k], axis=2) 62 | self.values = mx.concatenate([self.values, new_v], axis=2) 63 | else: 64 | self.keys, self.values = new_k, new_v 65 | 66 | self.offset += keys.shape[2] 67 | self.keys[..., prev : self.offset, :] = keys 68 | self.values[..., prev : self.offset, :] = values 69 | return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] 70 | 71 | -------------------------------------------------------------------------------- /src/examples/server.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring,missing-class-docstring 2 | """ 3 | Example model server with OpenAI-like API, including function calls / tools. 4 | """ 5 | import json 6 | import time 7 | import os 8 | from enum import Enum 9 | from traceback import format_exc 10 | from typing import Literal, List, Optional, Union 11 | 12 | from fastapi import FastAPI, Request, status 13 | from fastapi.responses import FileResponse, JSONResponse, StreamingResponse 14 | from fastapi.exceptions import RequestValidationError 15 | from pydantic import BaseModel 16 | 17 | from examples.llm_schema import Model 18 | from llm_structured_output.util.output import info, warning, debug 19 | 20 | 21 | app = FastAPI() 22 | 23 | model = Model() 24 | info("Loading model...") 25 | try: 26 | model_path = os.environ["MODEL_PATH"] 27 | model.load(model_path) 28 | except KeyError: 29 | warning("Need to specify MODEL_PATH environment variable") 30 | 31 | 32 | @app.exception_handler(RequestValidationError) 33 | # pylint: disable-next=unused-argument 34 | async def validation_exception_handler(request: Request, exc: RequestValidationError): 35 | exc_str = f"{exc}" 36 | warning(f"RequestValidationError: {exc_str}") 37 | content = {"error": exc_str} 38 | return JSONResponse( 39 | content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY 40 | ) 41 | 42 | 43 | @app.get("/status") 44 | def get_status(): 45 | return {"status": "OK"} 46 | 47 | 48 | @app.get("/") 49 | def get_root(): 50 | return FileResponse(f"{os.path.dirname(os.path.realpath(__file__))}/static/ui.html") 51 | 52 | 53 | class V1ChatMessageRole(str, Enum): 54 | SYSTEM = "system" 55 | USER = "user" 56 | ASSISTANT = "assistant" 57 | 58 | 59 | class V1ChatMessage(BaseModel): 60 | role: V1ChatMessageRole 61 | content: str 62 | 63 | 64 | class V1Function(BaseModel): 65 | name: str 66 | description: str = "" 67 | parameters: dict = {} 68 | 69 | 70 | class V1ToolFunction(BaseModel): 71 | type: Literal["function"] 72 | function: V1Function 73 | 74 | 75 | class V1ToolChoiceKeyword(str, Enum): 76 | AUTO = "auto" 77 | NONE = "none" 78 | 79 | 80 | class V1ToolChoiceFunction(BaseModel): 81 | type: Optional[Literal["function"]] = None 82 | name: str 83 | 84 | 85 | class V1ToolOptions(BaseModel): # Non-standard, our addition. 86 | # We automatically add instructions with the JSON schema 87 | # for the tool calls to the prompt. This option disables 88 | # it and is useful when the user prompt already includes 89 | # the schema and relevant instructions. 90 | no_prompt_steering: bool = False 91 | 92 | 93 | class V1ResponseFormatType(str, Enum): 94 | JSON_OBJECT = "json_object" 95 | 96 | 97 | class V1ResponseFormat(BaseModel): 98 | type: V1ResponseFormatType 99 | # schema is our addition, not an OpenAI API parameter 100 | schema: str = None 101 | 102 | 103 | class V1StreamOptions(BaseModel): 104 | include_usage: bool = False 105 | 106 | 107 | class V1ChatCompletionsRequest( 108 | BaseModel 109 | ): # pylint: disable=too-many-instance-attributes 110 | model: str = "default" 111 | max_tokens: int = 1000 112 | temperature: float = 0.0 113 | messages: List[V1ChatMessage] 114 | # The 'functions' and 'function_call' fields have been dreprecated and 115 | # replaced with 'tools' and 'tool_choice', that work similarly but allow 116 | # for multiple functions to be invoked. 117 | functions: List[V1Function] = None 118 | function_call: Union[V1ToolChoiceKeyword, V1ToolChoiceFunction] = None 119 | tools: List[V1ToolFunction] = None 120 | tool_choice: Union[V1ToolChoiceKeyword, V1ToolChoiceFunction] = None 121 | tool_options: V1ToolOptions = None 122 | response_format: V1ResponseFormat = None 123 | stream: bool = False 124 | stream_options: V1StreamOptions = None 125 | 126 | 127 | @app.post("/v1/chat/completions") 128 | async def post_v1_chat_completions(request: V1ChatCompletionsRequest): 129 | debug("REQUEST", request) 130 | if request.stream: 131 | async def get_content(): 132 | try: 133 | async for message in post_v1_chat_completions_impl(request): 134 | yield message 135 | # pylint: disable-next=broad-exception-caught 136 | except Exception as e: 137 | warning(format_exc()) 138 | yield 'data: {"choices": [{"index": 0, "finish_reason": "error: ' + str(e) + '"}]}' 139 | return StreamingResponse( 140 | content=get_content(), 141 | media_type="text/event-stream", 142 | ) 143 | else: 144 | # FUTURE: Python 3.10 can use `await anext(x))` instead of `await x.__anext__()`. 145 | try: 146 | response = await post_v1_chat_completions_impl(request).__anext__() 147 | # pylint: disable-next=broad-exception-caught 148 | except Exception as e: 149 | warning(format_exc()) 150 | content = {"error": str(e)} 151 | response = JSONResponse( 152 | content=content, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR 153 | ) 154 | debug("RESPONSE", response) 155 | return response 156 | 157 | 158 | async def post_v1_chat_completions_impl(request: V1ChatCompletionsRequest): 159 | messages = request.messages[:] 160 | 161 | # Extract valid functions from the request. 162 | functions = [] 163 | is_legacy_function_call = False 164 | if request.tool_choice == "none": 165 | pass 166 | elif request.tool_choice == "auto": 167 | functions = [tool.function for tool in request.tools if tool.type == "function"] 168 | elif request.tool_choice is not None: 169 | functions = [ 170 | next( 171 | tool.function 172 | for tool in request.tools 173 | if tool.type == "function" 174 | and tool.function.name == request.function_call.name 175 | ) 176 | ] 177 | elif request.function_call == "none": 178 | pass 179 | elif request.function_call == "auto": 180 | functions = request.functions 181 | is_legacy_function_call = True 182 | elif request.function_call is not None: 183 | functions = [ 184 | next( 185 | fn for fn in request.functions if fn.name == request.function_call.name 186 | ) 187 | ] 188 | is_legacy_function_call = True 189 | 190 | model_name = model_path 191 | schema = None 192 | if functions: 193 | # If the request includes functions, create a system prompt to instruct the LLM 194 | # to use tools, and assemble a JSON schema to steer the LLM output. 195 | if request.stream: 196 | responder = ToolCallStreamingResponder( 197 | model_name, 198 | functions, 199 | is_legacy_function_call, 200 | model, 201 | ) 202 | else: 203 | responder = ToolCallResponder( 204 | model_name, functions, is_legacy_function_call 205 | ) 206 | if not (request.tool_options and request.tool_options.no_prompt_steering): 207 | messages.insert( 208 | 0, 209 | V1ChatMessage( 210 | role="system", 211 | content=responder.tool_prompt, 212 | ), 213 | ) 214 | schema = responder.schema 215 | else: 216 | if request.response_format: 217 | assert request.response_format.type == V1ResponseFormatType.JSON_OBJECT 218 | # The request may specify a JSON schema (this option is not in the OpenAI API) 219 | if request.response_format.schema: 220 | schema = json.loads(request.response_format.schema) 221 | else: 222 | schema = {"type": "object"} 223 | if request.stream: 224 | responder = ChatCompletionStreamingResponder(model_name, schema, model) 225 | else: 226 | responder = ChatCompletionResponder(model_name) 227 | 228 | if schema is not None: 229 | debug("Using schema:", schema) 230 | 231 | info("Starting generation...") 232 | 233 | prompt_tokens = None 234 | 235 | for result in model.completion( 236 | messages, 237 | schema=schema, 238 | max_tokens=request.max_tokens, 239 | temp=request.temperature, 240 | cache_prompt=True, 241 | ): 242 | if result["op"] == "evaluatedPrompt": 243 | prompt_tokens = result["token_count"] 244 | elif result["op"] == "generatedTokens": 245 | message = responder.generated_tokens(result["text"]) 246 | if message: 247 | yield message 248 | elif result["op"] == "stop": 249 | completion_tokens = result["token_count"] 250 | yield responder.generation_stopped( 251 | result["reason"], prompt_tokens, completion_tokens 252 | ) 253 | else: 254 | assert False 255 | 256 | 257 | class ChatCompletionResponder: 258 | def __init__(self, model_name: str): 259 | self.object_type = "chat.completion" 260 | self.model_name = model_name 261 | self.created = int(time.time()) 262 | self.id = f"{id(self)}_{self.created}" 263 | self.content = "" 264 | 265 | def message_properties(self): 266 | return { 267 | "object": self.object_type, 268 | "id": f"chatcmpl-{self.id}", 269 | "created": self.created, 270 | "model": self.model_name, 271 | } 272 | 273 | def translate_reason(self, reason): 274 | """ 275 | Translate our reason codes to OpenAI ones. 276 | """ 277 | if reason == "end": 278 | return "stop" 279 | if reason == "max_tokens": 280 | return "length" 281 | return f"error: {reason}" # Not a standard OpenAI API reason 282 | 283 | def format_usage(self, prompt_tokens: int, completion_tokens: int): 284 | return { 285 | "usage": { 286 | "completion_tokens": completion_tokens, 287 | "prompt_tokens": prompt_tokens, 288 | "total_tokens": completion_tokens + prompt_tokens, 289 | }, 290 | } 291 | 292 | def generated_tokens( 293 | self, 294 | text: str, 295 | ): 296 | self.content += text 297 | return None 298 | 299 | def generation_stopped( 300 | self, 301 | stop_reason: str, 302 | prompt_tokens: int, 303 | completion_tokens: int, 304 | ): 305 | finish_reason = self.translate_reason(stop_reason) 306 | message = {"role": "assistant", "content": self.content} 307 | return { 308 | "choices": [ 309 | {"index": 0, "message": message, "finish_reason": finish_reason} 310 | ], 311 | **self.format_usage(prompt_tokens, completion_tokens), 312 | **self.message_properties(), 313 | } 314 | 315 | 316 | class ChatCompletionStreamingResponder(ChatCompletionResponder): 317 | def __init__(self, model_name: str, schema: dict = None, _model = None): 318 | super().__init__(model_name) 319 | self.object_type = "chat.completion.chunk" 320 | if schema: 321 | assert _model 322 | self.schema_parser = _model.get_driver_for_json_schema(schema) 323 | else: 324 | self.schema_parser = None 325 | 326 | def generated_tokens( 327 | self, 328 | text: str, 329 | ): 330 | delta = {"role": "assistant", "content": text} 331 | if self.schema_parser: 332 | values = {} 333 | for char in text: 334 | self.schema_parser.advance_char(char) 335 | for path in self.schema_parser.get_current_value_paths(): 336 | values[path] = values.get(path, "") + char 337 | delta["values"] = values 338 | message = { 339 | "choices": [{"index": 0, "delta": delta, "finish_reason": None}], 340 | **self.message_properties(), 341 | } 342 | return f"data: {json.dumps(message)}\n" 343 | 344 | def generation_stopped( 345 | self, 346 | stop_reason: str, 347 | prompt_tokens: int, 348 | completion_tokens: int, 349 | ): 350 | finish_reason = self.translate_reason(stop_reason) 351 | delta = {"role": "assistant", "content": ""} 352 | message = { 353 | "choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}], 354 | # Usage field notes: 355 | # - OpenAI only sends usage in streaming if the option 356 | # stream_options.include_usage is true, but we send it always. 357 | **self.format_usage(prompt_tokens, completion_tokens), 358 | **self.message_properties(), 359 | } 360 | return f"data: {json.dumps(message)}\ndata: [DONE]\n" 361 | 362 | 363 | class ToolCallResponder(ChatCompletionResponder): 364 | def __init__( 365 | self, model_name: str, functions: list[dict], is_legacy_function_call: bool 366 | ): 367 | super().__init__(model_name) 368 | 369 | self.is_legacy_function_call = is_legacy_function_call 370 | 371 | function_schemas = [ 372 | { 373 | "type": "object", 374 | "properties": { 375 | "name": {"type": "const", "const": fn.name}, 376 | "arguments": fn.parameters, 377 | }, 378 | "required": ["name", "arguments"], 379 | } 380 | for fn in functions 381 | ] 382 | if len(function_schemas) == 1: 383 | self.schema = function_schemas[0] 384 | self.tool_prompt = self._one_tool_prompt(functions[0], function_schemas[0]) 385 | elif is_legacy_function_call: # Only allows one function to be called. 386 | self.schema = {"oneOf": function_schemas} 387 | self.tool_prompt = self._select_tool_prompt(functions, function_schemas) 388 | else: 389 | self.schema = {"type": "array", "items": {"anyOf": function_schemas}} 390 | self.tool_prompt = self._multiple_tool_prompt(functions, function_schemas) 391 | 392 | def translate_reason(self, reason): 393 | if reason == "end": 394 | if self.is_legacy_function_call: 395 | return "function_call" 396 | return "tool_calls" 397 | return super().translate_reason(reason) 398 | 399 | def generation_stopped( 400 | self, 401 | stop_reason: str, 402 | prompt_tokens: int, 403 | completion_tokens: int, 404 | ): 405 | finish_reason = self.translate_reason(stop_reason) 406 | if finish_reason == "tool_calls": 407 | tool_calls = json.loads(self.content) 408 | if not isinstance(tool_calls, list): 409 | # len(functions) == 1 was special cased 410 | tool_calls = [tool_calls] 411 | message = { 412 | "role": "assistant", 413 | "tool_calls": [ 414 | { 415 | "id": f"call_{self.id}_{i}", 416 | "type": "function", 417 | "function": { 418 | "name": function_call["name"], 419 | "arguments": json.dumps(function_call["arguments"]), 420 | }, 421 | } 422 | for i, function_call in enumerate(tool_calls) 423 | ], 424 | } 425 | elif finish_reason == "function_call": 426 | function_call = json.loads(self.content) 427 | message = { 428 | "role": "assistant", 429 | "function_call": { 430 | "name": function_call["name"], 431 | "arguments": json.dumps(function_call["arguments"]), 432 | }, 433 | } 434 | else: 435 | message = None 436 | return { 437 | "choices": [ 438 | {"index": 0, "message": message, "finish_reason": finish_reason} 439 | ], 440 | **self.format_usage(prompt_tokens, completion_tokens), 441 | **self.message_properties(), 442 | } 443 | 444 | def _one_tool_prompt(self, tool, tool_schema): 445 | return f""" 446 | You are a helpful assistant with access to a tool that you must invoke to answer the user's request. 447 | The tool is: 448 | Tool {tool.name}: {tool.description} 449 | Invocation schema: {json.dumps(tool_schema)} 450 | Your answer is a JSON object according to the invocation schema in order to answer the user request below. 451 | """ 452 | 453 | def _multiple_tool_prompt(self, tools, tool_schemas, separator="\n"): 454 | return f""" 455 | You are a helpful assistant with access to tools that you must invoke to answer the user's request. 456 | The following tools are available: 457 | {separator.join([ f''' 458 | Tool {tool.name}: {tool.description} 459 | Invocation schema: {json.dumps(tool_schema)} 460 | ''' for tool, tool_schema in zip(tools, tool_schemas) ])} 461 | Your answer is a JSON array with one or more tool invocations according to the appropriate schema(s) 462 | in order to answer the user request below. 463 | """ 464 | 465 | def _select_tool_prompt(self, tools, tool_schemas, separator="\n"): 466 | return f""" 467 | You are a helpful assistant with access to tools that you must invoke to answer the user's request. 468 | The following tools are available: 469 | {separator.join([ f''' 470 | Function {tool.name}: {tool.description} 471 | Tool schema: {json.dumps(tool_schema)} 472 | ''' for tool, tool_schema in zip(tools, tool_schemas) ])} 473 | Your answer is a JSON object according to the invocation schema of the most appropriate tool to use 474 | to answer the user request below. 475 | """ 476 | 477 | 478 | class ToolCallStreamingResponder(ToolCallResponder): 479 | def __init__( 480 | self, 481 | model_name: str, 482 | functions: list[dict], 483 | is_legacy_function_call: bool, 484 | _model, 485 | ): 486 | super().__init__(model_name, functions, is_legacy_function_call) 487 | self.object_type = "chat.completion.chunk" 488 | 489 | # We need to parse the output as it's being generated in order to send 490 | # streaming messages that contain the name and arguments of the function 491 | # being called. 492 | 493 | self.current_function_index = -1 494 | self.current_function_name = None 495 | self.in_function_arguments = False 496 | 497 | def set_function_name(_prop_name: str, prop_value): 498 | self.current_function_index += 1 499 | self.current_function_name = prop_value 500 | 501 | def start_function_arguments(_prop_name: str): 502 | self.in_function_arguments = True 503 | 504 | def end_function_arguments(_prop_name: str, _prop_value: str): 505 | self.in_function_arguments = False 506 | 507 | hooked_function_schemas = [ 508 | { 509 | "type": "object", 510 | "properties": { 511 | "name": { 512 | "type": "const", 513 | "const": fn.name, 514 | "__hooks": { 515 | "value_end": set_function_name, 516 | }, 517 | }, 518 | "arguments": { 519 | **fn.parameters, 520 | "__hooks": { 521 | "value_start": start_function_arguments, 522 | "value_end": end_function_arguments, 523 | }, 524 | }, 525 | }, 526 | "required": ["name", "arguments"], 527 | } 528 | for fn in functions 529 | ] 530 | if len(hooked_function_schemas) == 1: 531 | hooked_schema = hooked_function_schemas[0] 532 | elif is_legacy_function_call: 533 | hooked_schema = {"oneOf": hooked_function_schemas} 534 | else: 535 | hooked_schema = { 536 | "type": "array", 537 | "items": {"anyOf": hooked_function_schemas}, 538 | } 539 | self.tool_call_parser = _model.get_driver_for_json_schema(hooked_schema) 540 | 541 | def generated_tokens( 542 | self, 543 | text: str, 544 | ): 545 | argument_text = "" 546 | for char in text: 547 | if self.in_function_arguments: 548 | argument_text += char 549 | # Update state. This is certain to parse, no need to check for rejections. 550 | self.tool_call_parser.advance_char(char) 551 | if not argument_text: 552 | return None 553 | assert self.current_function_name 554 | if self.is_legacy_function_call: 555 | delta = { 556 | "function_call": { 557 | "name": self.current_function_name, 558 | "arguments": argument_text, 559 | } 560 | } 561 | else: 562 | delta = { 563 | "tool_calls": [ 564 | { 565 | "index": self.current_function_index, 566 | "id": f"call_{self.id}_{self.current_function_index}", 567 | "type": "function", 568 | "function": { 569 | # We send the name on every update, but OpenAI only sends it on 570 | # the first one for each call, with empty arguments (""). Further 571 | # updates only have the arguments field. This is something we may 572 | # want to emulate if client code depends on this behavior. 573 | "name": self.current_function_name, 574 | "arguments": argument_text, 575 | }, 576 | } 577 | ] 578 | } 579 | message = { 580 | "choices": [{"index": 0, "delta": delta, "finish_reason": None}], 581 | **self.message_properties(), 582 | } 583 | return f"data: {json.dumps(message)}\n" 584 | 585 | def generation_stopped( 586 | self, 587 | stop_reason: str, 588 | prompt_tokens: int, 589 | completion_tokens: int, 590 | ): 591 | finish_reason = self.translate_reason(stop_reason) 592 | message = { 593 | "choices": [{"index": 0, "delta": {}, "finish_reason": finish_reason}], 594 | # Usage field notes: 595 | # - OpenAI only sends usage in streaming if the option 596 | # stream_options.include_usage is true, but we send it always. 597 | # - OpenAI sends two separate messages: one with the finish_reason and no 598 | # usage field, and one with an empty choices array and the usage field. 599 | **self.format_usage(prompt_tokens, completion_tokens), 600 | **self.message_properties(), 601 | } 602 | return f"data: {json.dumps(message)}\ndata: [DONE]\n" 603 | -------------------------------------------------------------------------------- /src/examples/static/attention.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Attention 6 | 176 | 386 | 387 | 388 | 389 |
390 |

Prompt

391 | 392 | 393 | 394 | 395 | 396 | 401 |
402 | 403 |
404 |
405 | 409 | 437 | 453 | 454 | 455 | -------------------------------------------------------------------------------- /src/examples/static/ui.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | LLM 8 | 163 | 295 | 296 | 297 | 298 |
299 |

Prompt

300 | 301 | 302 | 303 | 304 | 305 | 310 |
311 | Generation options 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 |
321 |
322 | 323 | 324 |
325 |
326 | 335 | 336 | 337 | 338 | -------------------------------------------------------------------------------- /src/llm_structured_output/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | LLM structured output: constrain generation to a JSON schema. 3 | """ 4 | from .json_schema_acceptor import JsonSchemaAcceptor, JsonSchemaAcceptorDriver 5 | from .json_acceptor import JsonAcceptor 6 | from .util.bitmap import bias_logits 7 | from .util.tokenization import HuggingfaceTokenizerHelper 8 | -------------------------------------------------------------------------------- /src/llm_structured_output/acceptor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base token acceptors. 3 | 4 | A token acceptor constrains the tokens that are acceptable at this point in 5 | the parsing or generation of a text. 6 | 7 | Since multiple parses of a given input may be possible (or multiple generations 8 | valid according to e.g. a schema), the acceptor creates multiple "cursors", one 9 | for each valid current state of the acceptor. This is akin to a chart parser, 10 | where all possible parses of the input are carried in parallel, which minimizes 11 | backtracking that is expensive on an LLM. 12 | 13 | The basic flow is: 14 | - First, the vocabulary (list of possible tokens for the LLM) is prepared into 15 | a trie for logarithmic traversal. Subclasses may also perform their own 16 | vocabulary preparation. 17 | - The acceptor's get_cursors() method is called, and the acceptor issues one or 18 | more cursors with initial state(s). 19 | - The trie is traversed to find which tokens are a valid match in the current 20 | state of the active cursor(s). For each cursor: 21 | - The select() method is called to narrow down the next character(s) that the 22 | cursor can accept in its current state. 23 | - For each selected character, we advance() the cursor, obtaining one or more 24 | follow-up cursors that represent the next state(s) of the cursor. 25 | - We descend down the trie branch corresponding to the selected character, and 26 | perform the same select(), advance() operation on the new cursor(s). 27 | - We traverse until the cursor(s) have reached an accepted state or we reach a 28 | leaf node. 29 | - As we traverse the trie recursively, we collect the token ids for each node. 30 | This creates a set of valid tokens that will be accepted. 31 | 32 | For example: if we have a TextAcceptor that will accept the word "true", the 33 | initial cursor's select() method will return "t" as the set of acceptable 34 | characters. We will then advance the cursor and obtain a cursor that accepts the 35 | word "rue", and our current trie node will become the "t" child branch of the 36 | prior trie node. We will then match the new trie node with the new acceptor, etc. 37 | 38 | Acceptors can be chained with e.g. a StateMachineAcceptor. In this case, when a 39 | cursor reaches a final state, the parent acceptor moves its own cursor forward, 40 | potentially issuing more cursors that can be matched with the remainder of the 41 | trie. 42 | 43 | Some methods have been added to help prevent combinatorial explosions while 44 | searching that can have a big effect in performance. For example, an acceptor 45 | for a quoted string can select() a very large amount of characters after the 46 | first quote. Descending upon every branch of the trie is not necessary in as 47 | much as every character is essentially equivalently valid. To avoid this, we 48 | allow the acceptor to prune the trie so that all equivalent characters are 49 | collapsed into one branch. In such a collapsed trie, each node keeps a set with 50 | all the ids for valid tokens of the same length, which are equivalent from the 51 | point of the view of the acceptor. 52 | """ 53 | 54 | from __future__ import annotations 55 | from copy import copy as shallowcopy 56 | from time import time_ns 57 | from typing import Iterable, Tuple 58 | 59 | from .util.tokentrie import TokenTrie 60 | 61 | 62 | class TokenAcceptor: 63 | """ 64 | Base class for token acceptors. 65 | """ 66 | 67 | @classmethod 68 | def prepare_vocabulary(cls, vocabulary: Iterable[Tuple[int, str]]) -> TokenTrie: 69 | """ 70 | Given a list of tokens (typically the vocabulary of an LLM), create 71 | a trie that will be used to select the tokens accepted by the current 72 | set of cursors. 73 | """ 74 | vocabulary_trie = TokenTrie() 75 | vocabulary_trie.insert_all(vocabulary) 76 | return vocabulary_trie 77 | 78 | @classmethod 79 | def match_all(cls, cursors: Iterable[TokenAcceptor.Cursor], trie: TokenTrie) -> int: 80 | """ 81 | Find which tokens in the vocabulary move any of the cursors towards an 82 | acceptance state from their current state. 83 | """ 84 | if any(cursor.matches_all() for cursor in cursors): 85 | return trie.collect_ids() 86 | bitmap = 0 87 | for cursor in cursors: 88 | bitmap |= cursor.match(trie) 89 | return bitmap 90 | 91 | @classmethod 92 | def debug_match_all( 93 | cls, 94 | cursors: Iterable[TokenAcceptor.Cursor], 95 | trie: TokenTrie, 96 | debug_output_fn=print, 97 | ) -> int: 98 | """ 99 | Same as match_all() but outputs debug information. 100 | """ 101 | if any(cursor.matches_all() for cursor in cursors): 102 | return trie.collect_ids() 103 | debug_output_fn("MATCH ALL") 104 | bitmap = 0 105 | for cursor in cursors: 106 | start = time_ns() 107 | cursor_matches = cursor.debug_match(trie, debug_output_fn) 108 | dt_ns = time_ns() - start 109 | match_count = bin(cursor_matches).count("1") 110 | debug_output_fn(f"t={dt_ns/1e6:.02f}ms {match_count=} {repr(cursor)}") 111 | bitmap |= cursor_matches 112 | return bitmap 113 | 114 | @classmethod 115 | def advance_all( 116 | cls, cursors: Iterable[TokenAcceptor.Cursor], char: str 117 | ) -> list[TokenAcceptor.Cursor]: 118 | """ 119 | Advance multiple cursors in parallel. 120 | """ 121 | return [ 122 | new_cursor 123 | for cursor in cursors 124 | if char in cursor.select(set(char)) 125 | for new_cursor in cursor.advance(char) 126 | ] 127 | 128 | def get_cursors(self) -> Iterable[TokenAcceptor.Cursor]: 129 | """ 130 | Get one or more cursors to traverse the acceptor. 131 | Override. 132 | """ 133 | return [self.__class__.Cursor(self)] 134 | 135 | class Cursor: 136 | """ 137 | A cursor encapsulates a valid current state of a token acceptor. 138 | """ 139 | 140 | def __init__(self, acceptor: TokenAcceptor): 141 | pass 142 | 143 | def clone(self): 144 | """ 145 | Cursors are never mutated, they are cloned as they advance. 146 | They should also be lightweight: think twice before overriding this 147 | to e.g. a deepcopy. 148 | """ 149 | return shallowcopy(self) 150 | 151 | def matches_all(self) -> bool: 152 | """ 153 | The acceptor accepts all the tokens (i.e. free text). 154 | This is an optimization and only useful for acceptors that don't constrain 155 | the input, such as WaitForAcceptor. 156 | """ 157 | return False 158 | 159 | def select(self, candidate_chars: set[str]) -> Iterable[str]: 160 | """ 161 | Narrow down the characters that are offered to the cursor for advancement. 162 | This is a crucial performance improvement for cursors in a state where they'll 163 | accept only a small set of characters, since they will be tested against that 164 | set instead of the whole range of characters available. 165 | Override. 166 | """ 167 | return candidate_chars 168 | 169 | # pylint: disable-next=unused-argument 170 | def advance(self, char: str) -> Iterable[TokenAcceptor.Cursor]: 171 | """ 172 | If the character can be consumed, return new cursor(s) for the possible 173 | continuation(s). IMPORTANT: Cursors should not mutate their state, only 174 | return mutated copies of the object, as the advance method is called 175 | multiple times with different inputs. See clone() method above. 176 | Override. 177 | """ 178 | return [] 179 | 180 | def in_accepted_state(self) -> bool: 181 | """ 182 | Returns True if the cursor has reached a final state. 183 | Typically, rather than override you should return an AcceptedState object 184 | in the advance() method when the state is reached after consuming input. 185 | """ 186 | return False 187 | 188 | def get_value(self): 189 | """ 190 | Returns the current value of the cursor as defined by itself. This can be 191 | either the ongoing representation of its temporary state, or its final value 192 | usable for the application once it reaches accepted state. At that point, 193 | cursors that return the same value are considered identical and duplicates 194 | may be discarded for performance. 195 | Override. 196 | """ 197 | return None 198 | 199 | def get_value_path(self): 200 | """ 201 | Returns the path of the value being pointed at by the cursor as defined by the 202 | application. This can be for example a JSON path in the case of a JSON acceptor. 203 | For higher-level application purposes only, not required for accepting. 204 | Override. 205 | """ 206 | return "" 207 | 208 | def is_in_value(self): 209 | """ 210 | Returns true if the cursor is accepting a value as opposed to syntactic elements. 211 | Used in conjunction with get_value_path(). 212 | Override. 213 | """ 214 | return False 215 | 216 | def prune(self, trie: TokenTrie) -> Iterable[(str, TokenTrie)]: 217 | """ 218 | Select the children of the trie to search for matches. See match() below. 219 | This can be overriden in order to e.g. use a collapsed trie. 220 | """ 221 | if trie.children: 222 | chars = set(trie.children.keys()) 223 | selected_chars = chars & set(self.select(chars)) 224 | for char in selected_chars: 225 | yield (char, trie.children[char]) 226 | 227 | def match(self, trie: TokenTrie) -> int: 228 | """ 229 | Find which tokens in the vocabulary move the acceptor towards an acceptance 230 | state from the current state held by this cursor. 231 | Returns a bit map with the bits corresponding to the index if the matched 232 | tokens set to 1. 233 | """ 234 | if self.matches_all(): 235 | return trie.collect_ids() 236 | bitmap = 0 237 | for char, child in self.prune(trie): 238 | followup_cursors = self.advance(char) 239 | if followup_cursors: 240 | bitmap |= child.ids 241 | for followup_cursor in followup_cursors: 242 | bitmap |= followup_cursor.match(child) 243 | return bitmap 244 | 245 | def debug_match( 246 | self, trie: TokenTrie, debug_output_fn=print, debug_indent=1 247 | ) -> int: 248 | """ 249 | Same as match() but outputs debug information 250 | """ 251 | debug_start = time_ns() 252 | if self.matches_all(): 253 | return trie.collect_ids() 254 | bitmap = 0 255 | debug_label = type(self).__qualname__ 256 | if isinstance(self, StateMachineAcceptor.Cursor): 257 | debug_label += f"({type(self.transition_cursor).__qualname__})" 258 | debug_prefix = " " * debug_indent + debug_label 259 | debug_prune_start = time_ns() 260 | for char, child in self.prune(trie): 261 | debug_advance_start = time_ns() 262 | followup_cursors = self.advance(char) 263 | debug_advance_end = time_ns() 264 | prune_time = (debug_advance_start - debug_prune_start) / 1e6 265 | advance_time = (debug_advance_end - debug_advance_start) / 1e6 266 | debug_output_fn( 267 | f"{debug_prefix} >>> " 268 | f"{prune_time=:.02f}ms {advance_time=:.02f}ms char={repr(char)}" 269 | ) 270 | debug_followup_start = time_ns() 271 | if followup_cursors: 272 | bitmap |= child.ids 273 | for followup_cursor in followup_cursors: 274 | bitmap |= followup_cursor.debug_match( 275 | child, debug_output_fn, debug_indent + 1 276 | ) 277 | debug_followup_end = time_ns() 278 | followup_time = (debug_followup_end - debug_followup_start) / 1e6 279 | followup_count = len(followup_cursors) 280 | match_count = bin(bitmap).count("1") 281 | debug_output_fn( 282 | f"{debug_prefix} <<< {followup_count=} {followup_time=:.02f}ms {match_count=}" 283 | ) 284 | debug_prune_start = time_ns() 285 | total_time = (time_ns() - debug_start) / 1e6 286 | debug_output_fn(f"{debug_prefix} {total_time=:.02f}ms") 287 | return bitmap 288 | 289 | def __repr__(self): 290 | return f"{type(self).__qualname__}(value={repr(self.get_value())})" 291 | 292 | 293 | class AcceptedState(TokenAcceptor.Cursor): 294 | """ 295 | Holds a cursor that has reached the accepted state. 296 | """ 297 | 298 | def __init__(self, cursor: TokenAcceptor.Cursor): 299 | self.cursor = cursor 300 | 301 | def in_accepted_state(self): 302 | return True 303 | 304 | def get_value(self): 305 | return self.cursor.get_value() 306 | 307 | def __repr__(self): 308 | return f"✅{repr(self.cursor)}" 309 | 310 | 311 | class CharAcceptor(TokenAcceptor): 312 | """ 313 | Accept one character iff is in the set of expected characters. 314 | """ 315 | 316 | def __init__(self, charset: Iterable[str]): 317 | self.charset = charset 318 | 319 | class Cursor(TokenAcceptor.Cursor): 320 | """ 321 | Cursor for CharAcceptor 322 | """ 323 | 324 | def __init__(self, acceptor, value=None): 325 | self.acceptor = acceptor 326 | self.value = value 327 | 328 | def select(self, candidate_chars): 329 | return self.acceptor.charset 330 | 331 | def advance(self, char): 332 | # Because we implemented the select method, we are guaranteed that the 333 | # char is in our accepted set. 334 | return [AcceptedState(self.__class__(self.acceptor, char))] 335 | 336 | def get_value(self): 337 | return self.value 338 | 339 | def __repr__(self): 340 | return f"charset={repr(self.acceptor.charset)} value={repr(self.value)}" 341 | 342 | 343 | class TextAcceptor(TokenAcceptor): 344 | """ 345 | Accept a pre-determined string of characters. 346 | """ 347 | 348 | def __init__(self, text: str): 349 | assert len(text) > 0 350 | self.text = text 351 | 352 | class Cursor(TokenAcceptor.Cursor): 353 | """ 354 | Cursor for TextAcceptor 355 | """ 356 | 357 | def __init__(self, acceptor, pos=0): 358 | self.acceptor = acceptor 359 | self.pos = pos 360 | 361 | def select(self, candidate_chars): 362 | return self.acceptor.text[self.pos] 363 | 364 | def advance(self, char): 365 | next_cursor = self.__class__(self.acceptor, self.pos + 1) 366 | if next_cursor.pos == len(self.acceptor.text): 367 | return [AcceptedState(next_cursor)] 368 | return [next_cursor] 369 | 370 | def get_value(self) -> str: 371 | head = self.acceptor.text[0 : self.pos] 372 | tail = self.acceptor.text[self.pos :] 373 | if len(tail): 374 | return f"{head}👉{tail}" 375 | else: 376 | return f"{head}" 377 | 378 | 379 | class StateMachineAcceptor(TokenAcceptor): 380 | """ 381 | Token acceptor that follows a state graph that defines edges to transition 382 | from state to state. Each state can have multiple edges, defined by the 383 | target state and a TokenAcceptor that, when reaching accepted state, causes 384 | the state machine acceptor to move to the target state. This is repeated 385 | until the state machine reaches a final state. Multiple transition paths 386 | are explored in parallel. 387 | """ 388 | 389 | def __init__(self, graph=None, initial_state=None, end_states=None): 390 | self.graph = graph or [] 391 | self.initial_state = initial_state or 0 392 | self.end_states = set(end_states or ["$"]) 393 | 394 | def get_edges(self, state): 395 | """ 396 | Retrieve the graph edges for transitions out of this state. 397 | Can be overriden for dynamic graphs. 398 | """ 399 | return self.graph[state] 400 | 401 | def get_cursors(self): 402 | initial_cursor = self.Cursor(self) 403 | initial_cursor.current_state = self.initial_state 404 | return self._find_transitions(initial_cursor, [], set()) 405 | 406 | def _find_transitions(self, cursor, visited_states, traversed_edges): 407 | try: 408 | edges = self.get_edges(cursor.current_state) 409 | except (KeyError, IndexError, TypeError): 410 | assert cursor.current_state in self.end_states 411 | return [] 412 | cursors = [] 413 | for transition_acceptor, target_state in edges: 414 | if cursor.start_transition(transition_acceptor, target_state): 415 | for transition_cursor in transition_acceptor.get_cursors(): 416 | copy = cursor.clone() 417 | copy.transition_cursor = transition_cursor 418 | copy.target_state = target_state 419 | # Handle cursors that start in an accepted state, 420 | # e.g. EmptyTransition, WhitespaceAcceptor 421 | if transition_cursor.in_accepted_state(): 422 | new_visited_states = visited_states + [cursor.current_state] 423 | assert target_state not in new_visited_states # Infinite loop 424 | cursors += self._cascade_transition( 425 | copy, new_visited_states, traversed_edges 426 | ) 427 | else: 428 | cursors.append(copy) 429 | return cursors 430 | 431 | def _cascade_transition(self, cursor, visited_states, traversed_edges): 432 | assert cursor.transition_cursor.in_accepted_state() 433 | # Copy before validation to allow for cursor mutation, e.g. storing the transition_value 434 | cursors = [] 435 | copy: StateMachineAcceptor.Cursor = cursor.clone() 436 | if copy.complete_transition( 437 | copy.transition_cursor.get_value(), 438 | copy.target_state, 439 | copy.target_state in copy.acceptor.end_states, 440 | ): 441 | copy.current_state = copy.target_state 442 | copy.target_state = None 443 | copy.accept_history = copy.accept_history + [copy.transition_cursor.cursor] 444 | copy.transition_cursor = None 445 | copy.consumed_character_count = 0 446 | # De-duplicate cursors that have reached the same state with the same value. 447 | # This prevents combinatorial explosion because of e.g. empty transitions. 448 | state_value = (copy.current_state, repr(copy.get_value())) 449 | if state_value not in traversed_edges: 450 | traversed_edges.add(state_value) 451 | if copy.current_state in self.end_states: 452 | cursors.append(AcceptedState(copy)) 453 | cursors += self._find_transitions(copy, visited_states, traversed_edges) 454 | return cursors 455 | 456 | def advance_cursor(self, cursor, char): 457 | """ 458 | Advance a cursor, and if it reaches accepted state, cause the state machine to transition. 459 | """ 460 | next_cursors = [] 461 | traversed_edges = set() 462 | for followup_cursor in cursor.transition_cursor.advance(char): 463 | copy = cursor.clone() 464 | copy.transition_cursor = followup_cursor 465 | copy.consumed_character_count += 1 466 | if followup_cursor.in_accepted_state(): 467 | next_cursors += self._cascade_transition( 468 | copy, [], traversed_edges 469 | ) 470 | else: 471 | next_cursors.append(copy) 472 | return next_cursors 473 | 474 | class Cursor(TokenAcceptor.Cursor): 475 | """ 476 | Cursor for StateMachineAcceptor 477 | """ 478 | 479 | def __init__(self, acceptor): 480 | self.acceptor = acceptor 481 | self.accept_history = [] 482 | self.current_state = None 483 | self.transition_cursor = None 484 | self.target_state = None 485 | self.consumed_character_count = 0 486 | 487 | def matches_all(self): 488 | if self.transition_cursor is None: 489 | return False 490 | return self.transition_cursor.matches_all() 491 | 492 | def select(self, candidate_chars): 493 | if self.transition_cursor is None: 494 | return set() 495 | return self.transition_cursor.select(candidate_chars) 496 | 497 | def prune(self, trie): 498 | if self.transition_cursor is None: 499 | return [] 500 | return self.transition_cursor.prune(trie) 501 | 502 | def advance(self, char): 503 | return self.acceptor.advance_cursor(self, char) 504 | 505 | # pylint: disable-next=unused-argument 506 | def start_transition(self, transition_acceptor, target_state) -> bool: 507 | """ 508 | Override to prevent an edge to be traversed. 509 | """ 510 | return True 511 | 512 | def complete_transition( # pylint: disable-next=unused-argument 513 | self, transition_value, target_state, is_end_state 514 | ) -> bool: 515 | """ 516 | Override to perform additional checks on the acceptee and mutate the cursor 517 | with the transition_value as appropriate. 518 | """ 519 | return True 520 | 521 | def get_value(self): 522 | value = [ 523 | accepted_transition_cursor.get_value() 524 | for accepted_transition_cursor in self.accept_history 525 | ] 526 | if self.transition_cursor is not None: 527 | value.append(self.transition_cursor.get_value()) 528 | return value 529 | 530 | def is_in_value(self): 531 | if self.consumed_character_count > 0: 532 | return self.transition_cursor.is_in_value() 533 | return self.accept_history[-1].is_in_value() if self.accept_history else None 534 | 535 | def get_value_path(self): 536 | if self.consumed_character_count > 0: 537 | return self.transition_cursor.get_value_path() 538 | return self.accept_history[-1].get_value_path() if self.accept_history else "" 539 | 540 | def __repr__(self) -> str: 541 | if self.transition_cursor is not None: 542 | transition_cursor = repr(self.transition_cursor) 543 | target_state = self.target_state 544 | else: 545 | transition_cursor = "None" 546 | target_state = "None" 547 | if self.accept_history: 548 | accept_history = [] 549 | for accepted_transition_cursor in self.accept_history: 550 | if isinstance( 551 | accepted_transition_cursor, StateMachineAcceptor.Cursor 552 | ): 553 | accept_history += accepted_transition_cursor.accept_history 554 | else: 555 | accept_history.append(accepted_transition_cursor) 556 | history = repr( 557 | "".join( 558 | [ 559 | str(accepted_transition_cursor.get_value()) 560 | for accepted_transition_cursor in accept_history 561 | ] 562 | ) 563 | ) 564 | else: 565 | history = "" 566 | state = ( 567 | f"{history} {self.current_state}⇒{target_state} {transition_cursor}" 568 | ) 569 | return f"{type(self).__qualname__}({state})" 570 | 571 | class EmptyTransitionAcceptor(TokenAcceptor): 572 | """ 573 | Faux acceptor that allows to create empty transition edges in a state 574 | machine graph for convenience in expressing complex graphs. 575 | An empty edge skips the current state altogether, without the need to 576 | consume input. 577 | """ 578 | 579 | def get_cursors(self): 580 | return [AcceptedState(self.Cursor(self))] 581 | 582 | class Cursor(TokenAcceptor.Cursor): 583 | """ 584 | Cursor for EmptyTransitionAcceptor 585 | """ 586 | 587 | def get_value(self): 588 | return "" 589 | 590 | # Singleton EmptyTransitionAcceptor 591 | EmptyTransition = EmptyTransitionAcceptor() 592 | 593 | 594 | class SequenceAcceptor(StateMachineAcceptor): 595 | """ 596 | Chain acceptors in sequence 597 | """ 598 | 599 | def __init__(self, acceptors): 600 | graph = [[(acceptor, i + 1)] for i, acceptor in enumerate(acceptors)] 601 | super().__init__(graph, initial_state=0, end_states=[len(acceptors)]) 602 | 603 | class Cursor(StateMachineAcceptor.Cursor): 604 | """ 605 | Cursor for SequenceAcceptor. Defined for inspectability. 606 | """ 607 | 608 | 609 | class WaitForAcceptor(TokenAcceptor): 610 | """ 611 | Accept all text until finding a segment that triggers another acceptor. 612 | This is useful to allow for free text until a delimiter is found, e.g. 613 | when the output of an LLM includes JSON that is encapsulated within a 614 | ```json ... ``` block. 615 | """ 616 | 617 | def __init__(self, wait_for_acceptor: TokenAcceptor): 618 | self.wait_for_acceptor = wait_for_acceptor 619 | 620 | class Cursor(TokenAcceptor.Cursor): 621 | """ 622 | Cursor for WaitForAcceptor 623 | """ 624 | 625 | def __init__(self, acceptor, cursors=None): 626 | self.acceptor = acceptor 627 | if cursors: 628 | self.cursors = cursors 629 | else: 630 | self.cursors = acceptor.wait_for_acceptor.get_cursors() 631 | 632 | def matches_all(self): 633 | return True 634 | 635 | def advance(self, char): 636 | cursors = TokenAcceptor.advance_all(self.cursors, char) 637 | accepted_cursors = [ 638 | cursor for cursor in cursors if cursor.in_accepted_state() 639 | ] 640 | if accepted_cursors: 641 | return accepted_cursors 642 | return [self.__class__(self.acceptor, cursors)] 643 | 644 | def get_value(self): 645 | return f"Waiting for {repr(self.cursors)}" 646 | -------------------------------------------------------------------------------- /src/llm_structured_output/json_acceptor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Acceptors for JSON parsing or constraning LLM generation to JSON outputs. 3 | """ 4 | 5 | import json 6 | 7 | from .acceptor import ( 8 | TokenAcceptor, 9 | AcceptedState, 10 | CharAcceptor, 11 | StateMachineAcceptor, 12 | SequenceAcceptor, 13 | TextAcceptor, 14 | ) 15 | from .util.tokentrie import TokenTrie 16 | 17 | 18 | class WhitespaceTokenTrie(TokenTrie): 19 | """ 20 | Create a smaller trie by collapsing all whitespace to a single symbol. 21 | Since all whitespace is equivalent in JSON, tokens that only differ in 22 | the type of whitespace are equivalent from a semantic point of view. 23 | 24 | For example, the tokens "\n\n\n", "\t\t\t" and " " are all mapped to the same 25 | node root -> " " -> " " -> " ", which now contains the token ids of all three 26 | tokens in its set of ids. 27 | 28 | This allows us to reduce the number of equivalent branches we explore when 29 | finding valid tokens. Note that this doesn't limit the possible output of 30 | an LLM, since the token ids are kept in the trie and thus matched as valid, 31 | and are accepted by the acceptor. 32 | """ 33 | 34 | @classmethod 35 | def from_trie(cls, trie, whitespace_charset): 36 | """ 37 | Create a WhitespaceTokenTrie given a full vocabulary trie. 38 | """ 39 | if isinstance(trie, WhitespaceTokenTrie): 40 | return trie 41 | 42 | def _whitespace_collapse_fn(char, level): 43 | if char in whitespace_charset: 44 | return " " 45 | if level == 0: 46 | # The trie doesn't need to contain tokens that don't start with whitespace, 47 | # since they won't be selected by the WhitespaceAcceptor. 48 | return None 49 | return True 50 | 51 | # pylint: disable-next=protected-access 52 | return trie._map(_whitespace_collapse_fn, WhitespaceTokenTrie()) 53 | 54 | 55 | class WhitespaceAcceptor(TokenAcceptor): 56 | """ 57 | Optional whitespace 58 | """ 59 | 60 | WHITESPACE = " \n\r\t" 61 | 62 | _cached_tries = {} 63 | 64 | @classmethod 65 | def prepare_trie(cls, trie: TokenTrie): 66 | """ 67 | Build a collapsed trie that reduces the search space for valid tokens. 68 | """ 69 | trie_id = id(trie) 70 | if trie_id in cls._cached_tries: 71 | return cls._cached_tries[trie_id] 72 | collapsed_trie = WhitespaceTokenTrie.from_trie( 73 | trie, WhitespaceAcceptor.WHITESPACE 74 | ) 75 | cls._cached_tries[trie_id] = collapsed_trie 76 | return collapsed_trie 77 | 78 | def __init__(self, max_whitespace: int = 40): 79 | self.max_whitespace = max_whitespace 80 | 81 | def get_cursors(self): 82 | # Whitespace is optional 83 | cursor = WhitespaceAcceptor.Cursor(self) 84 | return [cursor, AcceptedState(cursor)] 85 | 86 | class Cursor(TokenAcceptor.Cursor): 87 | """ 88 | Cursor for WhitespaceAcceptor 89 | """ 90 | 91 | def __init__(self, acceptor, text=""): 92 | self.acceptor = acceptor 93 | self.text = text 94 | self.length_exceeded = len(text) > self.acceptor.max_whitespace 95 | 96 | def select(self, candidate_chars): 97 | if self.length_exceeded: 98 | return set() 99 | return WhitespaceAcceptor.WHITESPACE 100 | 101 | def prune(self, trie): 102 | """ 103 | Use a custom matching trie to collapse all equivalent whitespace 104 | into one, saving time when selecting valid tokens. 105 | """ 106 | collapsed_trie = WhitespaceAcceptor.prepare_trie(trie) 107 | return super().prune(collapsed_trie) 108 | 109 | def advance(self, char): 110 | # Sometimes, LLMs try to run away with spaces when they don't know how to continue. 111 | # If the LLM triggers this often, consider whether the LLM is suitable for emitting 112 | # JSON and/or whether the task is achievable and makes sense with the information 113 | # provided in the prompt. 114 | if self.length_exceeded: 115 | return [] 116 | next_cursor = WhitespaceAcceptor.Cursor(self.acceptor, self.text + char) 117 | # More whitespace is optional 118 | return [next_cursor, AcceptedState(next_cursor)] 119 | 120 | def get_value(self): 121 | return self.text 122 | 123 | 124 | class BooleanAcceptor(StateMachineAcceptor): 125 | """ 126 | Accepts a JSON boolean value: true, false 127 | """ 128 | 129 | def __init__(self): 130 | super().__init__([[(TextAcceptor("true"), "$"), (TextAcceptor("false"), "$")]]) 131 | 132 | class Cursor(StateMachineAcceptor.Cursor): 133 | """ 134 | Cursor for BooleanAcceptor 135 | """ 136 | 137 | def __init__(self, acceptor): 138 | super().__init__(acceptor) 139 | self.value = None 140 | 141 | def complete_transition(self, transition_value, target_state, is_end_state): 142 | if is_end_state: 143 | if transition_value == "true": 144 | self.value = True 145 | else: 146 | assert transition_value == "false" 147 | self.value = False 148 | return True 149 | 150 | def get_value(self): 151 | return self.value 152 | 153 | def is_in_value(self): 154 | return True 155 | 156 | 157 | class NullAcceptor(TextAcceptor): 158 | """ 159 | Accepts the JSON null value 160 | """ 161 | 162 | def __init__(self): 163 | super().__init__("null") 164 | 165 | class Cursor(TextAcceptor.Cursor): 166 | """ 167 | Cursor for NullAcceptor 168 | """ 169 | 170 | def is_in_value(self): 171 | return True 172 | 173 | 174 | DigitAcceptor = CharAcceptor("0123456789") 175 | HexDigitAcceptor = CharAcceptor("0123456789ABCDEFabcdef") 176 | 177 | 178 | class StringCharTokenTrie(TokenTrie): 179 | """ 180 | Create a smaller trie by collapsing all unescaped valid string characters 181 | to a single one while keeping the token ids. This is useful to reduce 182 | combinatorial explosion in string acceptance when all strings of equal 183 | length are equally acceptable. 184 | """ 185 | 186 | @classmethod 187 | def from_trie(cls, trie): 188 | """ 189 | Create a StringCharTokenTrie given a full trie. 190 | """ 191 | if isinstance(trie, StringCharTokenTrie): 192 | return trie 193 | 194 | def _string_char_acceptor_collapse_fn(char, _level): 195 | if char in ['"', "\\"]: 196 | return True 197 | if char in StringCharAcceptor.INVALID_CHARS: 198 | return None 199 | return "." 200 | 201 | # pylint: disable-next=protected-access 202 | return trie._map(_string_char_acceptor_collapse_fn, StringCharTokenTrie()) 203 | 204 | 205 | class StringCharAcceptor(TokenAcceptor): 206 | """ 207 | Accepts a valid JSON unescaped string character 208 | """ 209 | 210 | INVALID_CHARS = set(chr(c) for c in range(0, 0x20)) | set(['"', "\\"]) 211 | _cached_tries = {} 212 | 213 | @classmethod 214 | def prepare_trie(cls, trie: TokenTrie): 215 | """ 216 | Build a collapsed trie that reduces the search space for valid tokens. 217 | Note that while there is only one main vocabulary trie, we may need to 218 | several collapsed tries because sometimes string matching will start 219 | in the middle of the main trie. I.e. we ara half way through the main 220 | trie with another acceptor; that acceptor reaches an end state and then 221 | we transition to the string acceptor; thus we start string matching in 222 | the middle of the main trie instead of the root. This can happen e.g. 223 | if there's tokens in the vocabulary that contain a quote and then 224 | additional characters afterwards. 225 | """ 226 | trie_id = id(trie) 227 | if trie_id in cls._cached_tries: 228 | return cls._cached_tries[trie_id] 229 | collapsed_trie = StringCharTokenTrie().from_trie(trie) 230 | cls._cached_tries[trie_id] = collapsed_trie 231 | return collapsed_trie 232 | 233 | class Cursor(TokenAcceptor.Cursor): 234 | """ 235 | Cursor for StringCharAcceptor 236 | """ 237 | 238 | def __init__(self, acceptor, value=None): 239 | self.acceptor = acceptor 240 | self.value = value 241 | 242 | def select(self, candidate_chars): 243 | return candidate_chars - StringCharAcceptor.INVALID_CHARS 244 | 245 | def prune(self, trie): 246 | """ 247 | Use a custom matching trie to avoid an explosion of valid options that 248 | are equivalent from the point of view of token matching. 249 | """ 250 | return super().prune(StringCharAcceptor.prepare_trie(trie)) 251 | 252 | def advance(self, char): 253 | return [AcceptedState(StringCharAcceptor.Cursor(self.acceptor, char))] 254 | 255 | def get_value(self): 256 | return self.value 257 | 258 | 259 | class StringAcceptor(StateMachineAcceptor): 260 | """ 261 | Accepts a well-formed JSON string 262 | """ 263 | 264 | STATES = [ 265 | [(CharAcceptor('"'), 1)], 266 | [(CharAcceptor('"'), "$"), (CharAcceptor("\\"), 2), (StringCharAcceptor(), 1)], 267 | [ 268 | (CharAcceptor('"\\/bfnrt'), 1), 269 | (CharAcceptor("u"), 3), 270 | ], 271 | [(HexDigitAcceptor, 4)], 272 | [(HexDigitAcceptor, 5)], 273 | [(HexDigitAcceptor, 6)], 274 | [(HexDigitAcceptor, 1)], 275 | ] 276 | 277 | def __init__(self): 278 | super().__init__(StringAcceptor.STATES) 279 | 280 | class Cursor(StateMachineAcceptor.Cursor): 281 | """ 282 | Cursor for StringAcceptor 283 | """ 284 | 285 | def __init__(self, acceptor): 286 | super().__init__(acceptor) 287 | self.text = "" 288 | self.length = 0 289 | self.value = None 290 | 291 | def complete_transition(self, transition_value, target_state, is_end_state): 292 | self.text += transition_value 293 | if target_state == 1 and self.current_state != 0: 294 | self.length += 1 295 | if is_end_state: 296 | self.value = json.loads(self.text) 297 | return True 298 | 299 | def get_value(self): 300 | if self.value is not None: 301 | return self.value 302 | else: 303 | return f"{self.text}👉" 304 | 305 | def is_in_value(self): 306 | return True 307 | 308 | 309 | class StringConstantAcceptor(TextAcceptor): 310 | """ 311 | Accept a constant string, quoted and escaped. 312 | """ 313 | 314 | def __init__(self, string: str): 315 | self.string = string 316 | super().__init__(json.dumps(string)) 317 | 318 | class Cursor(TextAcceptor.Cursor): 319 | """ 320 | Cursor for StringConstantAcceptor 321 | """ 322 | 323 | def get_value(self) -> str: 324 | if self.pos == len(self.acceptor.text): 325 | return self.acceptor.string 326 | return super().get_value() 327 | 328 | def is_in_value(self): 329 | return True 330 | 331 | 332 | class NumberTokenTrie(TokenTrie): 333 | """ 334 | Create a smaller trie by collapsing digit sequences. 335 | """ 336 | 337 | @classmethod 338 | def from_trie(cls, trie): 339 | """ 340 | Create a NumberTokenTrie given a full trie. 341 | """ 342 | if isinstance(trie, NumberTokenTrie): 343 | return trie 344 | 345 | def _number_acceptor_collapse_fn(char, level): 346 | if char in "0123456789": 347 | return "9" 348 | # Only store branches that start with a digit. 349 | return level > 0 350 | 351 | # pylint: disable-next=protected-access 352 | return trie._map(_number_acceptor_collapse_fn, StringCharTokenTrie()) 353 | 354 | 355 | class NumberAcceptor(StateMachineAcceptor): 356 | """ 357 | Accepts a well-formed JSON number 358 | """ 359 | 360 | STATES = { 361 | 0: [(CharAcceptor("-"), 1), (StateMachineAcceptor.EmptyTransition, 1)], # Sign 362 | 1: [(CharAcceptor("123456789"), 2), (CharAcceptor("0"), 3)], # First digit 363 | 2: [ 364 | (DigitAcceptor, 2), 365 | (StateMachineAcceptor.EmptyTransition, 3), 366 | ], # More digits 367 | 3: [(CharAcceptor("."), 4), (StateMachineAcceptor.EmptyTransition, 6)], 368 | 4: [(DigitAcceptor, 5)], # First decimal 369 | 5: [ 370 | (DigitAcceptor, 5), 371 | (StateMachineAcceptor.EmptyTransition, 6), 372 | ], # More decimals 373 | 6: [(CharAcceptor("eE"), 7)], 374 | 7: [(CharAcceptor("+-"), 8), (StateMachineAcceptor.EmptyTransition, 8)], 375 | 8: [(DigitAcceptor, 9)], # Exponential, first digit 376 | 9: [(DigitAcceptor, 9)], # Exponential, more digits 377 | "$": [2, 3, 5, 9], 378 | } 379 | _cached_tries = {} 380 | 381 | @classmethod 382 | def prepare_trie(cls, trie: TokenTrie): 383 | """ 384 | Build a collapsed trie that reduces the search space for valid tokens. 385 | """ 386 | trie_id = id(trie) 387 | if trie_id in cls._cached_tries: 388 | return cls._cached_tries[trie_id] 389 | collapsed_trie = NumberTokenTrie().from_trie(trie) 390 | cls._cached_tries[trie_id] = collapsed_trie 391 | return collapsed_trie 392 | 393 | def __init__(self): 394 | super().__init__(self.STATES, 0, self.STATES["$"]) 395 | 396 | class Cursor(StateMachineAcceptor.Cursor): 397 | """ 398 | Cursor for NumberAcceptor 399 | """ 400 | 401 | def __init__(self, acceptor): 402 | super().__init__(acceptor) 403 | self.text = "" 404 | self.value = None 405 | 406 | def prune(self, trie): 407 | """ 408 | Use a custom matching trie to avoid an explosion of valid options that 409 | are equivalent from the point of view of token matching. 410 | """ 411 | return super().prune(NumberAcceptor.prepare_trie(trie)) 412 | 413 | def complete_transition(self, transition_value, target_state, is_end_state): 414 | self.text += transition_value 415 | if is_end_state: 416 | self.value = json.loads(self.text) 417 | return True 418 | 419 | def get_value(self): 420 | if self.value is None: 421 | return f"{self.text}👉" 422 | return self.value 423 | 424 | def is_in_value(self): 425 | return True 426 | 427 | 428 | class ArrayAcceptor(StateMachineAcceptor): 429 | """ 430 | Accepts a well-formed JSON array 431 | """ 432 | 433 | def __init__(self): 434 | super().__init__() 435 | 436 | def get_edges(self, state): 437 | return { 438 | 0: [(TextAcceptor("["), 1)], 439 | 1: [(WhitespaceAcceptor(), 2), (TextAcceptor("]"), "$")], 440 | 2: [(JsonAcceptor(), 3)], 441 | 3: [(WhitespaceAcceptor(), 4)], 442 | 4: [ 443 | (SequenceAcceptor([TextAcceptor(","), WhitespaceAcceptor()]), 2), 444 | (TextAcceptor("]"), "$"), 445 | ], 446 | }[state] 447 | 448 | class Cursor(StateMachineAcceptor.Cursor): 449 | """ 450 | Cursor for ArrayAcceptor 451 | """ 452 | 453 | def __init__(self, acceptor): 454 | super().__init__(acceptor) 455 | self.value = [] 456 | 457 | def clone(self): 458 | c = super().clone() 459 | c.value = self.value[:] 460 | return c 461 | 462 | def complete_transition( 463 | self, transition_value, target_state, is_end_state 464 | ) -> bool: 465 | if self.current_state == 2: 466 | self.value.append(transition_value) 467 | return True 468 | 469 | def get_value_path(self): 470 | index = len(self.value) 471 | if self.current_state > 2: 472 | index -= 1 473 | return f"[{index}]{super().get_value_path()}" 474 | 475 | 476 | class ObjectAcceptor(StateMachineAcceptor): 477 | """ 478 | Accepts a well-formed JSON object 479 | """ 480 | 481 | def __init__(self): 482 | super().__init__() 483 | 484 | def get_edges(self, state): 485 | return { 486 | 0: [(TextAcceptor("{"), 1)], 487 | 1: [(self.EmptyTransition, 2), (self.EmptyTransition, 6)], 488 | 2: [(WhitespaceAcceptor(), 3)], 489 | 3: [(ObjectAcceptor.PropertyAcceptor(), 4)], 490 | 4: [(WhitespaceAcceptor(), 5)], 491 | 5: [(TextAcceptor(","), 2), (self.EmptyTransition, 7)], 492 | 6: [(WhitespaceAcceptor(), 7)], 493 | 7: [(TextAcceptor("}"), "$")], 494 | }[state] 495 | 496 | class Cursor(StateMachineAcceptor.Cursor): 497 | """ 498 | Cursor for ObjectAcceptor 499 | """ 500 | 501 | def __init__(self, acceptor): 502 | super().__init__(acceptor) 503 | self.value = {} 504 | 505 | def complete_transition( 506 | self, transition_value, target_state, is_end_state 507 | ) -> bool: 508 | if self.current_state == 3: 509 | prop_name, prop_value = transition_value 510 | self.value[prop_name] = prop_value 511 | return True 512 | 513 | def get_value(self): 514 | return self.value 515 | 516 | class PropertyAcceptor(SequenceAcceptor): 517 | """ 518 | JSON object property acceptor 519 | """ 520 | 521 | def __init__(self, graph=None): 522 | if graph is None: 523 | graph = [ 524 | StringAcceptor(), 525 | WhitespaceAcceptor(), 526 | TextAcceptor(":"), 527 | WhitespaceAcceptor(), 528 | JsonAcceptor(), 529 | ] 530 | super().__init__(graph) 531 | 532 | class Cursor(SequenceAcceptor.Cursor): 533 | """ 534 | Cursor for ObjectAcceptor.PropertyAcceptor 535 | """ 536 | 537 | def __init__(self, acceptor): 538 | super().__init__(acceptor) 539 | self.prop_name = None 540 | self.prop_value = None 541 | 542 | def complete_transition( 543 | self, transition_value, target_state, is_end_state 544 | ) -> bool: 545 | if target_state == 1: 546 | self.prop_name = transition_value 547 | elif is_end_state: 548 | self.prop_value = transition_value 549 | return True 550 | 551 | def get_value(self): 552 | return (self.prop_name, self.prop_value) 553 | 554 | def is_in_value(self): 555 | return self.current_state >= 4 and super().is_in_value() 556 | 557 | def get_value_path(self): 558 | return f".{self.prop_name}{super().get_value_path()}" 559 | 560 | 561 | class JsonAcceptor(StateMachineAcceptor): 562 | """ 563 | Acceptor for a JSON value 564 | """ 565 | 566 | def get_edges(self, state): 567 | if state == 0: 568 | return [ 569 | (BooleanAcceptor(), "$"), 570 | (NumberAcceptor(), "$"), 571 | (StringAcceptor(), "$"), 572 | (NullAcceptor(), "$"), 573 | (ObjectAcceptor(), "$"), 574 | (ArrayAcceptor(), "$"), 575 | ] 576 | return [] 577 | 578 | 579 | def prepare_json_acceptor_tries(trie: TokenTrie): 580 | """ 581 | Pre-cache custom acceptor tries. 582 | """ 583 | WhitespaceAcceptor.prepare_trie(trie) 584 | NumberAcceptor.prepare_trie(trie) 585 | StringCharAcceptor.prepare_trie(trie) 586 | if '"' in trie.children: 587 | StringCharAcceptor.prepare_trie(trie.children['"']) 588 | -------------------------------------------------------------------------------- /src/llm_structured_output/util/__init__.py: -------------------------------------------------------------------------------- 1 | from . import bitmap 2 | from . import output 3 | from . import tokentrie 4 | from . import tokenization 5 | -------------------------------------------------------------------------------- /src/llm_structured_output/util/bitmap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities to use the bitmap of accepted token ids returned by TokenAcceptor. 3 | """ 4 | 5 | from math import inf 6 | from typing import Iterable 7 | 8 | 9 | def count_set_bits(bitmap: int) -> int: 10 | """ 11 | Count the number of bits set to one. 12 | """ 13 | # FUTURE: self.ids.bit_count() available from Python 3.10 is said to be 6x faster 14 | return bin(bitmap).count("1") 15 | 16 | 17 | def highest_bit_set(bitmap: int) -> int: 18 | """ 19 | Return the index of the highest bit set in the bitmap. 20 | """ 21 | return bitmap.bit_length() - 1 22 | 23 | 24 | def bitmap_complement(bitmap: int, set_size: int = None) -> int: 25 | """ 26 | Negate the bits in the bitmap. 27 | Since the bitmap is encoded as a Python int, it can be of arbitrary length. 28 | I.e. we don't know how many zeros are above the top set bit. The set_size 29 | parameter can be passed to indicate the number of bits in the bitmap (which 30 | is akin to the number of members in the set it represents). If unspecified, 31 | the top set bit in the bitmap is used as its set size. 32 | """ 33 | if not set_size: 34 | set_size = bitmap.bit_length() 35 | return (1 << set_size) - 1 - bitmap 36 | 37 | 38 | def enumerate_set_bits(bitmap: int) -> Iterable[int]: 39 | """ 40 | Generator that yields the indices of the set bits in the bitmap. 41 | Note that it does so from highest to lowest. 42 | """ 43 | while bitmap: 44 | highest_bit = highest_bit_set(bitmap) 45 | yield highest_bit 46 | bitmap -= 1 << highest_bit 47 | 48 | 49 | def bias_logits(np, logits, accepted_token_bitmap): 50 | """ 51 | Apply a -inf bias to tokens that will not be accepted. 52 | Rather than import here, the np parameters is numpy or a compatible library 53 | import, such as mlx.core. 54 | """ 55 | vocab_size = logits.shape[0] 56 | highest_token_accepted = highest_bit_set(accepted_token_bitmap) 57 | accepted_token_count = count_set_bits(accepted_token_bitmap) 58 | # Check whether there's more tokens to be rejected or to be allowed, then do what's less work. 59 | if accepted_token_count <= highest_token_accepted / 2: 60 | bias = np.full(vocab_size, -inf) 61 | indices = np.array([*enumerate_set_bits(accepted_token_bitmap)]) 62 | bias[indices] = 0 63 | else: 64 | bias = np.concatenate( 65 | [ 66 | np.full(highest_token_accepted + 1, 0), 67 | # All tokens above the highest accepted token are rejected. 68 | np.full(vocab_size - highest_token_accepted - 1, -inf), 69 | ] 70 | ) 71 | rejected_token_bitmap = bitmap_complement(accepted_token_bitmap) 72 | indices = np.array([*enumerate_set_bits(rejected_token_bitmap)]) 73 | bias[indices] = -inf 74 | return np.add(logits, bias) 75 | -------------------------------------------------------------------------------- /src/llm_structured_output/util/output.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | """ 3 | Terminal colored output 4 | """ 5 | 6 | 7 | def info(*args, **kwargs): 8 | print("\033[34mℹ ", end="") 9 | print(*args, **kwargs) 10 | print("\033[0m", end="") 11 | 12 | 13 | def warning(*args, **kwargs): 14 | print("\033[43;37m", end="") 15 | print(*args, **kwargs) 16 | print("\033[0m", end="") 17 | 18 | 19 | def debug(*args, **kwargs): 20 | print("\033[33m", end="") 21 | print(*args, **kwargs) 22 | print("\033[0m", end="") 23 | 24 | 25 | def debugbold(*args, **kwargs): 26 | print("\033[1;33m", end="") 27 | print(*args, **kwargs) 28 | print("\033[0m", end="") 29 | 30 | 31 | def bold(*args, **kwargs): 32 | print("\033[1;30m", end="") 33 | print(*args, **kwargs) 34 | print("\033[0m", end="") 35 | 36 | 37 | def bolddim(*args, **kwargs): 38 | print("\033[1;2;30m", end="") 39 | print(*args, **kwargs) 40 | print("\033[0m", end="") 41 | 42 | 43 | def boldalt(*args, **kwargs): 44 | print("\033[1;36m", end="") 45 | print(*args, **kwargs) 46 | print("\033[0m", end="") 47 | 48 | 49 | def underline(*args, **kwargs): 50 | print("\033[4m", end="") 51 | print(*args, **kwargs) 52 | print("\033[0m", end="") 53 | 54 | 55 | def inverse(*args, **kwargs): 56 | print("\033[7m", end="") 57 | print(*args, **kwargs) 58 | print("\033[0m", end="") 59 | 60 | 61 | def setfg(r: float, g: float, b: float): 62 | """Each of r,g,b must be between 0 and 1""" 63 | color = 16 + 36 * round(5 * r) + 6 * round(5 * g) + round(5 * b) 64 | print(f"\033[38;5;{color}m", end="") 65 | 66 | 67 | def setbg(r: float, g: float, b: float): 68 | """Each of r,g,b must be between 0 and 1""" 69 | color = 16 + 36 * round(5 * r) + 6 * round(5 * g) + round(5 * b) 70 | print(f"\033[48;5;{color}m", end="") 71 | 72 | 73 | def clear(): 74 | print("\033[0m", end="") 75 | -------------------------------------------------------------------------------- /src/llm_structured_output/util/tokenization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tokenizer utils. 3 | """ 4 | 5 | from typing import Union 6 | 7 | SPIECE_UNDERLINE = "▁" 8 | 9 | 10 | class HuggingfaceTokenizerHelper: 11 | """ 12 | Helper to use Huggingface tokenizers effectively. 13 | """ 14 | 15 | def __init__(self, tokenizer): 16 | """ 17 | tokenizer is expected to be a Huggingface PreTrainedTokenizer[Fast] 18 | """ 19 | self.tokenizer = tokenizer 20 | self.token_has_space_prefix = dict( 21 | [ 22 | (i, fragment[0] == SPIECE_UNDERLINE) 23 | for fragment, i in tokenizer.vocab.items() 24 | ] 25 | ) 26 | 27 | def encode_prompt(self, prompt: Union[str, list[dict[str, str]]]) -> list[int]: 28 | """ 29 | Encode the prompt, applying the tokenizer template first if the prompt 30 | is a series of messages instead of a straight string. 31 | """ 32 | if isinstance(prompt, str): 33 | return self.tokenizer.encode(prompt) 34 | if not self.tokenizer.chat_template: 35 | return self.tokenizer.encode("\n\n".join( 36 | f"{message['role']}: {message['content']}" 37 | for message in prompt 38 | )) 39 | return self.tokenizer.apply_chat_template(prompt) 40 | 41 | def no_strip_decode(self, tokens): 42 | """ 43 | Allows to decode single tokens without removing the initial space. 44 | The Huggingface tokenizer doesn't seem to have an easy way to do this. 45 | """ 46 | fragment = self.tokenizer.decode(tokens) 47 | if self.token_has_space_prefix[tokens[0]]: 48 | return f" {fragment}" 49 | else: 50 | return fragment 51 | 52 | def extract_vocabulary(self) -> tuple[list[tuple[int, str]], int]: 53 | """ 54 | Extract the vocabulary and eos_token_id from a Huggingface PreTrainedTokenizer. 55 | """ 56 | return ( 57 | [(i, self.no_strip_decode([i])) for _, i in self.tokenizer.vocab.items()], 58 | self.tokenizer.eos_token_id, 59 | ) 60 | -------------------------------------------------------------------------------- /src/llm_structured_output/util/tokentrie.py: -------------------------------------------------------------------------------- 1 | """ 2 | TokenTrie: hold the LLM token vocabulary in a prefix tree in otder to perform 3 | operations over the whole vocabulary or parts of it in logarithmic time instead 4 | of linear. 5 | """ 6 | 7 | from __future__ import annotations 8 | from collections import namedtuple 9 | from typing import Callable, Iterable, Tuple 10 | 11 | 12 | TokenTrieStats = namedtuple( 13 | "TokenTrieStats", ["tokenids", "trienodes", "trieleaves", "triedepth"] 14 | ) 15 | 16 | 17 | class TokenTrie: 18 | """ 19 | Access the tokens in a vocabulary hierarchically by prefix. 20 | Ids are stored as a bitmap with bits set to one meaning id is present. 21 | """ 22 | 23 | def __init__(self): 24 | self.children: dict[str, TokenTrie] = {} 25 | self.ids: int = 0 26 | 27 | def insert_all(self, vocabulary: Iterable[Tuple[int, str]]): 28 | """ 29 | Insert all the tokens in the vocabulary in the trie, with the id of 30 | each token being its index in the vocabulary. 31 | """ 32 | for _id, token in vocabulary: 33 | if len(token) > 0: 34 | self.insert(token, _id) 35 | 36 | def insert(self, token, _id): 37 | """ 38 | Insert one token in the trie, with the given id. 39 | """ 40 | if len(token) == 0: 41 | self.ids |= 1 << _id 42 | else: 43 | head, tail = token[0], token[1:] 44 | child = self.children.get(head, self.__class__()) 45 | child.insert(tail, _id) 46 | self.children[head] = child 47 | 48 | def insert_ids(self, token, ids): 49 | """ 50 | Insert a token in the trie, with the given id set. 51 | This is useful e.g. when collapsing multiple branches into one. 52 | """ 53 | if len(token) == 0: 54 | self.ids |= ids 55 | else: 56 | head, tail = token[0], token[1:] 57 | child = self.children.get(head, self.__class__()) 58 | child.insert_ids(tail, ids) 59 | self.children[head] = child 60 | 61 | def collect_ids(self) -> set[int]: 62 | """ 63 | Returns a set with the ids of the token(s) in this node and all the 64 | nodes below it. 65 | """ 66 | ids = self.ids 67 | for child in self.children.values(): 68 | ids |= child.collect_ids() 69 | return ids 70 | 71 | def dfs(self, prefix="") -> Iterable[tuple[str, int]]: 72 | """ 73 | Traverse the trie depth-first, yielding (token, ids) tuples. 74 | """ 75 | if self.ids: 76 | yield (prefix, self.ids) 77 | for char, child in self.children.items(): 78 | yield from child.dfs(prefix + char) 79 | 80 | def map(self, map_fn: Callable[[str, int], str]) -> TokenTrie: 81 | """ 82 | Return a trie where the characters are mapped to other characters using a 83 | function. This is useful for example to collapse a tree into a smaller one 84 | by pruning or merging branches where the characters are equivalent for a 85 | particular use case. The mapping function is passed a character to map, and 86 | the recursion level in the tree, and it can return True to preserve the 87 | branch of the tree as is, None to prune it, or a replacement character. 88 | If the latter, the branch will be recursed upon and stored under the 89 | replacement branch. 90 | """ 91 | return self._map(map_fn, self.__class__()) 92 | 93 | def _map( 94 | self, map_fn: Callable[[str, int], str], mapped_trie: TokenTrie, level: int = 0 95 | ) -> TokenTrie: 96 | """ 97 | Internal implementation of map() 98 | """ 99 | mapped_trie.ids |= self.ids 100 | for char, child in self.children.items(): 101 | mapped_char = map_fn(char, level) 102 | if mapped_char is True: 103 | # If the mapping function returns True, preserve the original branch 104 | mapped_trie.children[char] = child 105 | elif mapped_char is None: 106 | # If the mapping function returns None, prune the original branch 107 | pass 108 | else: 109 | # Map the branch to a new character, e.g. merge several chars into one 110 | mapped_child = mapped_trie.children.get( 111 | mapped_char, mapped_trie.__class__() 112 | ) 113 | # pylint: disable-next=protected-access 114 | mapped_trie.children[mapped_char] = child._map( 115 | map_fn, mapped_child, level + 1 116 | ) 117 | return mapped_trie 118 | 119 | def _id_count(self) -> int: 120 | """ 121 | Returns the number of ids in this node 122 | """ 123 | # FUTURE: self.ids.bit_count() available from Python 3.10 is said to be 6x faster 124 | return bin(self.ids).count("1") 125 | 126 | def max_depth(self) -> int: 127 | """ 128 | Return the max depth of any branch on the trie, i.e. the length of the longest token. 129 | """ 130 | return max((child.max_depth() for child in self.children.values()), default=0) + 1 131 | 132 | def stats(self) -> TokenTrieStats: 133 | """ 134 | Compute and return statistics on the trie, for debugging purposes. 135 | """ 136 | ids = self._id_count() 137 | nodes = 1 138 | leaves = 0 139 | depth = 0 140 | if len(self.children) == 0: 141 | leaves = 1 142 | else: 143 | for branch in self.children.values(): 144 | branch_ids, branch_nodes, branch_leaves, branch_depth = branch.stats() 145 | ids += branch_ids 146 | nodes += branch_nodes 147 | leaves += branch_leaves 148 | depth = max(depth, branch_depth) 149 | return TokenTrieStats( 150 | tokenids=ids, trienodes=nodes, trieleaves=leaves, triedepth=depth + 1 151 | ) 152 | 153 | def __repr__(self): 154 | id_count = self._id_count() 155 | child_count = len(self.children) 156 | return f"{super().__repr__()}({id_count=}, {child_count=})" 157 | -------------------------------------------------------------------------------- /src/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otriscon/llm-structured-output/037e8eb7447005fda06e7d811b041efcb94b0cef/src/tests/__init__.py -------------------------------------------------------------------------------- /src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/multi_turn-00000-of-00001.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otriscon/llm-structured-output/037e8eb7447005fda06e7d811b041efcb94b0cef/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/multi_turn-00000-of-00001.parquet -------------------------------------------------------------------------------- /src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/parquet_to_jsonl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert a fireworks function calling dataset from parquet to jsonl that can be 3 | used by the evaluation scripts. 4 | 5 | https://huggingface.co/datasets/fireworks-ai/function-calling-eval-dataset-v0 6 | """ 7 | 8 | import sys 9 | import json 10 | import pyarrow.parquet as pq 11 | 12 | if len(sys.argv) < 2: 13 | print("Need path to parquet file.") 14 | sys.exit(1) 15 | input_file = sys.argv[1] 16 | data = pq.read_table(input_file).to_pydict() 17 | prompts = data["prompt"] 18 | completions = data["completion"] 19 | tools = data["tools"] 20 | 21 | output_file = input_file.replace(".parquet", ".jsonl") 22 | if output_file == input_file: 23 | output_file += ".jsonl" 24 | 25 | with open(output_file, mode="w", encoding="utf-8") as f: 26 | for i, prompt in enumerate(prompts): 27 | json.dump( 28 | { 29 | "prompt": prompt, 30 | "tools": json.loads(tools[i]), 31 | # The source dataset contains one gold completion per case, but we output an array 32 | # to support multiple gold answers down the line. 33 | "gold": [ 34 | { 35 | "type": "function", 36 | "function": json.loads( 37 | completions[i].partition("")[2] 38 | ), 39 | } 40 | ], 41 | "options": { 42 | "prompt_includes_schema": True, 43 | # This dataset has only cases where one tool is invoked, and the prompt includes 44 | # an example in which the output is not an array but a single tool call. 45 | "single_tool": True, 46 | }, 47 | }, 48 | f, 49 | ) 50 | f.write("\n") 51 | -------------------------------------------------------------------------------- /src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/report-multi_turn.md: -------------------------------------------------------------------------------- 1 | case | mlx-community/Meta-Llama-3-8B-Instruct-4bit | gpt-4o-2024-05-13 2 | --- | --- | --- 3 | 0 | ✅ | ✅ 4 | 1 | ✅ | ✅ 5 | 2 | ✅ | ✅ 6 | 3 | ✅ | ✅ 7 | 4 | ✅ | ✅ 8 | 5 | ✅ | ✅ 9 | 6 | ✅ | _function_call[0].url_ ~~'www.mywebsite.com'~~ 'http://www.mywebsite.com' 10 | 7 | ✅ | ✅ 11 | 8 | ✅ | ✅ 12 | 9 | ✅ | ✅ 13 | 10 | ✅ | ✅ 14 | 11 | ✅ | ✅ 15 | 12 | ✅ | ✅ 16 | 13 | ✅ | ✅ 17 | 14 | ✅ | ✅ 18 | 15 | ✅ | ✅ 19 | 16 | ✅ | ✅ 20 | 17 | ✅ | ✅ 21 | 18 | ➕ _function_call[0].genre_ 'action' | ✅ 22 | 19 | ✅ | ✅ 23 | 20 | ✅ | ✅ 24 | 21 | ✅ | ✅ 25 | 22 | ✅ | ✅ 26 | 23 | ✅ | ✅ 27 | 24 | ✅ | ✅ 28 | 25 | ➕ _function_call[0].genre_ 'pop' ⸱ _function_call[0].keyword_ ~~'pop'~~ 'Taylor Swift' | ✅ 29 | 26 | ✅ | _function_call[0].country_ ~~'US'~~ 'us' 30 | 27 | ✅ | ✅ 31 | 28 | ✅ | ✅ 32 | 29 | _function_call[0].amount_ ~~1500 [int]~~ 1500.0 [float] | ✅ 33 | 30 | ✅ | ✅ 34 | 31 | ✅ | ✅ 35 | 32 | ✅ | ✅ 36 | 33 | ✅ | ✅ 37 | 34 | ✅ | ✅ 38 | 35 | ✅ | ✅ 39 | 36 | _function_call[0].rating_ ~~7 [int]~~ 7.0 [float] | ✅ 40 | 37 | _function_call[0].date_range['start_date']_ ~~'2022-02-01'~~ '2023-02-20' ⸱ _function_call[0].date_range['end_date']_ ~~'2022-02-08'~~ '2023-02-27' | _function_call[0].date_range['start_date']_ ~~'2022-02-01'~~ '2023-09-30' ⸱ _function_call[0].date_range['end_date']_ ~~'2022-02-08'~~ '2023-10-07' 41 | 38 | ✅ | ✅ 42 | 39 | ✅ | ✅ 43 | 40 | _function_call[0].event_date_ ~~'2022-04-15'~~ 'today' | _function_call[0].event_date_ ~~'2022-04-15'~~ '2023-10-06' 44 | 41 | ✅ | ✅ 45 | 42 | ✅ | ✅ 46 | 43 | ✅ | ✅ 47 | 44 | ✅ | ✅ 48 | 45 | ✅ | ✅ 49 | 46 | ✅ | ✅ 50 | 47 | ✅ | ✅ 51 | 48 | ✅ | ✅ 52 | 49 | _function_call[0].query_ ~~''~~ 'comedy' | ➖ ~~_function_call[0].query_ ''~~ 53 | 50 | ✅ | _function_call[0].url_ ~~'www.example.com'~~ 'http://www.example.com' 54 | 51 | ✅ | _function_call[0].username_ ~~'@JohnDoe'~~ 'JohnDoe' 55 | 52 | ✅ | ✅ 56 | 53 | _function_call[0].source_language_ ~~'fr'~~ 'French' ⸱ _function_call[0].target_language_ ~~'en'~~ 'English' | ✅ 57 | 54 | ✅ | ✅ 58 | 55 | ✅ | ✅ 59 | 56 | ✅ | ✅ 60 | 57 | ✅ | _function_call[0].country_ ~~'United States'~~ 'us' 61 | 58 | _function_call[0].event_location_ ~~'conference room in our office'~~ 'Conference room in our office' | _function_call[0].event_date_ ~~'15th of next month'~~ '2023-11-15' ⸱ _function_call[0].event_time_ ~~'10 AM'~~ '10:00 AM' ⸱ _function_call[0].event_location_ ~~'conference room in our office'~~ 'Conference Room, Office' 62 | 59 | ✅ | ✅ 63 | 60 | ✅ | ✅ 64 | 61 | ✅ | _function_call[0].locations[0]_ ~~'Brooklyn'~~ 'Brooklyn, NY' ⸱ _function_call[0].locations[1]_ ~~'Manhattan'~~ 'Manhattan, NY' ⸱ _function_call[0].locations[2]_ ~~'Queens'~~ 'Queens, NY' ⸱ _function_call[0].locations[3]_ ~~'Brooklyn'~~ 'Brooklyn, NY' 65 | 62 | _function_call[0].image_ ~~'user_image'~~ 'The image you sent' | ➕ _tool_call[0]['error']_ {'error': "Parsing tool_calls: KeyError('tool_calls')", 'completion_message': {'role': 'assistant', 'content': 'Please provide the image of the barcode so I can proceed with scanning it.'}} ⸱ ➖ ~~_function_call[0]._ {'name': 'scan_barcode', 'arguments': {'image': 'user_image'}}~~ ⸱ _tool_call[0]['type']_ ~~'function'~~ 'error' 66 | 63 | ✅ | ✅ 67 | 64 | ✅ | ✅ 68 | 65 | ✅ | ✅ 69 | 66 | _function_call[0].meal_ ~~'pizza'~~ 'lunch' ⸱ _function_call[0].date_ ~~'2022-03-01'~~ 'today' | _function_call[0].date_ ~~'2022-03-01'~~ '2023-10-10' 70 | 67 | ✅ | ✅ 71 | 68 | ✅ | ✅ 72 | 69 | _function_call[0].language_ ~~'English'~~ 'en' | _function_call[0].language_ ~~'English'~~ 'en' 73 | 70 | ✅ | ✅ 74 | 71 | _function_call[0].order_items[0]['product_name']_ ~~'laptop'~~ 'Laptop' | ✅ 75 | 72 | ✅ | ✅ 76 | 73 | ✅ | _function_call[0].background_color_ ~~'white'~~ '#FFFFFF' ⸱ _function_call[0].foreground_color_ ~~'black'~~ '#000000' 77 | 74 | _function_call[0].items[0]['name']_ ~~'apple'~~ 'apples' ⸱ _function_call[0].items[0]['price']_ ~~0.5~~ 1.0 ⸱ _function_call[0].items[1]['name']_ ~~'orange'~~ 'oranges' ⸱ _function_call[0].items[1]['price']_ ~~0.75~~ 0.5 | _function_call[0].items[0]['price']_ ~~0.5~~ 1.0 ⸱ _function_call[0].items[1]['price']_ ~~0.75~~ 0.5 78 | 75 | ✅ | ✅ 79 | 76 | ✅ | ✅ 80 | 77 | ✅ | ✅ 81 | 78 | ✅ | ✅ 82 | 79 | ✅ | _function_call[0].start_date_ ~~'1st June'~~ '2023-06-01' ⸱ _function_call[0].end_date_ ~~'10th June'~~ '2023-06-10' 83 | 80 | ✅ | ✅ 84 | 81 | ✅ | ✅ 85 | 82 | ✅ | ✅ 86 | 83 | ➕ _function_call[0].source_currency_ 'USD' ⸱ ➖ ~~_function_call[0].base_currency_ 'USD'~~ ⸱ _function_call[0].['name']_ ~~'get_currency_conversion_rate'~~ 'convert_currency' | ➕ _function_call[0].source_currency_ 'USD' ⸱ ➖ ~~_function_call[0].base_currency_ 'USD'~~ ⸱ _function_call[0].['name']_ ~~'get_currency_conversion_rate'~~ 'convert_currency' 87 | 84 | ✅ | ✅ 88 | 85 | _function_call[0].location_ ~~'main office'~~ 'our main office' | ✅ 89 | 86 | ✅ | ✅ 90 | 87 | ✅ | ✅ 91 | 88 | ✅ | ✅ 92 | 89 | ✅ | _tool_call[1]_ ➕ {'type': 'function', 'function': {'name': 'get_news', 'arguments': {'interests': ['sports'], 'location': 'New York'}}} ⸱ _function_call[0].interests[1]_ ➖ ~~'sports'~~ 93 | 90 | ✅ | ✅ 94 | 91 | ✅ | ✅ 95 | 92 | ✅ | ✅ 96 | 93 | ✅ | ✅ 97 | 94 | ✅ | ✅ 98 | 95 | ✅ | ✅ 99 | 96 | ✅ | ✅ 100 | 97 | ✅ | ✅ 101 | 98 | ✅ | ✅ 102 | 99 | ✅ | ✅ 103 | pass | 84 (84.0%) | 82 (82.0%) 104 | -------------------------------------------------------------------------------- /src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/report-single_turn.md: -------------------------------------------------------------------------------- 1 | case | mlx-community/Meta-Llama-3-8B-Instruct-4bit | gpt-4o-2024-05-13 2 | --- | --- | --- 3 | 0 | ✅ | ✅ 4 | 1 | ✅ | ✅ 5 | 2 | ✅ | ✅ 6 | 3 | ✅ | ✅ 7 | 4 | ➕ _function_call[0].limit_ 100 ⸱ ➖ ~~_function_call[0].relationship_ 'siblings'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | _function_call[0].relationship_ ~~'siblings'~~ 'sibling_domains' 8 | 5 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'referrer_files'~~ 'includes' | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'referrer_files'~~ 'files' 9 | 6 | _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅ 10 | 7 | ✅ | ✅ 11 | 8 | ✅ | ➕ _function_call[0].ip_ '192.0.2.1' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '192.0.2.1'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address' 12 | 9 | _function_call[0].relationship_ ~~'referrer_files'~~ 'has_file' | _function_call[0].relationship_ ~~'referrer_files'~~ 'files' 13 | 10 | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' 14 | 11 | ✅ | ➕ _function_call[0].ip_ '203.0.113.0' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '203.0.113.0'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address' 15 | 12 | ✅ | ✅ 16 | 13 | ✅ | ✅ 17 | 14 | ✅ | _function_call[0].ip_ ~~'http://www.example.org'~~ '93.184.216.34' 18 | 15 | ✅ | ✅ 19 | 16 | _function_call[0].ip_ ~~'12.234.56.126'~~ '22.242.75.136' | _function_call[0].ip_ ~~'12.234.56.126'~~ '22.242.75.136' 20 | 17 | _function_call[0].relationship_ ~~'urls'~~ 'related_to' | ✅ 21 | 18 | ✅ | ✅ 22 | 19 | _function_call[0].ip_ ~~'explorerweb.org'~~ 'http://explorerweb.org' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'explorerweb.org'~~ 'http://explorerweb.org' 23 | 20 | ✅ | ✅ 24 | 21 | ✅ | ✅ 25 | 22 | _function_call[0].relationship_ ~~'referrer_files'~~ 'contains' | _function_call[0].relationship_ ~~'referrer_files'~~ 'communicating_files' 26 | 23 | ➖ ~~_function_call[0].relationship_ 'downloaded_files'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ✅ 27 | 24 | _function_call[0].relationship_ ~~'siblings'~~ 'sibling' | _function_call[0].relationship_ ~~'siblings'~~ 'sibling_domains' 28 | 25 | ➕ _function_call[0].limit_ 100 ⸱ ➕ _function_call[0].cursor_ '' ⸱ _function_call[0].x-apikey_ ~~'delta_key'~~ 'your_delta_key' | ✅ 29 | 26 | ✅ | ✅ 30 | 27 | ✅ | ➕ _function_call[0].ip_ '44.55.66.77' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '44.55.66.77'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address' 31 | 28 | ➖ ~~_function_call[0].relationship_ 'graphs'~~ ⸱ ➖ ~~_function_call[0].x-apikey_ 'sec_key2'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_votes_on_ip_address' | ✅ 32 | 29 | ✅ | ✅ 33 | 30 | ✅ | ✅ 34 | 31 | ✅ | ✅ 35 | 32 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅ 36 | 33 | _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅ 37 | 34 | ✅ | ➕ _function_call[0].ip_ '10.0.0.1' ⸱ ➕ _function_call[0].relationship_ 'resolutions' ⸱ ➖ ~~_function_call[0].id_ '10.0.0.1'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_dns_resolution_object'~~ 'vt_get_objects_related_to_ip_address' 38 | 35 | ✅ | ✅ 39 | 36 | ✅ | ✅ 40 | 37 | ✅ | ✅ 41 | 38 | ✅ | ✅ 42 | 39 | ✅ | ✅ 43 | 40 | ➕ _function_call[0].limit_ 100 | ✅ 44 | 41 | ✅ | ➕ _function_call[0].domain_ 'mysite.io' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'mysite.io'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report' 45 | 42 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_objects_related_to_domain' ⸱ _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅ 46 | 43 | ✅ | ✅ 47 | 44 | ✅ | ✅ 48 | 45 | ✅ | ✅ 49 | 46 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ ➖ ~~_function_call[0].x-apikey_ 'gamma_key'~~ ⸱ ➖ ~~_function_call[0].cursor_ 'next_page'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_votes_on_ip_address' | ✅ 50 | 47 | ✅ | ➕ _function_call[0].domain_ 'samplepage.net' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'samplepage.net'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report' 51 | 48 | ✅ | ✅ 52 | 49 | ➖ ~~_function_call[0].cursor_ 'start_cursor'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'downloaded_files'~~ 'downloaded_from' | ➖ ~~_function_call[0].cursor_ 'start_cursor'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' 53 | 50 | ➖ ~~_function_call[0].relationship_ 'parent'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ➕ _function_call[0].limit_ 1 54 | 51 | ✅ | ✅ 55 | 52 | _function_call[0].relationship_ ~~'caa_records'~~ 'CAA' | ✅ 56 | 53 | ✅ | ✅ 57 | 54 | ✅ | ✅ 58 | 55 | ✅ | ✅ 59 | 56 | ✅ | ✅ 60 | 57 | ✅ | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' 61 | 58 | ➕ _function_call[0].limit_ 0 ⸱ ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅ 62 | 59 | ✅ | _function_call[0].ip_ ~~'https://www.example.org'~~ '93.184.216.34' 63 | 60 | ➖ ~~_function_call[0].relationship_ 'siblings'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | _function_call[0].relationship_ ~~'siblings'~~ 'sibling_domains' 64 | 61 | _function_call[0].ip_ ~~'viewpage.net'~~ 'http://viewpage.net' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'viewpage.net'~~ 'http://viewpage.net' 65 | 62 | ✅ | ✅ 66 | 63 | ✅ | ✅ 67 | 64 | ➖ ~~_function_call[0].relationship_ 'caa_records'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ✅ 68 | 65 | ✅ | _tool_call[1]_ ➕ {'type': 'function', 'function': {'name': 'vt_get_comments_on_domain', 'arguments': {'domain': 'reddit.com', 'x-apikey': 'reddit_api_key'}}} 69 | 66 | ✅ | ✅ 70 | 67 | ✅ | ✅ 71 | 68 | _function_call[0].relationship_ ~~'historical_whois'~~ 'whois' ⸱ _function_call[0].x-apikey_ ~~'elite_api'~~ 'your_api_key' | _function_call[0].relationship_ ~~'historical_whois'~~ 'whois' 72 | 69 | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' 73 | 70 | ✅ | ✅ 74 | 71 | ➕ _function_call[0].limit_ 100 | ✅ 75 | 72 | ✅ | ✅ 76 | 73 | ✅ | ✅ 77 | 74 | ✅ | ✅ 78 | 75 | ✅ | ✅ 79 | 76 | _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_object_descriptors_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'related_threat_actors'~~ 'threat_actor' | _function_call[0].relationship_ ~~'related_threat_actors'~~ 'threat_actors' 80 | 77 | ➖ ~~_function_call[0].relationship_ 'subdomains'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅ 81 | 78 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅ 82 | 79 | ➕ _function_call[0].id_ 'dns_resolution_object_id' ⸱ ➖ ~~_function_call[0].domain_ 'site5.info'~~ ⸱ ➖ ~~_function_call[0].relationship_ 'resolutions'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_dns_resolution_object' | ✅ 83 | 80 | ✅ | ✅ 84 | 81 | ✅ | ✅ 85 | 82 | ✅ | ✅ 86 | 83 | ✅ | ✅ 87 | 84 | _function_call[0].relationship_ ~~'historical_whois'~~ 'whois' | ✅ 88 | 85 | ➕ _function_call[0].id_ 'yahoo.com' ⸱ ➖ ~~_function_call[0].domain_ 'yahoo.com'~~ ⸱ ➖ ~~_function_call[0].relationship_ 'resolutions'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_dns_resolution_object' | ✅ 89 | 86 | _function_call[0].relationship_ ~~'referrer_files'~~ 'contains' | _function_call[0].relationship_ ~~'referrer_files'~~ 'files' 90 | 87 | _function_call[0].ip_ ~~'digdeep.io'~~ 'http://digdeep.io' | ➕ _function_call[0].domain_ 'digdeep.io' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'digdeep.io'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report' 91 | 88 | ✅ | ✅ 92 | 89 | _function_call[0].ip_ ~~'surfthis.net'~~ 'http://surfthis.net' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'surfthis.net'~~ 'http://surfthis.net' 93 | 90 | _function_call[0].relationship_ ~~'communicating_files'~~ 'communicates_with' | ✅ 94 | 91 | ➖ ~~_function_call[0].relationship_ 'urls'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅ 95 | 92 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl_certificate' | _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl_certificates' 96 | 93 | _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_object_descriptors_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl-certificate' | _function_call[0].relationship_ ~~'historical_ssl_certificates'~~ 'ssl_certificates' 97 | 94 | _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_ip_address'~~ 'vt_get_objects_related_to_ip_address' ⸱ _function_call[0].relationship_ ~~'referrer_files'~~ 'REFERENCES' | ✅ 98 | 95 | _function_call[0].id_ ~~'10.10.10.10linked.site'~~ '10.10.10.10_linked.site' | ✅ 99 | 96 | _function_call[0].ip_ ~~'checkthisout.net'~~ 'http://checkthisout.net' | ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_ip_address_report' ⸱ _function_call[0].ip_ ~~'checkthisout.net'~~ 'http://checkthisout.net' 100 | 97 | _function_call[0].domain_ ~~'sample.org'~~ 'sample.com' ⸱ _function_call[0].relationship_ ~~'cname_records'~~ 'dns_resolution' | _function_call[0].domain_ ~~'sample.org'~~ 'sample.com' 101 | 98 | _function_call[0].relationship_ ~~'communicating_files'~~ 'file' | ✅ 102 | 99 | _function_call[0].x-apikey_ ~~'eta_key'~~ 'your_api_key' | ✅ 103 | 100 | ➖ ~~_function_call[0].relationship_ 'historical_whois'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_ip_address_report' | ➖ ~~_function_call[0].relationship_ 'historical_whois'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_ip_address'~~ 'vt_get_ip_address_report' 104 | 101 | ✅ | ✅ 105 | 102 | ➖ ~~_function_call[0].x-apikey_ 'KEY123'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_ip_address_report'~~ 'vt_get_votes_on_ip_address' | ✅ 106 | 103 | ✅ | ✅ 107 | 104 | ✅ | ✅ 108 | 105 | ✅ | ✅ 109 | 106 | _function_call[0].ip_ ~~'inspectlink.com'~~ 'http://inspectlink.com' | ➕ _function_call[0].domain_ 'inspectlink.com' ⸱ ➕ _function_call[0].x-apikey_ 'your_api_key_here' ⸱ ➖ ~~_function_call[0].ip_ 'inspectlink.com'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_votes_on_ip_address'~~ 'vt_get_domain_report' 110 | 107 | ✅ | ✅ 111 | 108 | ➖ ~~_function_call[0].relationship_ 'historical_whois'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_domain_report' | ✅ 112 | 109 | ✅ | ✅ 113 | 110 | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' | ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_objects_related_to_domain'~~ 'vt_get_comments_on_domain' 114 | 111 | ➕ _function_call[0].limit_ 100 ⸱ ➖ ~~_function_call[0].relationship_ 'comments'~~ ⸱ _function_call[0].['name']_ ~~'vt_get_object_descriptors_related_to_domain'~~ 'vt_get_comments_on_domain' | ✅ 115 | pass | 59 (52.68%) | 77 (68.75%) 116 | -------------------------------------------------------------------------------- /src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/requirements.txt: -------------------------------------------------------------------------------- 1 | pyarrow 2 | -------------------------------------------------------------------------------- /src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/single_turn-00000-of-00001.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/otriscon/llm-structured-output/037e8eb7447005fda06e7d811b041efcb94b0cef/src/tests/data/fireworks-ai_function-calling-eval-dataset-v0/single_turn-00000-of-00001.parquet -------------------------------------------------------------------------------- /src/tests/eval_api.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | """ 3 | Run a tool use evaluation using an LLM with an OpenAI-like API. 4 | """ 5 | import argparse 6 | import json 7 | import time 8 | import requests 9 | 10 | from llm_structured_output.util.output import info, inverse, debug, warning 11 | 12 | from .eval_report import eval_completion 13 | 14 | 15 | def run_eval_case( 16 | api_url, 17 | api_key, 18 | model_name, 19 | case, 20 | header, 21 | temp=0, 22 | seed=0, 23 | stream=False, 24 | out=None, 25 | ): 26 | options = case.get("options", {}) 27 | prompt_includes_schema = options.get("prompt_includes_schema", False) 28 | 29 | payload = { 30 | "model": model_name, 31 | "messages": case["prompt"], 32 | "tools": case["tools"], 33 | "tool_choice": "auto", 34 | "temperature": temp, 35 | "seed": seed, 36 | } 37 | if stream: 38 | payload["stream"] = True 39 | payload["stream_options"] = {"include_usage": True} 40 | if prompt_includes_schema and "api.openai.com" not in api_url: 41 | # Non-standard option, should not be set for OpenAI API. 42 | payload["tool_options"] = { 43 | # Do not dump the schema again, since it's already in the prompt 44 | "no_prompt_steering": True, 45 | } 46 | 47 | info(f"{header} Sending API request...") 48 | start_time = time.time_ns() 49 | 50 | r = requests.post( 51 | f"{api_url}/v1/chat/completions", 52 | json=payload, 53 | headers={"Authorization": f"Bearer {api_key}"}, 54 | timeout=60, 55 | stream=stream, 56 | ) 57 | if stream: 58 | response = None 59 | tool_calls = [] 60 | for line in r.iter_lines(decode_unicode=True): 61 | if not line: 62 | continue 63 | if not line.startswith("data:"): 64 | warning("Expected all server-sent events to start with 'data:'") 65 | line = line[5:].strip() 66 | if line == "[DONE]": 67 | break 68 | message = json.loads(line) 69 | if response is None: 70 | response = message 71 | elif "usage" in message: 72 | response["usage"] = message["usage"] 73 | if not message["choices"]: 74 | continue 75 | tool_deltas = message["choices"][0]["delta"].get("tool_calls", []) 76 | if len(tool_deltas) > 1: 77 | warning( 78 | f"Expected updates for one tool_call at a time, got multiple: {tool_deltas=}" 79 | ) 80 | if tool_deltas: 81 | tool_delta = tool_deltas[0] 82 | index = tool_delta["index"] 83 | argument_delta = tool_delta["function"]["arguments"] 84 | if index == len(tool_calls): 85 | tool_calls.append(tool_delta) 86 | tool_name = tool_delta["function"][ 87 | "name" 88 | ] # name may not be present in additional updates 89 | debug( 90 | f"[call #{index}]\nname: {tool_name}\narguments: {argument_delta}", 91 | end="", 92 | ) 93 | elif index == len(tool_calls) - 1: 94 | tool_calls[index]["function"]["arguments"] += argument_delta 95 | debug(argument_delta, end="") 96 | else: 97 | warning( 98 | f"Unexpected tool_delta out of sequence: " 99 | f"current_index={len(tool_calls)-1} {tool_delta=}" 100 | ) 101 | response["choices"] = [ 102 | {"message": {"role": "assistant", "tool_calls": tool_calls}} 103 | ] 104 | debug() 105 | else: 106 | response = r.json() 107 | debug(response) 108 | 109 | total_time = (time.time_ns() - start_time) / 1e6 110 | prompt_tokens = response["usage"]["prompt_tokens"] 111 | completion_tokens = response["usage"]["completion_tokens"] 112 | info(f"{header} {prompt_tokens=} {completion_tokens=} {total_time=:.02f}") 113 | 114 | if out: 115 | json.dump(response, out) 116 | out.write("\n") 117 | 118 | diff = eval_completion(case, response) 119 | if diff: 120 | inverse(f"{header} DIFF:", diff) 121 | return False 122 | else: 123 | info(f"{header} PASS") 124 | return True 125 | 126 | 127 | def main(): 128 | parser = argparse.ArgumentParser( 129 | description="Run a function calling evaluation with the Fireworks AI dataset or similar" 130 | ) 131 | parser.add_argument( 132 | "--api-url", 133 | type=str, 134 | default="https://api.openai.com", 135 | help="The URL of the API server", 136 | ) 137 | parser.add_argument( 138 | "--api-key", 139 | type=str, 140 | default=None, 141 | help="The URL of the API server", 142 | ) 143 | parser.add_argument( 144 | "--model-name", 145 | type=str, 146 | default="gpt-4o", 147 | help="The name of the model to use", 148 | ) 149 | parser.add_argument( 150 | "--dataset-path", 151 | required=True, 152 | type=str, 153 | help="The path to the evaluation dataset (JSONL)", 154 | ) 155 | parser.add_argument( 156 | "--skip", 157 | type=int, 158 | default=0, 159 | help="Start at the given evaluation case number", 160 | ) 161 | parser.add_argument( 162 | "--count", 163 | type=int, 164 | default=None, 165 | help="Limit the number of cases to run", 166 | ) 167 | parser.add_argument( 168 | "--temp", 169 | help="The sampling temperature.", 170 | type=float, 171 | default=0.0, 172 | ) 173 | parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") 174 | parser.add_argument( 175 | "--stream", 176 | help="Use streaming API.", 177 | action=argparse.BooleanOptionalAction, 178 | default=False, 179 | ) 180 | parser.add_argument( 181 | "--output-file", 182 | help="Write completions to JSONL file.", 183 | type=str, 184 | default=None, 185 | ) 186 | args = parser.parse_args() 187 | 188 | out = None 189 | if args.output_file: 190 | out = open(args.output_file, mode="w", encoding="utf-8") 191 | 192 | with open(args.dataset_path, encoding="utf-8") as dataset: 193 | if args.count: 194 | end_index = args.skip + args.count 195 | else: 196 | end_index = None 197 | pass_count = 0 198 | fail_count = 0 199 | t0 = time.time_ns() 200 | for i, line in enumerate(dataset.readlines()): 201 | if i < args.skip: 202 | continue 203 | if end_index is not None and i == end_index: 204 | break 205 | case = json.loads(line) 206 | if run_eval_case( 207 | args.api_url, 208 | args.api_key, 209 | args.model_name, 210 | case, 211 | f"[{i}]", 212 | temp=args.temp, 213 | seed=args.seed, 214 | stream=args.stream, 215 | out=out, 216 | ): 217 | pass_count += 1 218 | else: 219 | fail_count += 1 220 | average_time = (time.time_ns() - t0) / 1e9 / (pass_count + fail_count) 221 | info(f"Totals: {pass_count=} {fail_count=} {average_time=:.02}s") 222 | 223 | if out: 224 | out.close() 225 | 226 | 227 | main() 228 | -------------------------------------------------------------------------------- /src/tests/eval_local.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | """ 3 | Run a tool use evaluation using a local LLM. 4 | """ 5 | import argparse 6 | import json 7 | import time 8 | 9 | from examples.llm_schema import Model 10 | from llm_structured_output.util.output import info, bold, inverse, debug 11 | 12 | from .eval_report import eval_tool_calls 13 | 14 | 15 | def run_eval_case(model, case, header, temp=None, seed=None, preemptive_batch_size=0): 16 | messages = case["prompt"] 17 | tools = case["tools"] 18 | options = case.get("options", {}) 19 | prompt_includes_schema = options.get("prompt_includes_schema", False) 20 | single_tool = options.get("single_tool", False) 21 | 22 | tool_schemas = [ 23 | { 24 | "type": "object", 25 | "properties": { 26 | "name": { 27 | "type": "const", 28 | "const": tool["function"]["name"], 29 | }, 30 | "arguments": tool["function"]["parameters"], 31 | }, 32 | "required": ["name", "arguments"], 33 | } 34 | for tool in tools 35 | ] 36 | 37 | separator = "\n" 38 | if single_tool: 39 | schema = {"anyOf": tool_schemas} 40 | if not prompt_includes_schema: 41 | schema_message = f""" 42 | You are a helpful assistant with access to tools that you must invoke to answer the user's request. 43 | The following tools are available: 44 | {separator.join([ f''' 45 | Tool {repr(tool[tool["type"]]["name"])}: {tool[tool["type"]]["description"]} 46 | Invocation schema: {json.dumps(tool_schema)} 47 | ''' for tool, tool_schema in zip(tools, tool_schemas) ])} 48 | Your answer is a JSON object according to the invocation schema of the most appropriate tool to use 49 | to answer the user request below. 50 | """ 51 | print(json.dumps(schema, indent=2)) ### 52 | print(schema_message) ### 53 | messages.insert(0, {"role": "system", "message": schema_message}) 54 | else: 55 | tool_call_schemas = [ 56 | { 57 | "type": "object", 58 | "properties": { 59 | "type": { 60 | "type": "const", 61 | "const": tool["type"], 62 | }, 63 | tool["type"]: tool_schema, 64 | }, 65 | "required": ["type", tool["type"]], 66 | } 67 | for tool, tool_schema in zip(tools, tool_schemas) 68 | ] 69 | schema = { 70 | "type": "array", 71 | "items": {"anyOf": tool_call_schemas}, 72 | } 73 | if not prompt_includes_schema: 74 | schema_message = f""" 75 | You are a helpful assistant with access to tools that you must invoke to answer the user's request. 76 | The following tools are available: 77 | {separator.join([ f''' 78 | Tool {repr(tool[tool["type"]]["name"])}: {tool[tool["type"]]["description"]} 79 | Invocation schema: {json.dumps(tool_call_schema)} 80 | ''' for tool, tool_call_schema in zip(tools, tool_call_schemas) ])} 81 | Your answer is a JSON array with one or more tool invocations according to the appropriate schema(s) 82 | in order to answer the user request below. 83 | """ 84 | print(json.dumps(schema, indent=2)) ### 85 | print(schema_message) ### 86 | messages.insert(0, {"role": "system", "message": schema_message}) 87 | 88 | info(f"{header} Starting generation...") 89 | content = "" 90 | prompt_tokens = 0 91 | completion_tokens = 0 92 | completion_time = 0 93 | start_time = time.time_ns() 94 | 95 | for result in model.completion( 96 | messages, 97 | schema=schema, 98 | max_tokens=4000, 99 | temp=temp, 100 | seed=seed, 101 | preemptive_batch_size=preemptive_batch_size, 102 | cache_prompt=True, 103 | ): 104 | if result["op"] == "evaluatedPrompt": 105 | prompt_tokens += result["token_count"] 106 | prompt_time = result["time_ms"] 107 | elif result["op"] == "generatedTokens": 108 | completion_tokens += result["token_count"] 109 | completion_time += result["time_ms"] 110 | content += result["text"] 111 | bold(result["text"], end="", flush=True) 112 | elif result["op"] == "stop": 113 | print() 114 | else: 115 | debug(f"{result=}") 116 | assert False 117 | 118 | total_time = (time.time_ns() - start_time) / 1e6 119 | prompt_tps = prompt_tokens / prompt_time * 1e3 120 | completion_tps = completion_tokens / completion_time * 1e3 121 | info( 122 | f"{header} {prompt_tokens=} {prompt_tps=:.02f} {completion_tokens=} {completion_tps=:.02f}" 123 | f" {prompt_time=:.02f} {completion_time=:.02f} {total_time=:.02f}" 124 | ) 125 | 126 | tool_calls = json.loads(content) 127 | if single_tool: 128 | tool_calls = [{"type": "function", "function": tool_calls}] 129 | 130 | diff = eval_tool_calls(case, tool_calls) 131 | if diff: 132 | inverse(f"{header} DIFF:", diff) 133 | return False 134 | else: 135 | info(f"{header} PASS") 136 | return True 137 | 138 | 139 | def main(): 140 | parser = argparse.ArgumentParser( 141 | description="Run a function calling evaluation with the Fireworks AI dataset or similar" 142 | ) 143 | parser.add_argument( 144 | "--model-path", 145 | type=str, 146 | default="mlx_model", 147 | help="The path to the model weights and tokenizer", 148 | ) 149 | parser.add_argument( 150 | "--dataset-path", 151 | required=True, 152 | type=str, 153 | help="The path to the evaluation dataset (JSONL)", 154 | ) 155 | parser.add_argument( 156 | "--skip", 157 | type=int, 158 | default=0, 159 | help="Start at the given evaluation case number", 160 | ) 161 | parser.add_argument( 162 | "--count", 163 | type=int, 164 | default=None, 165 | help="Limit the number of cases to run", 166 | ) 167 | parser.add_argument( 168 | "--temp", 169 | help="The sampling temperature.", 170 | type=float, 171 | default=0.0, 172 | ) 173 | parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") 174 | parser.add_argument( 175 | "--preemptive", 176 | type=int, 177 | default=0, 178 | help="If greater than zero, the maximum size of the batch for pre-emptive decoding", 179 | ) 180 | args = parser.parse_args() 181 | 182 | info("Loading model...") 183 | model = Model() 184 | model.load(args.model_path) 185 | 186 | with open(args.dataset_path, encoding="utf-8") as dataset: 187 | if args.count: 188 | end_index = args.skip + args.count 189 | else: 190 | end_index = None 191 | pass_count = 0 192 | fail_count = 0 193 | t0 = time.time_ns() 194 | for i, line in enumerate(dataset.readlines()): 195 | if i < args.skip: 196 | continue 197 | if end_index is not None and i == end_index: 198 | break 199 | case = json.loads(line) 200 | if run_eval_case( 201 | model, 202 | case, 203 | f"[{i}]", 204 | temp=args.temp, 205 | seed=args.seed, 206 | preemptive_batch_size=args.preemptive, 207 | ): 208 | pass_count += 1 209 | else: 210 | fail_count += 1 211 | average_time = (time.time_ns() - t0) / 1e9 / (pass_count + fail_count) 212 | info(f"Totals: {pass_count=} {fail_count=} {average_time=:.02}s") 213 | 214 | 215 | main() 216 | -------------------------------------------------------------------------------- /src/tests/eval_report.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-function-docstring 2 | """ 3 | Create a markdown report from an evaluation dataset and one or more completions. 4 | """ 5 | import argparse 6 | import json 7 | import re 8 | import sys 9 | 10 | from deepdiff import DeepDiff 11 | 12 | 13 | def eval_tool_calls(case, tool_calls): 14 | single_tool = case.get("options", {}).get("single_tool", False) 15 | 16 | best_diff_count = 1e10 17 | for gold_tool_calls in case["gold"]: 18 | if single_tool: 19 | # The gold set in the source dataset is a single tool invocation instead of an array. 20 | # We could use the legacy function_call method to force a single function call, but 21 | # we think it's better to evaluate the model for non-legacy tool use. If the model 22 | # comes up with multi-tool solutions that are deemed acceptable, we can then: 23 | # - Remove this flag for this evaluation case, 24 | # - Wrap each existing gold value for this case in an array, 25 | # - Add the new solution that has multiple invocations to the gold set for the case. 26 | gold_tool_calls = [gold_tool_calls] 27 | diff = DeepDiff(gold_tool_calls, tool_calls, verbose_level=2) 28 | if diff is None: 29 | best_diff = None 30 | best_diff_count = 0 31 | break 32 | else: 33 | diff_count = diff.get_stats()["DIFF COUNT"] 34 | if diff_count < best_diff_count: 35 | best_diff_count = diff_count 36 | best_diff = diff 37 | return best_diff 38 | 39 | 40 | def eval_completion(case, completion): 41 | try: 42 | completion_tool_calls = completion["choices"][0]["message"]["tool_calls"] 43 | except (KeyError, TypeError) as e: 44 | sys.stderr.write( 45 | f"Completion object doesn't match expected format: {completion=}\n" 46 | ) 47 | completion_tool_calls = [ 48 | { 49 | "type": "error", 50 | "error": { 51 | "error": f"Parsing tool_calls: {repr(e)}", 52 | "completion_message": completion["choices"][0]["message"], 53 | }, 54 | } 55 | ] 56 | 57 | # Remove call metadata (currently only id) to compare with gold. 58 | # Note that we expect the gold set in the evaluation dataset to have 59 | # deserialized function arguments rather than as a string. 60 | tool_calls = [ 61 | ( 62 | { 63 | "type": "function", 64 | "function": { 65 | "name": tool_call["function"]["name"], 66 | "arguments": json.loads(tool_call["function"]["arguments"]), 67 | }, 68 | } 69 | if tool_call["type"] == "function" 70 | else { 71 | "type": tool_call["type"], 72 | tool_call["type"]: tool_call[tool_call["type"]], 73 | } 74 | ) 75 | for tool_call in completion_tool_calls 76 | ] 77 | 78 | return eval_tool_calls(case, tool_calls) 79 | 80 | 81 | CHANGE_FORMATTERS = { 82 | "type_changes": lambda path, change: f"_{path}_ ~~{repr(change['old_value'])} [{change['old_type'].__name__}]~~ {repr(change['new_value'])} [{change['new_type'].__name__}]", 83 | "values_changed": lambda path, change: f"_{path}_ ~~{repr(change['old_value'])}~~ {repr(change['new_value'])}", 84 | "dictionary_item_added": lambda path, change: f"➕ _{path}_ {repr(change)}", 85 | "dictionary_item_removed": lambda path, change: f"➖ ~~_{path}_ {repr(change)}~~", 86 | "iterable_item_added": lambda path, change: f"_{path}_ ➕ {repr(change)}", 87 | "iterable_item_removed": lambda path, change: f"_{path}_ ➖ ~~{repr(change)}~~", 88 | "set_item_added": lambda path, change: f"_{path}_ ➕ {repr(change)}", 89 | "set_item_removed": lambda path, change: f"_{path}_ ➖ ~~{repr(change)}~~", 90 | } 91 | 92 | 93 | def diff_to_md(diff): 94 | if not diff: 95 | return "✅" 96 | md_changes = [] 97 | for change_type, changes in diff.items(): 98 | formatter = CHANGE_FORMATTERS[change_type] 99 | for path, change in changes.items(): 100 | path = re.sub(r"root\[(\d*)\]\['function'\]", "function_call[\\1].", path) 101 | path = re.sub(r"root\[(\d*)\]", "tool_call[\\1]", path) 102 | path = re.sub(r"\['arguments'\]\['([^']*)']", "\\1", path) 103 | md_changes.append(formatter(path, change)) 104 | return " ⸱ ".join(md_changes) 105 | 106 | 107 | def report_eval_case( 108 | case, 109 | completions, 110 | index, 111 | out, 112 | ): 113 | eval_diffs = [eval_completion(case, completion) for completion in completions] 114 | columns = [diff_to_md(diff) for diff in eval_diffs] 115 | out.write(f"{index} | {' | '.join(columns)}\n") 116 | results = [not diff for diff in eval_diffs] 117 | return results 118 | 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser( 122 | description="Run a function calling evaluation with the Fireworks AI dataset or similar" 123 | ) 124 | parser.add_argument( 125 | "--dataset-path", 126 | required=True, 127 | type=str, 128 | help="The path to the evaluation dataset (JSONL)", 129 | ) 130 | parser.add_argument( 131 | "completions", 132 | metavar="completion_files", 133 | type=str, 134 | nargs="+", 135 | help="One or more jsonl files with completions for the evaluation dataset", 136 | ) 137 | parser.add_argument( 138 | "--output-file", 139 | help="Write report to a file instead of stdout", 140 | type=str, 141 | default=None, 142 | ) 143 | args = parser.parse_args() 144 | 145 | input_files = [open(filename, encoding="utf-8") for filename in args.completions] 146 | 147 | out = sys.stdout 148 | if args.output_file: 149 | out = open(args.output_file, mode="w", encoding="utf-8") 150 | 151 | i = 0 152 | with open(args.dataset_path, encoding="utf-8") as dataset: 153 | for i, line in enumerate(dataset.readlines()): 154 | case = json.loads(line) 155 | completions = [ 156 | json.loads(input_file.readline()) for input_file in input_files 157 | ] 158 | if i == 0: 159 | sum_results = [0 for completion in completions] 160 | models = [completion["model"] for completion in completions] 161 | out.write(f"case | {' | '.join(models)}\n") 162 | out.write(f"--- | {' | '.join(['---'] * len(models))}\n") 163 | results = report_eval_case(case, completions, i, out) 164 | sum_results = [sum_results[i] + result for i, result in enumerate(results)] 165 | total = i + 1 166 | out.write( 167 | f"pass | {' | '.join([f'{r} ({round(100*r/total, 2)}%)' for r in sum_results])}\n" 168 | ) 169 | 170 | for input_file in input_files: 171 | input_file.close() 172 | if out: 173 | out.close() 174 | 175 | 176 | if __name__ == "__main__": 177 | main() 178 | -------------------------------------------------------------------------------- /src/tests/requirements.txt: -------------------------------------------------------------------------------- 1 | mlx-lm >= 0.19.2 2 | tokenizers >= 0.20.1 3 | sentencepiece 4 | deepdiff 5 | requests 6 | --------------------------------------------------------------------------------